diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 619b04c700e..b7e1708bb77 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -692,9 +692,6 @@ def test_stateful(self, safe, spec_type, lazy): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td) - # test bounds if not safe and spec_type == "bounded": assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() @@ -856,9 +853,6 @@ def test_functional(self, safe, spec_type): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params) - # test bounds if not safe and spec_type == "bounded": assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() @@ -1012,9 +1006,6 @@ def test_functional_with_buffer(self, safe, spec_type): td = TensorDict({"in": torch.randn(3, 7)}, [3]) tdmodule(td, params=params) - with pytest.raises(RuntimeError, match="Cannot call get_dist on a sequence"): - dist, *_ = tdmodule.get_dist(td, params=params) - assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 7]) diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 7e4bc6a68ca..a399edde0b5 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -6,10 +6,10 @@ import warnings from typing import Optional, Sequence, Type, Union -from tensordict.nn import TensorDictModule -from tensordict.nn.prototype import ( +from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, + TensorDictModule, ) from tensordict.tensordict import TensorDictBase @@ -20,7 +20,7 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): - """A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.prototype.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. + """A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. `SafeProbabilisticModule` is a non-parametric module representing a probability distribution. It reads the distribution parameters from an input @@ -190,7 +190,7 @@ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: class SafeProbabilisticSequential(ProbabilisticTensorDictSequential, SafeSequential): - """A :obj:``SafeProbabilisticSequential`` is an :obj:``tensordict.nn.prototype.ProbabilisticTensorDictSequential`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. + """A :obj:``SafeProbabilisticSequential`` is an :obj:``tensordict.nn.ProbabilisticTensorDictSequential`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. Similarly to :obj:`TensorDictSequential`, but enforces that the final module in the sequence is an :obj:`ProbabilisticTensorDictModule` and also exposes ``get_dist`` diff --git a/tutorials/sphinx-tutorials/tensordict_module.py b/tutorials/sphinx-tutorials/tensordict_module.py index 442a2f242d3..7e8b9dfef1d 100644 --- a/tutorials/sphinx-tutorials/tensordict_module.py +++ b/tutorials/sphinx-tutorials/tensordict_module.py @@ -189,7 +189,7 @@ def forward(self, x): print("the output tensordict shape is: ", result_td.shape) -from tensordict.nn.prototype import ( +from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, ) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 27a7c97e3a6..eb4a97a4cc3 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -505,7 +505,7 @@ tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) actor(tensordict) # action is the default value -from tensordict.nn.prototype import ( +from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, )