Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ docs/src

*/_C.so
tensordict/_version.py

scratch/*.py
18 changes: 16 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,22 @@ def _has_exclusive_keys(self):

@_fails_exclusive_keys
def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
) -> dict[str, Any]: ...
self,
*,
retain_none: bool = True,
convert_tensors: bool = False,
tolist_first: bool = False,
) -> Dict[str, Any]:
result = {}
for td in self.tensordicts:
result.update(
td.to_dict(
retain_none=retain_none,
convert_tensors=convert_tensors,
tolist_first=tolist_first,
)
)
return result

def _reduce_get_metadata(self):
metadata = {}
Expand Down
47 changes: 35 additions & 12 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3692,7 +3692,7 @@ def repeat(self, repeats: torch.Size): ...
def repeat(self, *repeats: int) -> TensorDictBase:
"""Repeats this tensor along the specified dimensions.

Unlike :meth:`~.expand()`, this function copies the tensors data.
Unlike :meth:`~.expand()`, this function copies the tensor's data.

.. warning:: :meth:`~.repeat` behaves differently from :func:`~numpy.repeat`, but is more similar to
:func:`numpy.tile`. For the operator similar to :func:`numpy.repeat`, see :meth:`~tensordict.TensorDictBase.repeat_interleave`.
Expand Down Expand Up @@ -12105,7 +12105,11 @@ def as_tensor(tensor):
return self._fast_apply(as_tensor, propagate_lock=True)

def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
self,
*,
retain_none: bool = True,
convert_tensors: bool = False,
tolist_first: bool = False,
) -> dict[str, Any]:
"""Returns a dictionary with key-value pairs matching those of the tensordict.

Expand All @@ -12115,6 +12119,8 @@ def to_dict(
Otherwise, they will be discarded. Default: ``True``.
convert_tensors (bool): if ``True``, tensors will be converted to lists when creating the dictionary.
Otherwise, they will remain as tensors. Default: ``False``.
tolist_first (bool): if ``True``, the tensordict will be converted to a list first when
it has batch dimensions. Default: ``False``.

Returns:
A dictionary representation of the tensordict.
Expand Down Expand Up @@ -12157,16 +12163,23 @@ def to_dict(
and value.data is None
):
continue
value = value.to_dict(
retain_none=retain_none, convert_tensors=convert_tensors
)
if tolist_first:
value = value.tolist(convert_tensors=convert_tensors)
else:
value = value.to_dict(
retain_none=retain_none, convert_tensors=convert_tensors
)
elif convert_tensors and hasattr(value, "tolist"):
value = value.tolist()
result[key] = value
return result

def tolist(
self, *, convert_nodes: bool = True, convert_tensors: bool = False
self,
*,
convert_nodes: bool = True,
convert_tensors: bool = False,
tolist_first: bool = False,
) -> List[Any]:
"""Returns a nested list representation of the tensordict.

Expand All @@ -12178,6 +12191,8 @@ def tolist(
Otherwise, they will be returned as lists of values. Default: ``True``.
convert_tensors (bool): if ``True``, tensors will be converted to lists when creating the dictionary.
Otherwise, they will remain as tensors. Default: ``False``.
tolist_first (bool): if ``True``, the tensordict will be converted to a list first when
it has batch dimensions. Default: ``False``.

Returns:
A nested list representation of the tensordict.
Expand All @@ -12191,11 +12206,12 @@ def tolist(
... b=TensorDict(c=torch.arange(12).reshape(2, 3, 2), batch_size=(2, 3, 2)),
... batch_size=(2, 3)
... )
>>>
>>> print(td.tolist())
>>> print(td.tolist(tolist_first=True))
[[{'a': tensor([0, 1, 2, 3]), 'b': [{'c': tensor(0)}, {'c': tensor(1)}]}, {'a': tensor([4, 5, 6, 7]), 'b': [{'c': tensor(2)}, {'c': tensor(3)}]}, {'a': tensor([ 8, 9, 10, 11]), 'b': [{'c': tensor(4)}, {'c': tensor(5)}]}], [{'a': tensor([12, 13, 14, 15]), 'b': [{'c': tensor(6)}, {'c': tensor(7)}]}, {'a': tensor([16, 17, 18, 19]), 'b': [{'c': tensor(8)}, {'c': tensor(9)}]}, {'a': tensor([20, 21, 22, 23]), 'b': [{'c': tensor(10)}, {'c': tensor(11)}]}]]
>>> print(td.tolist(tolist_first=False))
[[{'a': tensor([0, 1, 2, 3]), 'b': {'c': tensor([0, 1])}}, {'a': tensor([4, 5, 6, 7]), 'b': {'c': tensor([2, 3])}}, {'a': tensor([ 8, 9, 10, 11]), 'b': {'c': tensor([4, 5])}}], [{'a': tensor([12, 13, 14, 15]), 'b': {'c': tensor([6, 7])}}, {'a': tensor([16, 17, 18, 19]), 'b': {'c': tensor([8, 9])}}, {'a': tensor([20, 21, 22, 23]), 'b': {'c': tensor([10, 11])}}]]
>>> print(td.tolist(convert_tensors=True))
[[{'a': [0, 1, 2, 3], 'b': {'c': [0, 1]}}, {'a': [4, 5, 6, 7], 'b': {'c': [2, 3]}}, {'a': [8, 9, 10, 11], 'b': {'c': [4, 5]}}], [{'a': [12, 13, 14, 15], 'b': {'c': [6, 7]}}, {'a': [16, 17, 18, 19], 'b': {'c': [8, 9]}}, {'a': [20, 21, 22, 23], 'b': {'c': [10, 11]}}]]
>>> print(td.tolist(convert_tensors=False))
[[{'a': [0, 1, 2, 3], 'b': [{'c': 0}, {'c': 1}]}, {'a': [4, 5, 6, 7], 'b': [{'c': 2}, {'c': 3}]}, {'a': [8, 9, 10, 11], 'b': [{'c': 4}, {'c': 5}]}], [{'a': [12, 13, 14, 15], 'b': [{'c': 6}, {'c': 7}]}, {'a': [16, 17, 18, 19], 'b': [{'c': 8}, {'c': 9}]}, {'a': [20, 21, 22, 23], 'b': [{'c': 10}, {'c': 11}]}]]
>>> print(td.tolist(convert_nodes=False))
[[[tensor([0, 1, 2, 3]), TensorDict(
fields={
Expand Down Expand Up @@ -12234,7 +12250,9 @@ def tolist(
raise TypeError("convert_tensors requires convert_nodes to be set to True")
if not self.batch_dims:
if convert_nodes:
return self.to_dict(convert_tensors=convert_tensors)
return self.to_dict(
convert_tensors=convert_tensors, tolist_first=tolist_first
)
return self

q = collections.deque()
Expand All @@ -12245,7 +12263,12 @@ def tolist(
vals = val.unbind(0)
if val.ndim == 1:
if convert_nodes:
vals = [v.to_dict(convert_tensors=convert_tensors) for v in vals]
vals = [
v.to_dict(
convert_tensors=convert_tensors, tolist_first=tolist_first
)
for v in vals
]
else:
vals = list(vals)
_result.extend(vals)
Expand Down
4 changes: 3 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,9 @@ def _write_to_tensordict(
else:
tensordict_out = tensordict
if len(tensors) > len(out_keys):
raise RuntimeError("There are more tensors than out_keys.")
raise RuntimeError(
f"There are more tensors ({len(tensors)=}) than out_keys ({out_keys=})."
)
elif len(out_keys) > len(tensors):
raise RuntimeError("There are more out_keys than tensors.")
for _out_key, _tensor in zip(out_keys, tensors):
Expand Down
31 changes: 24 additions & 7 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3354,7 +3354,11 @@ def _apply_nest(self, *args, out=None, **kwargs):
)

def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
self,
*,
retain_none: bool = True,
convert_tensors: bool = False,
tolist_first: bool = False,
) -> dict[str, Any]:
# override to_dict to return just the data
return self.data
Expand Down Expand Up @@ -3418,19 +3422,23 @@ def _multithread_rebuild(self, *args, **kwargs):
nowarn=True,
)(*args, **kwargs)

def tolist(self, *, convert_tensors: bool = False):
def tolist(self, *, convert_tensors: bool = False, tolist_first: bool = False):
"""Converts the data in a list if the batch-size is non-empty.

If the batch-size is empty, returns the data.

Keyword Args:
convert_tensors (bool, optional): if ``True``, tensors will be converted to lists.
Otherwise, they will remain as tensors. Default: ``False``.

tolist_first (bool, optional): if ``True``, the tensordict will be converted to a list first when
it has batch dimensions. Default: ``True``.
"""
if not self.batch_size:
return self.data
return [ntd.tolist(convert_tensors=convert_tensors) for ntd in self.unbind(0)]
return [
ntd.tolist(convert_tensors=convert_tensors, tolist_first=tolist_first)
for ntd in self.unbind(0)
]

def copy_(
self, src: NonTensorDataBase | NonTensorStack, non_blocking: bool = False
Expand Down Expand Up @@ -3942,12 +3950,14 @@ def __init__(self, *args, **kwargs):
if not all(is_non_tensor(item) for item in self.tensordicts):
raise RuntimeError("All tensordicts must be non-tensors.")

def tolist(self, *, convert_tensors: bool = False):
def tolist(self, *, convert_tensors: bool = False, tolist_first: bool = False):
"""Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list.

Keyword Args:
convert_tensors (bool): if ``True``, tensors will be converted to lists.
Otherwise, they will remain as tensors. Default: ``False``.
tolist_first (bool, optional): if ``True``, the tensordict will be converted to a list first when
it has batch dimensions. Default: ``True``.

Examples:
>>> from tensordict import NonTensorData
Expand All @@ -3960,7 +3970,10 @@ def tolist(self, *, convert_tensors: bool = False):

"""
iterator = self.tensordicts if self.stack_dim == 0 else self.unbind(0)
return [td.tolist(convert_tensors=convert_tensors) for td in iterator]
return [
td.tolist(convert_tensors=convert_tensors, tolist_first=tolist_first)
for td in iterator
]

def maybe_to_stack(self):
"""Placeholder for interchangeability between stack and non-stack of non-tensors."""
Expand Down Expand Up @@ -4029,7 +4042,11 @@ def lazy_stack(
return result

def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
self,
*,
retain_none: bool = True,
convert_tensors: bool = False,
tolist_first: bool = False,
) -> dict[str, Any]:
return self.tolist(convert_tensors=convert_tensors)

Expand Down
12 changes: 10 additions & 2 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -995,10 +995,18 @@ class TensorClass:
): ...
def as_tensor(self): ...
def to_dict(
self, *, retain_none: bool = True, convert_tensors: bool = False
self,
*,
retain_none: bool = True,
convert_tensors: bool = False,
tolist_first: bool = False,
) -> dict[str, Any]: ...
def tolist(
self, *, convert_nodes: bool = True, convert_tensors: bool = False
self,
*,
convert_nodes: bool = True,
convert_tensors: bool = False,
tolist_first: bool = False,
) -> List[Any]: ...
def numpy(self): ...
def to_namedtuple(self, dest_cls: type | None = None): ...
Expand Down
29 changes: 29 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3180,6 +3180,35 @@ def test_tolist(self, convert_nodes, convert_tensors):
else:
assert isinstance(tdlist[0][0]["b"], TensorDict)

def test_tolist_first(self):
"""Tests the behavior of tolist_first parameter in tolist() method."""
td = TensorDict(
a=torch.arange(24).view(2, 3, 4),
b=TensorDict(c=torch.arange(12).reshape(2, 3, 2), batch_size=(2, 3, 2)),
batch_size=(2, 3),
)

# Test with tolist_first=True
result_true = td.tolist(tolist_first=True)
# First element should be a list of 3 dictionaries
assert len(result_true[0]) == 3
# Each dictionary should have 'a' as tensor and 'b' as a list of dictionaries
assert isinstance(result_true[0][0]["a"], torch.Tensor)
assert result_true[0][0]["a"].equal(torch.tensor([0, 1, 2, 3]))
assert isinstance(result_true[0][0]["b"], list)
assert len(result_true[0][0]["b"]) == 2
assert result_true[0][0]["b"][0]["c"].equal(torch.tensor(0))

# Test with tolist_first=False
result_false = td.tolist(tolist_first=False)
# First element should be a list of 3 dictionaries
assert len(result_false[0]) == 3
# Each dictionary should have 'a' as tensor and 'b' as a dictionary with 'c' as tensor
assert isinstance(result_false[0][0]["a"], torch.Tensor)
assert result_false[0][0]["a"].equal(torch.tensor([0, 1, 2, 3]))
assert isinstance(result_false[0][0]["b"], dict)
assert result_false[0][0]["b"]["c"].equal(torch.tensor([0, 1]))

def test_unbind_batchsize(self):
td = TensorDict({"a": TensorDict({"b": torch.zeros(2, 3)}, [2, 3])}, [2])
td["a"].batch_size
Expand Down
Loading