Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch.nn.parameter import UninitializedTensorMixin
from torch.nn.parameter import Parameter, UninitializedTensorMixin
from torch.utils._pytree import tree_map

try:
Expand Down Expand Up @@ -3196,21 +3196,22 @@ def count_bytes(tensor):
if isinstance(tensor, MemoryMappedTensor):
add(tensor)
return
if type(tensor) is not torch.Tensor:
try:
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
except AttributeError:
warnings.warn(
"The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it "
"impossible to count the bytes it contains. Falling back on regular count.",
category=UserWarning,
)
count_bytes(torch.as_tensor(tensor))
return
if type(tensor) in (Tensor, Parameter, Buffer):
pass
elif hasattr(tensor, "__tensor_flatten__"):
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
else:
warnings.warn(
"The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it "
"impossible to count the bytes it contains. Falling back on regular count.",
category=UserWarning,
)
count_bytes(torch.as_tensor(tensor))
return

grad = getattr(tensor, "grad", None)
if grad is not None:
Expand Down
Loading