diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index cc11ac8b29e..6b6fd391560 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,7 +12,7 @@ from typing import Iterator, List, Optional, Tuple import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from torch import nn @@ -248,6 +248,13 @@ def convert_to_functional( # For buffers, a cloned expansion (or equivalently a repeat) is returned. def _compare_and_expand(param): + if is_tensor_collection(param): + return param._apply_nest( + _compare_and_expand, + batch_size=[expand_dim, *param.shape], + filter_empty=False, + call_on_nested=True, + ) if not isinstance(param, nn.Parameter): buffer = param.expand(expand_dim, *param.shape).clone() return buffer @@ -257,7 +264,7 @@ def _compare_and_expand(param): # is called: return expanded_param else: - p_out = param.repeat(expand_dim, *[1 for _ in param.shape]) + p_out = param.expand(expand_dim, *param.shape).clone() p_out = nn.Parameter( p_out.uniform_( p_out.min().item(), p_out.max().item() @@ -270,6 +277,7 @@ def _compare_and_expand(param): _compare_and_expand, batch_size=[expand_dim, *params.shape], filter_empty=False, + call_on_nested=True, ), no_convert=True, )