Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,7 @@ def _cast_reduction(
except Exception:
raise RuntimeError(
f"{reduction_name} requires this object to be cast to a regular TensorDict. "
f"If you need {type(self)} to support {reduction_name}, help us by filing an issue"
f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue"
f" on github!"
)
return td._cast_reduction(
Expand Down Expand Up @@ -3253,7 +3253,7 @@ def _has_names(self):

def _erase_names(self):
raise RuntimeError(
f"Cannot erase names of a {type(self)}. "
f"Cannot erase names of a {type(self).__name__}. "
f"Erase source TensorDict's names instead."
)

Expand Down Expand Up @@ -3379,7 +3379,7 @@ def _stack_onto_(
dim: int,
) -> T:
raise RuntimeError(
f"stacking tensordicts is not allowed for type {type(self)}"
f"stacking tensordicts is not allowed for type {type(self).__name__}"
f"consider calling 'to_tensordict()` first"
)

Expand Down Expand Up @@ -3480,9 +3480,13 @@ def to(self, *args, **kwargs) -> T:
batch_size,
pin_memory,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
if inplace:
raise TypeError(f"Cannot use inplace=True with {type(self).__name__}.to().")

if batch_size is not None:
raise TypeError(f"Cannot pass batch-size to a {type(self)}.")
raise TypeError(f"Cannot pass batch-size to {type(self).__name__}.to().")
result = self

if device is not None and dtype is None and device == self.device:
Expand Down Expand Up @@ -3757,7 +3761,7 @@ def _cast_reduction(
except Exception:
raise RuntimeError(
f"{reduction_name} requires this object to be cast to a regular TensorDict. "
f"If you need {type(self)} to support {reduction_name}, help us by filing an issue"
f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue"
f" on github!"
)
return td._cast_reduction(
Expand Down
27 changes: 20 additions & 7 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ def _multithread_rebuild(
) -> None:
if constructor_kwargs:
raise RuntimeError(
f"constructor_kwargs not supported for class {type(self)}."
f"constructor_kwargs not supported for class {type(self).__name__}."
)
# Rebuilds a tensordict from the futures of its leaves
if inplace:
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def setter(
return
result.set(key, item_trsf, inplace=inplace)

elif isinstance(result, TensorDict) and checked and (inplace is not True):
elif checked and isinstance(result, TensorDict) and (inplace is not True):

def setter(
item_trsf,
Expand Down Expand Up @@ -1329,9 +1329,18 @@ def _apply_nest(
"batch_size and out.batch_size must be equal when both are provided."
)
if device is not NO_DEFAULT and device != out.device:
raise RuntimeError(
"device and out.device must be equal when both are provided."
)
if checked:
raise RuntimeError(
f"device and out.device must be equal when both are provided. Got device={device} and out.device={out.device}."
)
else:
device = torch.device(device)
out._device = device
for node in out.values(True, True, is_leaf=_is_tensor_collection):
if is_tensorclass(node):
node._tensordict._device = device
else:
node._device = device
else:

def make_result(names=names, batch_size=batch_size):
Expand Down Expand Up @@ -3594,9 +3603,13 @@ def to(self, *args, **kwargs: Any) -> T:
batch_size,
pin_memory,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
result = self

if inplace:
raise TypeError(
"Cannot send a _SubTensorDict instance to device/dtype inplace."
)
if device is not None and dtype is None and device == self.device:
return result
return self.to_tensordict().to(*args, **kwargs)
Expand Down Expand Up @@ -4093,7 +4106,7 @@ def _cast_reduction(
except Exception:
raise RuntimeError(
f"{reduction_name} requires this object to be cast to a regular TensorDict. "
f"If you need {type(self)} to support {reduction_name}, help us by filing an issue"
f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue"
f" on github!"
)
return td._cast_reduction(
Expand Down
37 changes: 32 additions & 5 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10650,6 +10650,7 @@ def to(
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[torch.device, str]] = ...,
non_blocking: bool = ...,
inplace: bool = False,
) -> T: ...

@overload
Expand All @@ -10665,10 +10666,16 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T: ...
def to(self: T, *, batch_size: torch.Size) -> T: ...

def _to_cuda_with_pin_mem(
self, *, num_threads, device="cuda", non_blocking=None, to: Callable
self,
*,
num_threads,
device="cuda",
non_blocking=None,
to: Callable,
inplace: bool = False,
):
if self.is_empty():
return self.to(device)
return self.to(device, inplace=inplace)
keys, vals = self._items_list(
leaves_only=True, include_nested=True, is_leaf=_NESTED_TENSORS_AS_LISTS
)
Expand Down Expand Up @@ -10700,6 +10707,8 @@ def _to_cuda_with_pin_mem(
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
device=device,
out=self if inplace else None,
checked=True,
)
return result

Expand Down Expand Up @@ -10751,6 +10760,9 @@ def to(self, *args, **kwargs) -> T:
``max(1, torch.get_num_threads())`` threads will be spawn.
``num_threads=0`` will cancel any
multithreading for the `pin_memory()` calls.
inplace (bool, optional): if ``True``, the data will be written in-place in the same tensordict.
This can be significantly faster whenever building a tensordict is CPU-overhead bound.
Defaults to ``False``.

Returns:
a new tensordict instance if the device differs from the tensordict
Expand Down Expand Up @@ -10779,6 +10791,7 @@ def to(self, *args, **kwargs) -> T:
batch_size,
non_blocking_pin,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
result = self

Expand All @@ -10791,6 +10804,7 @@ def to(self, *args, **kwargs) -> T:
pin_memory=non_blocking_pin,
num_threads=num_threads,
non_blocking=non_blocking,
inplace=inplace,
)

if non_blocking is None:
Expand Down Expand Up @@ -10822,11 +10836,13 @@ def to(tensor):
if num_threads is None:
num_threads = max(1, torch.get_num_threads() // 2)
result = self._to_cuda_with_pin_mem(
num_threads=num_threads, to=to, device=device
num_threads=num_threads, to=to, device=device, inplace=inplace
)
else:
apply_kwargs["device"] = device if device is not None else self.device
apply_kwargs["batch_size"] = batch_size
apply_kwargs["out"] = self if inplace else None
apply_kwargs["checked"] = False
if non_blocking_pin:

def to_pinmem(tensor, _to=to):
Expand All @@ -10848,7 +10864,9 @@ def to_pinmem(tensor, _to=to):
self._sync_all()
return result

def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
def _to_consolidated(
self, *, device, pin_memory, num_threads, non_blocking, inplace
):
if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0
Expand Down Expand Up @@ -10911,8 +10929,17 @@ def set_(x):
storage_offset=storage_offset,
)

if inplace:
out = self
else:
out = None

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads
set_,
device=torch.device(device),
num_threads=num_threads,
out=out,
checked=True,
)
result._consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
Expand Down
8 changes: 6 additions & 2 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def keys(
) -> _PersistentTDKeysView:
if is_leaf not in (None, _default_is_leaf, _is_leaf_nontensor):
raise ValueError(
f"is_leaf {is_leaf} is not supported within tensordicts of type {type(self)}."
f"is_leaf {is_leaf} is not supported within tensordicts of type {type(self).__name__}."
)
return _PersistentTDKeysView(
tensordict=self,
Expand Down Expand Up @@ -1026,7 +1026,11 @@ def to(self, *args, **kwargs: Any) -> PersistentTensorDict:
batch_size,
non_blocking_pin,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
if inplace:
raise TypeError(f"Cannot use inplace=True with {type(self).__name__}.to().")

if non_blocking_pin:
raise RuntimeError(
f"Cannot use non_blocking_pin=True {type(self).__name__}.to(). Call "
Expand Down Expand Up @@ -1181,7 +1185,7 @@ def _convert_inplace(self, inplace, key):

def _set_non_tensor(self, key: NestedKey, value: Any):
raise NotImplementedError(
f"set_non_tensor is not compatible with the tensordict type {type(self)}."
f"set_non_tensor is not compatible with the tensordict type {type(self).__name__}."
)

def _set_str(
Expand Down
4 changes: 2 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ def _set_str(
):
if is_non_tensor(self):
if key != "data":
raise KeyError(f"only 'data' keys are supported for {type(self)}.")
raise KeyError(f"only 'data' keys are supported for {type(self).__name__}.")
while isinstance(value, (NonTensorData, NonTensorStack)):
value = value.data
self._non_tensordict[key] = value
Expand Down Expand Up @@ -1737,7 +1737,7 @@ def _set_at_str(
):
if is_non_tensor(self):
if key != "data":
raise KeyError(f"only 'data' keys are supported for {type(self)}.")
raise KeyError(f"only 'data' keys are supported for {type(self).__name__}.")
while isinstance(value, (NonTensorData, NonTensorStack)):
value = value.data
self._non_tensordict[key] = value
Expand Down
2 changes: 2 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,7 @@ def _parse_to(*args, **kwargs):
non_blocking_pin = kwargs.pop("non_blocking_pin", False)
num_threads = kwargs.pop("num_threads", None)
other = kwargs.pop("other", None)
inplace = kwargs.pop("inplace", False)
if not is_dynamo_compiling():
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
Expand Down Expand Up @@ -1397,6 +1398,7 @@ def _parse_to(*args, **kwargs):
batch_size,
non_blocking_pin,
num_threads,
inplace,
)


Expand Down
5 changes: 5 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def get_available_devices():
devices += [torch.device(f"cuda:{i}")]
if i == 1:
break
# if torch.backends.mps.is_available():
# for i in range(torch.mps.device_count()):
# devices += [torch.device(f"mps:{i}")]
# if i == 1:
# break
return devices


Expand Down
22 changes: 20 additions & 2 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,8 +2056,26 @@ def test_split(self):

def test_to(self):
td = self.get_nested()
td = td.to("cpu:1")
assert isinstance(td.get("c")[0], self.TensorClass)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu:1")
td_device = td.to(device)
assert isinstance(td_device.get("c")[0], self.TensorClass)
assert td_device is not td
assert td_device.device == device

td_device = td.to(device, inplace=True)
assert td_device is td
assert td_device.device == device

td_cpu = td_device.to("cpu", inplace=True)
assert td_cpu.device == torch.device("cpu")

td_double = td.to(torch.float64, inplace=True)
assert td_double is td
assert td_double.dtype == torch.double
assert td_double.device == torch.device("cpu")


def test_decorator():
Expand Down
Loading
Loading