Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,6 @@ def test_stateful(self, safe, spec_type, lazy):
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 4])

with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"):
dist, *_ = tdmodule.get_dist(td)

# test bounds
if not safe and spec_type == "bounded":
assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any()
Expand Down Expand Up @@ -856,9 +853,6 @@ def test_functional(self, safe, spec_type):
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 4])

with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"):
dist, *_ = tdmodule.get_dist(td, params=params)

# test bounds
if not safe and spec_type == "bounded":
assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any()
Expand Down Expand Up @@ -1012,9 +1006,6 @@ def test_functional_with_buffer(self, safe, spec_type):
td = TensorDict({"in": torch.randn(3, 7)}, [3])
tdmodule(td, params=params)

with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"):
dist, *_ = tdmodule.get_dist(td, params=params)

assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 7])

Expand Down
8 changes: 4 additions & 4 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import warnings
from typing import Optional, Sequence, Type, Union

from tensordict.nn import TensorDictModule
from tensordict.nn.prototype import (
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from tensordict.tensordict import TensorDictBase

Expand All @@ -20,7 +20,7 @@


class SafeProbabilisticModule(ProbabilisticTensorDictModule):
"""A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.prototype.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain.
"""A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain.

`SafeProbabilisticModule` is a non-parametric module representing a
probability distribution. It reads the distribution parameters from an input
Expand Down Expand Up @@ -190,7 +190,7 @@ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase:


class SafeProbabilisticSequential(ProbabilisticTensorDictSequential, SafeSequential):
"""A :obj:``SafeProbabilisticSequential`` is an :obj:``tensordict.nn.prototype.ProbabilisticTensorDictSequential`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain.
"""A :obj:``SafeProbabilisticSequential`` is an :obj:``tensordict.nn.ProbabilisticTensorDictSequential`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain.

Similarly to :obj:`TensorDictSequential`, but enforces that the final module in the
sequence is an :obj:`ProbabilisticTensorDictModule` and also exposes ``get_dist``
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/tensordict_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def forward(self, x):
print("the output tensordict shape is: ", result_td.shape)


from tensordict.nn.prototype import (
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/torchrl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@
tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(tensordict) # action is the default value

from tensordict.nn.prototype import (
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
Expand Down