Skip to content

[BUG] Multiagent nets problems with SAC #1957

@matteobettini

Description

@matteobettini

I am experiencing a new series of bugs in the BenchMARL library after #1921.

I'll try to list them here as I unravel them.

A first one can be observed when running python examples/multiagent/sac.py model.shared_parameters=False

This seems related to the use of ensembles in SAC

Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/sac.py", line 193, in train
    loss_module = SACLoss(
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/sac.py", line 327, in __init__
    self.convert_to_functional(
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/common.py", line 292, in convert_to_functional
    with params.apply(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1152, in new_func
    out = func(_self, *args, **kwargs)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 595, in to_module
    return self._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
    local_out = value._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
    local_out = value._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 334, in _to_module
    module.update(self)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1114, in new_func
    return func(self, *args, **kwargs)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 417, in update
    TensorDictBase.update(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2613, in update
    target.update(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2620, in update
    self._set_tuple(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1515, in _set_tuple
    return self._set_str(key[0], value, inplace=inplace, validated=validated)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1485, in _set_str
    value = self._validate_value(value, check_shape=True)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 4378, in _validate_value
    raise RuntimeError(
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([3]) and value.shape=torch.Size([2, 3, 256, 54]).

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions