Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 2 additions & 42 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
)
FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality."

import torch

from tensordict.nn.common import TensorDictModule
from tensordict.nn.probabilistic import ProbabilisticTensorDictModule
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import _normalize_key, NESTED_KEY
from torch import nn
Expand All @@ -49,7 +46,7 @@ class TensorDictSequential(TensorDictModule):
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
looking for those that have the required keys, if any.

TensorDictSequence supports functional, modular and vmap coding:
TensorDictSequential supports functional, modular and vmap coding:
Examples:
>>> import torch
>>> from tensordict import TensorDict
Expand Down Expand Up @@ -173,7 +170,7 @@ def select_subsequence(
out_keys: output keys of the subsequence we want to select

Returns:
A new TensorDictSequential with only the modules that are necessary acording to the given input and output keys.
A new TensorDictSequential with only the modules that are necessary according to the given input and output keys.
"""
if in_keys is None:
in_keys = deepcopy(self.in_keys)
Expand Down Expand Up @@ -257,40 +254,3 @@ def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None:

def __delitem__(self, index: Union[int, slice]) -> None:
self.module.__delitem__(idx=index)

def get_dist(
self,
tensordict: TensorDictBase,
**kwargs,
) -> Tuple[torch.distributions.Distribution, ...]:
if isinstance(self.module[-1], ProbabilisticTensorDictModule):
if kwargs:
raise RuntimeError(
"TensorDictSequential does not support keyword arguments other than 'params', 'buffers' and 'vmap'"
)
tensordict = self[:-1](tensordict)
out = self[-1].get_dist(tensordict)
return out
else:
raise RuntimeError(
"Cannot call get_dist on a sequence of tensordicts that does not end with a probabilistic TensorDict. "
f"The sequence items were of type: {[type(m) for m in self.module]}"
)

def get_dist_params(
self,
tensordict: TensorDictBase,
**kwargs,
) -> Tuple[torch.distributions.Distribution, ...]:
if isinstance(self.module[-1], ProbabilisticTensorDictModule):
if kwargs:
raise RuntimeError(
"TensorDictSequential does not support keyword arguments."
)
tensordict = self[:-1](tensordict)
return self[-1].get_dist_params(tensordict)
else:
raise RuntimeError(
"Cannot call get_dist on a sequence of tensordicts that does not end with a probabilistic TensorDict. "
f"The sequence items were of type: {[type(m) for m in self.module]}"
)
138 changes: 68 additions & 70 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,6 @@ def test_stateful(self, lazy):
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 4])

with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"):
dist, *_ = tdmodule.get_dist(td)

@pytest.mark.parametrize("lazy", [True, False])
def test_stateful_probabilistic(self, lazy):
torch.manual_seed(0)
Expand All @@ -374,36 +371,36 @@ def test_stateful_probabilistic(self, lazy):
dummy_net = nn.Linear(4, 4)
net2 = nn.Linear(4, 4 * param_multiplier)
net2 = NormalParamWrapper(net2)
net2 = TensorDictModule(
module=net2, in_keys=["hidden"], out_keys=["loc", "scale"]
)

kwargs = {"distribution_class": Normal}
tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"])
dummy_tdmodule = TensorDictModule(
dummy_net, in_keys=["hidden"], out_keys=["hidden"]
)
tdmodule2 = ProbabilisticTensorDictModule(
module=net2,
tdmodule2 = TensorDictModule(
net2, in_keys=["hidden"], out_keys=["loc", "scale"]
)

tdmodule = ProbabilisticTensorDictModule(
module=TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2),
dist_in_keys=["loc", "scale"],
sample_out_key=["out"],
**kwargs,
)
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)

assert hasattr(tdmodule, "__setitem__")
assert len(tdmodule) == 3
tdmodule[1] = tdmodule2
assert len(tdmodule) == 3
assert hasattr(tdmodule.module, "__setitem__")
assert len(tdmodule.module) == 3
tdmodule.module[1] = tdmodule2
assert len(tdmodule.module) == 3

assert hasattr(tdmodule, "__delitem__")
assert len(tdmodule) == 3
del tdmodule[2]
assert len(tdmodule) == 2
assert hasattr(tdmodule.module, "__delitem__")
assert len(tdmodule.module) == 3
del tdmodule.module[2]
assert len(tdmodule.module) == 2

assert hasattr(tdmodule, "__getitem__")
assert tdmodule[0] is tdmodule1
assert tdmodule[1] is tdmodule2
assert hasattr(tdmodule.module, "__getitem__")
assert tdmodule.module[0] is tdmodule1
assert tdmodule.module[1] is tdmodule2

td = TensorDict({"in": torch.randn(3, 3)}, [3])
tdmodule(td)
Expand Down Expand Up @@ -454,9 +451,6 @@ def test_functional(self):
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 4])

with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"):
dist, *_ = tdmodule.get_dist(td, params=params)

@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}"
)
Expand All @@ -482,9 +476,6 @@ def test_functional_functorch(self):
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 4])

with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"):
dist, *_ = tdmodule.get_dist(td, params=params)

@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}"
)
Expand All @@ -497,38 +488,40 @@ def test_functional_probabilistic(self):
net2 = nn.Linear(4, 4 * param_multiplier)
net2 = NormalParamWrapper(net2)

net2 = TensorDictModule(
module=net2, in_keys=["hidden"], out_keys=["loc", "scale"]
)

kwargs = {"distribution_class": Normal}

tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"])
dummy_tdmodule = TensorDictModule(
dummy_net, in_keys=["hidden"], out_keys=["hidden"]
)
tdmodule2 = ProbabilisticTensorDictModule(
net2, dist_in_keys=["loc", "scale"], sample_out_key=["out"], **kwargs
tdmodule2 = TensorDictModule(
net2, in_keys=["hidden"], out_keys=["loc", "scale"]
)

tdmodule = ProbabilisticTensorDictModule(
module=TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2),
dist_in_keys=["loc", "scale"],
sample_out_key=["out"],
**kwargs,
)
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)

params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"])

assert hasattr(tdmodule, "__setitem__")
assert len(tdmodule) == 3
tdmodule[1] = tdmodule2
params["module", "1"] = params["module", "2"]
assert len(tdmodule) == 3
assert hasattr(tdmodule.module, "__setitem__")
assert len(tdmodule.module) == 3
tdmodule.module[1] = tdmodule2
params["module", "module", "1"] = params["module", "module", "2"]
assert len(tdmodule.module) == 3

assert hasattr(tdmodule, "__delitem__")
assert len(tdmodule) == 3
del tdmodule[2]
del params["module", "2"]
assert len(tdmodule) == 2
assert hasattr(tdmodule.module, "__delitem__")
assert len(tdmodule.module) == 3
del tdmodule.module[2]
del params["module", "module", "2"]
assert len(tdmodule.module) == 2

assert hasattr(tdmodule, "__getitem__")
assert tdmodule[0] is tdmodule1
assert tdmodule[1] is tdmodule2
assert hasattr(tdmodule.module, "__getitem__")
assert tdmodule.module[0] is tdmodule1
assert tdmodule.module[1] is tdmodule2

td = TensorDict({"in": torch.randn(3, 3)}, [3])
tdmodule(td, params=params)
Expand Down Expand Up @@ -579,9 +572,6 @@ def test_functional_with_buffer(self):
td = TensorDict({"in": torch.randn(3, 7)}, [3])
tdmodule(td, params=params)

with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"):
dist, *_ = tdmodule.get_dist(td, params=params)

assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 7])

Expand All @@ -598,35 +588,39 @@ def test_functional_with_buffer_probabilistic(self):
nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier)
)
net2 = NormalParamWrapper(net2)
net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"])

kwargs = {"distribution_class": Normal}
tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"])
dummy_tdmodule = TensorDictModule(
dummy_net, in_keys=["hidden"], out_keys=["hidden"]
)
tdmodule2 = ProbabilisticTensorDictModule(
net2, dist_in_keys=["loc", "scale"], sample_out_key=["out"], **kwargs
tdmodule2 = TensorDictModule(
net2, in_keys=["hidden"], out_keys=["loc", "scale"]
)
tdmodule = ProbabilisticTensorDictModule(
module=TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2),
dist_in_keys=["loc", "scale"],
sample_out_key=["out"],
**kwargs,
)
tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2)

params = make_functional(tdmodule, ["forward", "get_dist"])

assert hasattr(tdmodule, "__setitem__")
assert len(tdmodule) == 3
tdmodule[1] = tdmodule2
params["module", "1"] = params["module", "2"]
assert len(tdmodule) == 3
assert hasattr(tdmodule.module, "__setitem__")
assert len(tdmodule.module) == 3
tdmodule.module[1] = tdmodule2
params["module", "module", "1"] = params["module", "module", "2"]
assert len(tdmodule.module) == 3

assert hasattr(tdmodule, "__delitem__")
assert len(tdmodule) == 3
del tdmodule[2]
del params["module", "2"]
assert len(tdmodule) == 2
assert hasattr(tdmodule.module, "__delitem__")
assert len(tdmodule.module) == 3
del tdmodule.module[2]
del params["module", "module", "2"]
assert len(tdmodule.module) == 2

assert hasattr(tdmodule, "__getitem__")
assert tdmodule[0] is tdmodule1
assert tdmodule[1] is tdmodule2
assert hasattr(tdmodule.module, "__getitem__")
assert tdmodule.module[0] is tdmodule1
assert tdmodule.module[1] is tdmodule2

td = TensorDict({"in": torch.randn(3, 7)}, [3])
tdmodule(td, params=params)
Expand Down Expand Up @@ -700,14 +694,18 @@ def test_vmap_probabilistic(self):

net2 = nn.Linear(4, 4 * param_multiplier)
net2 = NormalParamWrapper(net2)
net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"])

kwargs = {"distribution_class": Normal}
tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"])
tdmodule2 = ProbabilisticTensorDictModule(
net2, sample_out_key=["out"], dist_in_keys=["loc", "scale"], **kwargs
tdmodule2 = TensorDictModule(
net2, in_keys=["hidden"], out_keys=["loc", "scale"]
)
tdmodule = ProbabilisticTensorDictModule(
module=TensorDictSequential(tdmodule1, tdmodule2),
sample_out_key=["out"],
dist_in_keys=["loc", "scale"],
**kwargs,
)
tdmodule = TensorDictSequential(tdmodule1, tdmodule2)

params = make_functional(tdmodule)

Expand Down