Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,5 @@ Utils
parse_tensor_dict_string
set_capture_non_tensor_stack
set_lazy_legacy
set_list_to_stack
list_to_stack
2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@
is_non_tensor,
is_tensorclass,
lazy_legacy,
list_to_stack,
parse_tensor_dict_string,
set_capture_non_tensor_stack,
set_lazy_legacy,
set_list_to_stack,
unravel_key,
unravel_key_list,
)
Expand Down
30 changes: 28 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
_is_number,
_maybe_correct_neg_dim,
_parse_to,
_recursive_unbind_list,
_renamed_inplace_method,
_shape,
_td_fields,
Expand All @@ -83,6 +84,7 @@
infer_size_impl,
is_non_tensor,
is_tensorclass,
list_to_stack,
lock_blocked,
NestedKey,
unravel_key_list,
Expand Down Expand Up @@ -198,6 +200,22 @@ class LazyStackedTensorDict(TensorDictBase):
>>> print(td_stack[:, 0] is tds[0])
True

.. note:: Lazy stacks support assignment via lists. For consistency, the lists should be
presented as `tensor.tolist()` data structure. This means that the length of the first
level of the nested lists should match the first dimension of the lazy stack (whether or
not this is the stack dimension).

>>> td = LazyStackedTensorDict(TensorDict(), TensorDict(), stack_dim=0)
>>> td["a"] = [torch.ones(2), torch.zeros(1)]
>>> assert td[1]["a"] == torch.zeros(1)
>>> td["b"] = ["a string", "another string"]
>>> assert td[1]["b"] == "another string"

.. note:: When using the :meth:`~.get` method, one can pass `as_nested_tensor`, `as_padded_tensor`
or the `as_list` arguments to control how the data should be presented if the dimensions of the
tensors mismatch. When passed, the nesting/padding will occur regardless of whether the
dimensions mismatch or not.

"""

_is_vmapped: bool = False
Expand Down Expand Up @@ -587,11 +605,19 @@ def _set_str(
"their register."
) from e
if not validated:
value = self._validate_value(value, non_blocking=non_blocking)
value = self._validate_value(
value, non_blocking=non_blocking, check_shape=not list_to_stack()
)
validated = True
if self._is_vmapped:
value = self.hook_in(value)
values = value.unbind(self.stack_dim)
if isinstance(value, list):
if self.stack_dim == 0:
values = list(value)
else:
values = _recursive_unbind_list(value, self.stack_dim)
else:
values = value.unbind(self.stack_dim)
for tensordict, item in _zip_strict(self.tensordicts, values):
tensordict._set_str(
key,
Expand Down
18 changes: 18 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from tensordict.utils import (
_as_context_manager,
_CloudpickleWrapper,
_convert_list_to_stack,
_DTYPE2STRDTYPE,
_GENERIC_NESTED_ERR,
_is_dataclass as is_dataclass,
Expand Down Expand Up @@ -95,6 +96,7 @@
is_namedtuple_class,
is_non_tensor,
lazy_legacy,
list_to_stack,
lock_blocked,
prod,
set_capture_non_tensor_stack,
Expand Down Expand Up @@ -11281,6 +11283,8 @@ def _convert_to_tensor(
castable = None
if isinstance(array, (float, int, bool)):
castable = True
elif isinstance(array, list) and list_to_stack():
return _convert_list_to_stack(array)[0]
elif isinstance(array, np.bool_):
castable = True
array = array.item()
Expand All @@ -11289,6 +11293,20 @@ def _convert_to_tensor(
return TensorDictBase.from_struct_array(array, device=self.device)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, (list, tuple)):
if isinstance(array, list) and list_to_stack(allow_none=True) is None:
warnings.warn(
"You are setting a list of elements within a tensordict without setting `set_list_to_stack`. "
"The current behaviour is that: if this list can be cast to a Tensor, it will be and will be written "
"as such. If it cannot, it will be converted to a numpy array and considered as a non-indexable "
"entity through a wrapping in a NonTensorData. If you want your list to be indexable along the "
"tensordict batch-size, use the decorator/context manager tensordict.set_list_to_stack(True), the "
"global flag `tensordict.set_list_to_stack(True).set()`, or "
"the environment variable LIST_TO_STACK=1 (or use False/0 to silence this warning). "
"This behavior will change in v0.10.0, and "
"lists will be automatically stacked.",
category=FutureWarning,
)

array = np.asarray(array)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif hasattr(array, "numpy"):
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ def log_prob(
*,
dist: torch.distributions.Distribution | None = None,
**kwargs,
):
) -> TensorDictBase | torch.Tensor:
"""Returns the log-probability of the input tensordict.

If `self.return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`,
Expand Down
136 changes: 136 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from torch._dynamo import assume_constant_result, is_compiling

if TYPE_CHECKING:
from tensordict.tensorclass import NonTensorStack
from tensordict.tensordict import TensorDictBase

try:
Expand Down Expand Up @@ -2013,6 +2014,141 @@ def capture_non_tensor_stack(allow_none=False):
)


# list to stack constrol
_DEFAULT_LIST_TO_STACK = None
_LIST_TO_STACK = os.environ.get("LIST_TO_STACK")


class set_list_to_stack(_DecoratorContextManager):
"""Context manager and decorator to control the behavior of list handling in TensorDict.

When enabled, lists assigned to a TensorDict will be automatically stacked along the batch dimension.
This can be useful for ensuring that lists of tensors or other elements are treated as stackable entities
within a TensorDict.

Current Behavior:
If a list is assigned to a TensorDict without this context manager, it will be converted to a numpy array
and wrapped in a NonTensorData if it cannot be cast to a Tensor.

Future Behavior:
In version 0.10.0, lists will be automatically stacked by default.

Args:
mode (bool): If True, enables list-to-stack conversion. If False, disables it.

.. warning::
A FutureWarning will be raised if a list is assigned to a TensorDict without setting this context manager
or the global flag, indicating that the behavior will change in the future.

Example:
>>> with set_list_to_stack(True):
... td = TensorDict(a=[torch.zeros(()), torch.ones(())], batch_size=2)
... assert (td["a"] == torch.tensor([0, 1])).all()
... assert td[0]["a"] == 0
... assert td[1]["a"] == 1

.. seealso:: :func:`~tensordict.list_to_stack`.

"""

def __init__(self, mode: bool) -> None:
super().__init__()
self.mode = mode

def clone(self) -> set_list_to_stack:
# override this method if your children class takes __init__ parameters
return type(self)(self.mode)

def __enter__(self) -> None:
self.set()

def set(self) -> None:
global _LIST_TO_STACK
self._old_mode = _LIST_TO_STACK
_LIST_TO_STACK = bool(self.mode)
os.environ["LIST_TO_STACK"] = str(_LIST_TO_STACK)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _LIST_TO_STACK
_LIST_TO_STACK = self._old_mode
os.environ["LIST_TO_STACK"] = str(_LIST_TO_STACK)


def list_to_stack(allow_none=False):
"""Retrieves the current setting for list-to-stack conversion in TensorDict.

This function checks the global environment variable or the context manager setting to determine
whether lists should be automatically stacked when assigned to a TensorDict.

Current Behavior:
Returns the current setting for list-to-stack conversion. If the setting is not defined and `allow_none`
is True, it returns None. Otherwise, it returns the default setting.

Future Behavior:
The default behavior will change in version 0.10.0 to automatically stack lists.

Args:
allow_none (bool): If True, allows the function to return None if the setting is not defined.

Returns:
bool or None: The current setting for list-to-stack conversion.

.. seealso:: :class:`~tensordict.set_list_to_stack`.

"""
global _LIST_TO_STACK
if _LIST_TO_STACK is None and allow_none:
return None
elif _LIST_TO_STACK is None:
return _DEFAULT_LIST_TO_STACK
elif _LIST_TO_STACK == "none":
return _DEFAULT_LIST_TO_STACK
return (
strtobool(_LIST_TO_STACK) if isinstance(_LIST_TO_STACK, str) else _LIST_TO_STACK
)


def _convert_list_to_stack(
a_list: list[Any],
) -> tuple[torch.Tensor | TensorDictBase | NonTensorStack, bool]: # noqa
# First, check elements and determine if there are lists within
nontensor = True
if all(isinstance(elt, list) for elt in a_list):
a_list, nontensor = zip(*[_convert_list_to_stack(elt) for elt in a_list])
nontensor = any(nontensor)
# FIXME: we should check that the type is unique
all_castable = all(isinstance(elt, (bool, int, float)) for elt in a_list)
if all_castable:
return torch.tensor(a_list), False
all_tensors = all(isinstance(elt, torch.Tensor) for elt in a_list)
if not nontensor or all_tensors:
# should we stack?
if all_tensors and len({x.shape for x in a_list}) == 1:
# FIXME: this may lead to some weird behaviours if we have nested lists and by chance one of them has
# things that can be stacked, and others don't.
return torch.stack(a_list), False
# TODO: check that LazyStack understands that a list is a bunch of elements to write in separate tds
return list(a_list), False
from tensordict.base import _is_tensor_collection

if all(_is_tensor_collection(type(elt)) for elt in a_list):
return torch.stack(a_list), False
from tensordict import NonTensorStack

return NonTensorStack(*a_list), True


def _recursive_unbind_list(a_list, dim):
if dim == 0:
return list(a_list)
try:
return map(
list, _zip_strict(*[_recursive_unbind_list(elt, dim - 1) for elt in a_list])
)
except Exception:
raise ValueError("lengths of nested lists differed.")


# Process initializer for map
def _proc_init(base_seed, queue, num_threads):
worker_id = queue.get(timeout=120)
Expand Down
2 changes: 2 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
is_tensor_collection,
NonTensorData,
NonTensorStack,
set_list_to_stack,
tensorclass,
TensorDict,
)
Expand Down Expand Up @@ -976,6 +977,7 @@ def forward(self, input):
out = module(TensorDict({"a": torch.randn(3)}, []))
assert (out["b"] == out["a"]).all()

@set_list_to_stack(True)
def test_tdmodule_inplace(self):
tdm = TensorDictModule(
lambda x: (x, x), in_keys=["x"], out_keys=["y", "z"], inplace=False
Expand Down
Loading
Loading