Skip to content

Commit

Permalink
[BugFix] Fix batch-size expansion in functionalization (#1959)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 23, 2024
1 parent 492091a commit 7782751
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Iterator, List, Optional, Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import is_tensor_collection, TensorDict, TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
from torch import nn
Expand Down Expand Up @@ -248,6 +248,13 @@ def convert_to_functional(
# For buffers, a cloned expansion (or equivalently a repeat) is returned.

def _compare_and_expand(param):
if is_tensor_collection(param):
return param._apply_nest(
_compare_and_expand,
batch_size=[expand_dim, *param.shape],
filter_empty=False,
call_on_nested=True,
)
if not isinstance(param, nn.Parameter):
buffer = param.expand(expand_dim, *param.shape).clone()
return buffer
Expand All @@ -257,7 +264,7 @@ def _compare_and_expand(param):
# is called:
return expanded_param
else:
p_out = param.repeat(expand_dim, *[1 for _ in param.shape])
p_out = param.expand(expand_dim, *param.shape).clone()
p_out = nn.Parameter(
p_out.uniform_(
p_out.min().item(), p_out.max().item()
Expand All @@ -270,6 +277,7 @@ def _compare_and_expand(param):
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
filter_empty=False,
call_on_nested=True,
),
no_convert=True,
)
Expand Down

0 comments on commit 7782751

Please sign in to comment.