-
Notifications
You must be signed in to change notification settings - Fork 420
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The device of info['_weight'] doesn't match the storage device.
To Reproduce
# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)
# Check devices
print(f"sample device: {sample.device}\n"
f"info['_weight'] device: {info['_weight'].device}")sample device: cuda:0
info['_weight'] device: cpuExpected behavior
Both should be on the same device defined in storage(..., device) as these weights are later used to compute the loss.
System info
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)2024.10.23 1.26.4 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0] linuxReason and Possible fixes
Specify device argument in samplers.py (L508):
weight = torch.as_tensor(self._sum_tree[index], device=storage.device)Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working