Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def get_item_shape(self, key):
item = self.get(key)
return item.shape
except RuntimeError as err:
if re.match(r"Found more than one unique shape in the tensors", str(err)):
if re.match(
r"Found more than one unique shape in the tensors|Could not run 'aten::stack' with arguments from the",
str(err),
):
shape = None
for td in self.tensordicts:
if shape is None:
Expand Down
27 changes: 18 additions & 9 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tensordict.utils import (
_check_keys,
_ErrorInteceptor,
_shape,
DeviceType,
is_non_tensor,
is_tensorclass,
Expand Down Expand Up @@ -474,16 +475,24 @@ def _stack(
) or is_tensorclass(tensor_cls)
if is_not_init is None:
is_not_init = isinstance(tensor, UninitializedTensorMixin)
if not is_not_init and tensor_shape is None:
tensor_shape = tensor.shape
elif not is_not_init and tensor.shape != tensor_shape:
if maybe_dense_stack:
with set_lazy_legacy(True):
return _stack(list_of_tensordicts, dim=dim)
if not is_not_init:
new_tensor_shape = _shape(tensor)
if tensor_shape is not None:
if len(new_tensor_shape) != len(tensor_shape) or not all(
s1 == s2 and s1 != -1
for s1, s2 in zip(_shape(tensor), tensor_shape)
):
# Nested tensors will require a lazy stack
if maybe_dense_stack:
with set_lazy_legacy(True):
return _stack(list_of_tensordicts, dim=dim)
else:
raise RuntimeError(
"The shapes of the tensors to stack is incompatible."
)
else:
raise RuntimeError(
"The shapes of the tensors to stack is incompatible."
)
tensor_shape = new_tensor_shape

out[key].append(tensor)
out[key] = (out[key], is_not_init, is_tensor)

Expand Down
2 changes: 2 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6699,6 +6699,8 @@ def numpy(self):

def to_numpy(x):
if hasattr(x, "numpy"):
if getattr(x, "is_nested", False):
return tuple(_x.numpy() for _x in x.unbind(0))
return x.numpy()
return x

Expand Down
225 changes: 205 additions & 20 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from torch import multiprocessing as mp, Tensor
from torch.multiprocessing import Manager
from torch.utils._pytree import tree_map

T = TypeVar("T", bound=TensorDictBase)
# We use an abstract AnyType instead of Any because Any isn't recognised as a type for python < 3.10
Expand Down Expand Up @@ -88,6 +89,162 @@ def __subclasscheck__(self, subclass):
torch.cat,
torch.gather,
}
# Methods to be executed from tensordict, any ref to self means 'tensorclass'
_METHOD_FROM_TD = [
"gather",
"replace",
]
# Methods to be executed from tensordict, any ref to self means 'self._tensordict'
_FALLBACK_METHOD_FROM_TD = [
"__abs__",
"__add__",
"__iadd__",
"__imul__",
"__ipow__",
"__isub__",
"__itruediv__",
"__mul__",
"__pow__",
"__sub__",
"__truediv__",
"_add_batch_dim",
"apply",
"_apply_nest",
"_fast_apply",
"apply_",
"named_apply",
"_check_unlock",
"unsqueeze",
"squeeze",
"_erase_names", # TODO: must be specialized
"_exclude", # TODO: must be specialized
"_get_str",
"_get_tuple",
"_set_at_tuple",
"_has_names",
"_propagate_lock",
"_propagate_unlock",
"_remove_batch_dim",
"is_memmap",
"is_shared",
"_select", # TODO: must be specialized
"_set_str",
"_set_tuple",
"all",
"any",
"empty",
"exclude",
"expand",
"expand_as",
"is_empty",
"is_shared",
"items",
"keys",
"lock_",
"masked_fill",
"masked_fill_",
"permute",
"flatten",
"unflatten",
"ndimension",
"rename_", # TODO: must be specialized
"reshape",
"select",
"to",
"transpose",
"unlock_",
"values",
"view",
"zero_",
"add",
"add_",
"mul",
"mul_",
"abs",
"abs_",
"acos",
"acos_",
"exp",
"exp_",
"neg",
"neg_",
"reciprocal",
"reciprocal_",
"sigmoid",
"sigmoid_",
"sign",
"sign_",
"sin",
"sin_",
"sinh",
"sinh_",
"tan",
"tan_",
"tanh",
"tanh_",
"trunc",
"trunc_",
"norm",
"lgamma",
"lgamma_",
"frac",
"frac_",
"expm1",
"expm1_",
"log",
"log_",
"log10",
"log10_",
"log1p",
"log1p_",
"log2",
"log2_",
"ceil",
"ceil_",
"floor",
"floor_",
"round",
"round_",
"erf",
"erf_",
"erfc",
"erfc_",
"asin",
"asin_",
"atan",
"atan_",
"cos",
"cos_",
"cosh",
"cosh_",
"lerp",
"lerp_",
"addcdiv",
"addcdiv_",
"addcmul",
"addcmul_",
"sub",
"sub_",
"maximum_",
"maximum",
"minimum_",
"minimum",
"clamp_max_",
"clamp_max",
"clamp_min_",
"clamp_min",
"pow",
"pow_",
"div",
"div_",
"sqrt",
"sqrt_",
]
_FALLBACK_METHOD_FROM_TD_COPY = [
"_clone", # TODO: must be specialized
"clone", # TODO: must be specialized
"copy", # TODO: must be specialized
]


class tensorclass:
Expand Down Expand Up @@ -173,7 +330,7 @@ def __torch_function__(
kwargs: dict[str, Any] | None = None,
) -> Callable:
if func not in _TD_PASS_THROUGH or not all(
issubclass(t, (Tensor, cls)) for t in types
issubclass(t, (Tensor, cls, TensorDictBase)) for t in types
):
return NotImplemented

Expand Down Expand Up @@ -266,28 +423,23 @@ def __torch_function__(
cls.share_memory_ = _share_memory_
if not hasattr(cls, "update"):
cls.update = _update
if not hasattr(cls, "replace"):
cls.replace = TensorDictBase.replace
if not hasattr(cls, "update_"):
cls.update_ = _update_
if not hasattr(cls, "update_at_"):
cls.update_at_ = _update_at_
if not hasattr(cls, "_set_str"):
cls._set_str = TensorDict._set_str
if not hasattr(cls, "_set_tuple"):
cls._set_tuple = TensorDict._set_tuple

cls.__add__ = TensorDict.__add__
cls.__iadd__ = TensorDict.__iadd__
cls.__abs__ = TensorDict.__abs__
cls.__truediv__ = TensorDict.__truediv__
cls.__itruediv__ = TensorDict.__itruediv__
cls.__mul__ = TensorDict.__mul__
cls.__imul__ = TensorDict.__imul__
cls.__sub__ = TensorDict.__sub__
cls.__isub__ = TensorDict.__isub__
cls.__pow__ = TensorDict.__pow__
cls.__ipow__ = TensorDict.__ipow__
for method_name in _METHOD_FROM_TD:
if not hasattr(cls, method_name):
setattr(cls, method_name, getattr(TensorDict, method_name))
for method_name in _FALLBACK_METHOD_FROM_TD:
if not hasattr(cls, method_name):
setattr(cls, method_name, _wrap_td_method(method_name))
for method_name in _FALLBACK_METHOD_FROM_TD_COPY:
if not hasattr(cls, method_name):
setattr(
cls,
method_name,
_wrap_td_method(method_name, copy_non_tensor=True),
)

if not hasattr(cls, "_apply_nest"):
cls._apply_nest = TensorDict._apply_nest
Expand Down Expand Up @@ -564,7 +716,7 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417
del non_tensordict[key]
continue
raise KeyError(
f"{key} is present in both tensor and non-tensor dicts"
f"{key} is present in both tensor and non-tensor dicts."
)
# bypass initialisation. this means we don't incur any overhead creating an
# empty tensordict and writing values to it. we can skip this because we already
Expand Down Expand Up @@ -744,7 +896,40 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
return wrapper


def _wrap_td_method(funcname, *, copy_non_tensor=False):
def wrapped_func(self, *args, **kwargs):
td = super(type(self), self).__getattribute__("_tensordict")
result = getattr(td, funcname)(*args, **kwargs)

def check_out(kwargs, result):
out = kwargs.get("out")
if out is result:
# No need to transform output
return True
return False

if isinstance(result, TensorDictBase) and not check_out(kwargs, result):
if result is td:
return self
nontd = super(type(self), self).__getattribute__("_non_tensordict")
if copy_non_tensor:
# use tree_map to copy
nontd = tree_map(lambda x: x, nontd)
return super(type(self), self).__getattribute__("_from_tensordict")(
result, nontd
)
return result

return wrapped_func


def _wrap_method(self, attr, func):
warnings.warn(
f"The method {func} wasn't explicitly implemented for tensorclass. "
f"This fallback will be deprecated in future releases because it is inefficient "
f"and non-compilable. Please raise an issue in tensordict repo to support this method!"
)

@functools.wraps(func)
def wrapped_func(*args, **kwargs):
args = tuple(_arg_to_tensordict(arg) for arg in args)
Expand Down
1 change: 0 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import numpy as np
import torch


try:
from functorch import dim as ftdim

Expand Down
2 changes: 2 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,12 +1399,14 @@ class MyDataNested:
assert data.z == "test_tensorclass"
assert data.y.z == "test_tensorclass"
data_replace = data.replace(replacement)

assert isinstance(data_replace, MyDataNested)
assert isinstance(data_replace.y, MyDataNested)
assert data.z == "test_tensorclass"
assert data.y.z == "test_tensorclass"
assert data_replace.z == "replacement"
assert data_replace.y.z == "replacement"

assert (data.X == 1).all()
assert (data.y.X == 1).all()
assert (data_replace.X == 0).all()
Expand Down