Skip to content

[BUG] info['_weight'] device for Importance Sampling in PER #2518

@EladSharony

Description

@EladSharony

Describe the bug

The device of info['_weight'] doesn't match the storage device.

To Reproduce

# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)

# Check devices
print(f"sample device: {sample.device}\n"
      f"info['_weight'] device: {info['_weight'].device}")
sample device: cuda:0
info['_weight'] device: cpu

Expected behavior

Both should be on the same device defined in storage(..., device) as these weights are later used to compute the loss.

System info

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.10.23 1.26.4 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0] linux

Reason and Possible fixes

Specify device argument in samplers.py (L508):

weight = torch.as_tensor(self._sum_tree[index], device=storage.device)

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions