From 5bf16b1925035bafbd308ec2038e11401b9aee53 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 24 Oct 2024 12:22:50 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/base.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 358cae1b1..a7b65c04f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -94,7 +94,7 @@ unravel_key_list, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor -from torch.nn.parameter import UninitializedTensorMixin +from torch.nn.parameter import Parameter, UninitializedTensorMixin from torch.utils._pytree import tree_map try: @@ -3196,21 +3196,22 @@ def count_bytes(tensor): if isinstance(tensor, MemoryMappedTensor): add(tensor) return - if type(tensor) is not torch.Tensor: - try: - attrs, ctx = tensor.__tensor_flatten__() - for attr in attrs: - t = getattr(tensor, attr) - count_bytes(t) - return - except AttributeError: - warnings.warn( - "The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it " - "impossible to count the bytes it contains. Falling back on regular count.", - category=UserWarning, - ) - count_bytes(torch.as_tensor(tensor)) - return + if type(tensor) in (Tensor, Parameter, Buffer): + pass + elif hasattr(tensor, "__tensor_flatten__"): + attrs, ctx = tensor.__tensor_flatten__() + for attr in attrs: + t = getattr(tensor, attr) + count_bytes(t) + return + else: + warnings.warn( + "The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it " + "impossible to count the bytes it contains. Falling back on regular count.", + category=UserWarning, + ) + count_bytes(torch.as_tensor(tensor)) + return grad = getattr(tensor, "grad", None) if grad is not None: