diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 136cbc237..8a16f1a56 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -21,10 +21,7 @@ ) FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." -import torch - from tensordict.nn.common import TensorDictModule -from tensordict.nn.probabilistic import ProbabilisticTensorDictModule from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from tensordict.utils import _normalize_key, NESTED_KEY from torch import nn @@ -49,7 +46,7 @@ class TensorDictSequential(TensorDictModule): stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. - TensorDictSequence supports functional, modular and vmap coding: + TensorDictSequential supports functional, modular and vmap coding: Examples: >>> import torch >>> from tensordict import TensorDict @@ -173,7 +170,7 @@ def select_subsequence( out_keys: output keys of the subsequence we want to select Returns: - A new TensorDictSequential with only the modules that are necessary acording to the given input and output keys. + A new TensorDictSequential with only the modules that are necessary according to the given input and output keys. """ if in_keys is None: in_keys = deepcopy(self.in_keys) @@ -257,40 +254,3 @@ def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None: def __delitem__(self, index: Union[int, slice]) -> None: self.module.__delitem__(idx=index) - - def get_dist( - self, - tensordict: TensorDictBase, - **kwargs, - ) -> Tuple[torch.distributions.Distribution, ...]: - if isinstance(self.module[-1], ProbabilisticTensorDictModule): - if kwargs: - raise RuntimeError( - "TensorDictSequential does not support keyword arguments other than 'params', 'buffers' and 'vmap'" - ) - tensordict = self[:-1](tensordict) - out = self[-1].get_dist(tensordict) - return out - else: - raise RuntimeError( - "Cannot call get_dist on a sequence of tensordicts that does not end with a probabilistic TensorDict. " - f"The sequence items were of type: {[type(m) for m in self.module]}" - ) - - def get_dist_params( - self, - tensordict: TensorDictBase, - **kwargs, - ) -> Tuple[torch.distributions.Distribution, ...]: - if isinstance(self.module[-1], ProbabilisticTensorDictModule): - if kwargs: - raise RuntimeError( - "TensorDictSequential does not support keyword arguments." - ) - tensordict = self[:-1](tensordict) - return self[-1].get_dist_params(tensordict) - else: - raise RuntimeError( - "Cannot call get_dist on a sequence of tensordicts that does not end with a probabilistic TensorDict. " - f"The sequence items were of type: {[type(m) for m in self.module]}" - ) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 5af8e8758..657b4d349 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -358,9 +358,6 @@ def test_stateful(self, lazy): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td) - @pytest.mark.parametrize("lazy", [True, False]) def test_stateful_probabilistic(self, lazy): torch.manual_seed(0) @@ -374,36 +371,36 @@ def test_stateful_probabilistic(self, lazy): dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) kwargs = {"distribution_class": Normal} tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) dummy_tdmodule = TensorDictModule( dummy_net, in_keys=["hidden"], out_keys=["hidden"] ) - tdmodule2 = ProbabilisticTensorDictModule( - module=net2, + tdmodule2 = TensorDictModule( + net2, in_keys=["hidden"], out_keys=["loc", "scale"] + ) + + tdmodule = ProbabilisticTensorDictModule( + module=TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2), dist_in_keys=["loc", "scale"], sample_out_key=["out"], **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - assert len(tdmodule) == 3 + assert hasattr(tdmodule.module, "__setitem__") + assert len(tdmodule.module) == 3 + tdmodule.module[1] = tdmodule2 + assert len(tdmodule.module) == 3 - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - assert len(tdmodule) == 2 + assert hasattr(tdmodule.module, "__delitem__") + assert len(tdmodule.module) == 3 + del tdmodule.module[2] + assert len(tdmodule.module) == 2 - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 + assert hasattr(tdmodule.module, "__getitem__") + assert tdmodule.module[0] is tdmodule1 + assert tdmodule.module[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 3)}, [3]) tdmodule(td) @@ -454,9 +451,6 @@ def test_functional(self): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params) - @pytest.mark.skipif( not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" ) @@ -482,9 +476,6 @@ def test_functional_functorch(self): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params) - @pytest.mark.skipif( not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" ) @@ -497,38 +488,40 @@ def test_functional_probabilistic(self): net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) - kwargs = {"distribution_class": Normal} tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) dummy_tdmodule = TensorDictModule( dummy_net, in_keys=["hidden"], out_keys=["hidden"] ) - tdmodule2 = ProbabilisticTensorDictModule( - net2, dist_in_keys=["loc", "scale"], sample_out_key=["out"], **kwargs + tdmodule2 = TensorDictModule( + net2, in_keys=["hidden"], out_keys=["loc", "scale"] + ) + + tdmodule = ProbabilisticTensorDictModule( + module=TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2), + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], + **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"]) - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 + assert hasattr(tdmodule.module, "__setitem__") + assert len(tdmodule.module) == 3 + tdmodule.module[1] = tdmodule2 + params["module", "module", "1"] = params["module", "module", "2"] + assert len(tdmodule.module) == 3 - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - del params["module", "2"] - assert len(tdmodule) == 2 + assert hasattr(tdmodule.module, "__delitem__") + assert len(tdmodule.module) == 3 + del tdmodule.module[2] + del params["module", "module", "2"] + assert len(tdmodule.module) == 2 - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 + assert hasattr(tdmodule.module, "__getitem__") + assert tdmodule.module[0] is tdmodule1 + assert tdmodule.module[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 3)}, [3]) tdmodule(td, params=params) @@ -579,9 +572,6 @@ def test_functional_with_buffer(self): td = TensorDict({"in": torch.randn(3, 7)}, [3]) tdmodule(td, params=params) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params) - assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 7]) @@ -598,35 +588,39 @@ def test_functional_with_buffer_probabilistic(self): nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) kwargs = {"distribution_class": Normal} tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) dummy_tdmodule = TensorDictModule( dummy_net, in_keys=["hidden"], out_keys=["hidden"] ) - tdmodule2 = ProbabilisticTensorDictModule( - net2, dist_in_keys=["loc", "scale"], sample_out_key=["out"], **kwargs + tdmodule2 = TensorDictModule( + net2, in_keys=["hidden"], out_keys=["loc", "scale"] + ) + tdmodule = ProbabilisticTensorDictModule( + module=TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2), + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], + **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) params = make_functional(tdmodule, ["forward", "get_dist"]) - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 + assert hasattr(tdmodule.module, "__setitem__") + assert len(tdmodule.module) == 3 + tdmodule.module[1] = tdmodule2 + params["module", "module", "1"] = params["module", "module", "2"] + assert len(tdmodule.module) == 3 - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - del params["module", "2"] - assert len(tdmodule) == 2 + assert hasattr(tdmodule.module, "__delitem__") + assert len(tdmodule.module) == 3 + del tdmodule.module[2] + del params["module", "module", "2"] + assert len(tdmodule.module) == 2 - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 + assert hasattr(tdmodule.module, "__getitem__") + assert tdmodule.module[0] is tdmodule1 + assert tdmodule.module[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 7)}, [3]) tdmodule(td, params=params) @@ -700,14 +694,18 @@ def test_vmap_probabilistic(self): net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) kwargs = {"distribution_class": Normal} tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) - tdmodule2 = ProbabilisticTensorDictModule( - net2, sample_out_key=["out"], dist_in_keys=["loc", "scale"], **kwargs + tdmodule2 = TensorDictModule( + net2, in_keys=["hidden"], out_keys=["loc", "scale"] + ) + tdmodule = ProbabilisticTensorDictModule( + module=TensorDictSequential(tdmodule1, tdmodule2), + sample_out_key=["out"], + dist_in_keys=["loc", "scale"], + **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) params = make_functional(tdmodule)