diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index b043445bf94..6888b8b043c 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -179,7 +179,7 @@ def split_trajectories( torch.ones( out_split.shape, dtype=torch.bool, - device=out_split.get(("next", "done")).device, + device=out_split.device, ), ) if len(out_splits) > 1: