-
Notifications
You must be signed in to change notification settings - Fork 420
Closed
pytorch/tensordict
#688Labels
bugSomething isn't workingSomething isn't working
Description
I am experiencing a new series of bugs in the BenchMARL library after #1921.
I'll try to list them here as I unravel them.
A first one can be observed when running python examples/multiagent/sac.py model.shared_parameters=False
This seems related to the use of ensembles in SAC
Traceback (most recent call last):
File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/sac.py", line 193, in train
loss_module = SACLoss(
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/sac.py", line 327, in __init__
self.convert_to_functional(
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/common.py", line 292, in convert_to_functional
with params.apply(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1152, in new_func
out = func(_self, *args, **kwargs)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 595, in to_module
return self._to_module(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
local_out = value._to_module(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
local_out = value._to_module(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 334, in _to_module
module.update(self)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1114, in new_func
return func(self, *args, **kwargs)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 417, in update
TensorDictBase.update(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2613, in update
target.update(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2620, in update
self._set_tuple(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1515, in _set_tuple
return self._set_str(key[0], value, inplace=inplace, validated=validated)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1485, in _set_str
value = self._validate_value(value, check_shape=True)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 4378, in _validate_value
raise RuntimeError(
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([3]) and value.shape=torch.Size([2, 3, 256, 54]).
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working