From c5a30661446928208cbc61608cc47e07b766d8ef Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 30 Dec 2022 01:26:13 +0100 Subject: [PATCH 1/4] Remove NdBoundedTensorSpec --- docs/source/reference/data.rst | 2 +- test/mocking_classes.py | 9 +- test/test_cost.py | 16 +- test/test_exploration.py | 10 +- test/test_helpers.py | 8 +- test/test_modules.py | 6 +- test/test_tensor_spec.py | 27 ++- test/test_tensordictmodules.py | 36 ++-- test/test_transforms.py | 69 ++++---- torchrl/collectors/collectors.py | 4 +- torchrl/data/__init__.py | 1 - torchrl/data/tensor_specs.py | 163 ++++++------------ torchrl/envs/libs/brax.py | 4 +- torchrl/envs/libs/dm_control.py | 4 +- torchrl/envs/libs/gym.py | 4 +- torchrl/envs/libs/jumanji.py | 4 +- torchrl/modules/tensordict_module/actors.py | 16 +- .../modules/tensordict_module/exploration.py | 8 +- tutorials/sphinx-tutorials/torchrl_demo.py | 4 +- 19 files changed, 168 insertions(+), 227 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index b3341904ae4..fc468d346c9 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -74,7 +74,7 @@ as shape, device, dtype and domain. BoundedTensorSpec OneHotDiscreteTensorSpec UnboundedContinuousTensorSpec - NdBoundedTensorSpec + BoundedTensorSpec NdUnboundedContinuousTensorSpec BinaryDiscreteTensorSpec MultOneHotDiscreteTensorSpec diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 5f92e0da52c..15ac6975e14 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -13,7 +13,6 @@ CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, @@ -26,7 +25,6 @@ "one_hot": OneHotDiscreteTensorSpec, "categorical": DiscreteTensorSpec, "unbounded": UnboundedContinuousTensorSpec, - "ndbounded": NdBoundedTensorSpec, "ndunbounded": NdUnboundedContinuousTensorSpec, "binary": BinaryDiscreteTensorSpec, "mult_one_hot": MultOneHotDiscreteTensorSpec, @@ -34,11 +32,10 @@ } default_spec_kwargs = { - BoundedTensorSpec: {"minimum": -1.0, "maximum": 1.0}, OneHotDiscreteTensorSpec: {"n": 7}, DiscreteTensorSpec: {"n": 7}, UnboundedContinuousTensorSpec: {}, - NdBoundedTensorSpec: {"minimum": -torch.ones(4), "maxmimum": torch.ones(4)}, + BoundedTensorSpec: {"minimum": -torch.ones(4), "maximum": torch.ones(4)}, NdUnboundedContinuousTensorSpec: { "shape": [ 7, @@ -376,7 +373,7 @@ def __new__( ), ) if action_spec is None: - action_spec = NdBoundedTensorSpec(-1, 1, (7,)) + action_spec = BoundedTensorSpec(-1, 1, (7,)) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec() @@ -592,7 +589,7 @@ def __new__( ) if action_spec is None: - action_spec = NdBoundedTensorSpec(-1, 1, pixel_shape[-1]) + action_spec = BoundedTensorSpec(-1, 1, pixel_shape[-1]) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec() diff --git a/test/test_cost.py b/test/test_cost.py index 08c705cbfcd..5e1dafa0937 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -28,10 +28,10 @@ from tensordict.tensordict import assert_allclose_td, TensorDict from torch import autograd, nn from torchrl.data import ( + BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, ) @@ -130,7 +130,7 @@ def _create_mock_actor( elif action_spec_type == "categorical": action_spec = DiscreteTensorSpec(action_dim) elif action_spec_type == "nd_bounded": - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) else: @@ -417,7 +417,7 @@ class TestDDPG: def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Linear(obs_dim, action_dim) @@ -647,7 +647,7 @@ class TestSAC: def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) @@ -1026,7 +1026,7 @@ class TestREDQ: def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) @@ -1481,7 +1481,7 @@ class TestPPO: def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) @@ -1504,7 +1504,7 @@ def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) base_layer = nn.Linear(obs_dim, 5) @@ -1808,7 +1808,7 @@ class TestA2C: def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) diff --git a/test/test_exploration.py b/test/test_exploration.py index 7506eed44c0..3a8fa17bdd0 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -11,7 +11,7 @@ from scipy.stats import ttest_1samp from tensordict.tensordict import TensorDict from torch import nn -from torchrl.data import CompositeSpec, NdBoundedTensorSpec +from torchrl.data import BoundedTensorSpec, CompositeSpec from torchrl.envs.transforms.transforms import gSDENoise from torchrl.envs.utils import set_exploration_mode from torchrl.modules import SafeModule, SafeSequential @@ -61,7 +61,7 @@ def test_ou_wrapper(device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0): torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) - action_spec = NdBoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) + action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) policy = ProbabilisticActor( spec=action_spec, module=module, @@ -106,7 +106,7 @@ def test_additivegaussian_sd( ): torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), @@ -173,7 +173,7 @@ def test_additivegaussian_wrapper( torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), @@ -244,7 +244,7 @@ def test_gsde( module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal distribution_kwargs = {"min": -bound, "max": bound} - spec = NdBoundedTensorSpec( + spec = BoundedTensorSpec( -torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,) ).to(device) diff --git a/test/test_helpers.py b/test/test_helpers.py index c2df97d8486..73d8d082965 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -27,7 +27,7 @@ MockSerialEnv, ) from packaging import version -from torchrl.data import CompositeSpec, NdBoundedTensorSpec +from torchrl.data import BoundedTensorSpec, CompositeSpec from torchrl.envs.libs.gym import _has_gym from torchrl.envs.transforms import ObservationNorm from torchrl.envs.transforms.transforms import ( @@ -988,7 +988,7 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial if keys: obs_spec = CompositeSpec( **{ - key: NdBoundedTensorSpec(maximum=1, minimum=1, shape=torch.Size([1])) + key: BoundedTensorSpec(maximum=1, minimum=1, shape=torch.Size([1])) for key in keys } ) @@ -996,9 +996,7 @@ def test_initialize_stats_from_observation_norms(device, keys, composed, initial env = ContinuousActionVecMockEnv( device=device, observation_spec=obs_spec, - action_spec=NdBoundedTensorSpec( - minimum=1, maximum=2, shape=torch.Size((1,)) - ), + action_spec=BoundedTensorSpec(minimum=1, maximum=2, shape=torch.Size((1,))), ) env.out_key = "observation" else: diff --git a/test/test_modules.py b/test/test_modules.py index 1f249980921..25bcd07028f 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -13,8 +13,8 @@ from tensordict import TensorDict from torch import nn from torchrl.data.tensor_specs import ( + BoundedTensorSpec, DiscreteTensorSpec, - NdBoundedTensorSpec, OneHotDiscreteTensorSpec, ) from torchrl.modules import ( @@ -511,7 +511,7 @@ def test_dreamer_decoder( @pytest.mark.parametrize("deter_size", [20, 30]) @pytest.mark.parametrize("action_size", [3, 6]) def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size): - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( shape=(action_size,), dtype=torch.float32, minimum=-1, maximum=1 ) rssm_prior = RSSMPrior( @@ -566,7 +566,7 @@ def test_rssm_posterior(self, device, batch_size, stoch_size, deter_size): def test_rssm_rollout( self, device, batch_size, temporal_size, stoch_size, deter_size, action_size ): - action_spec = NdBoundedTensorSpec( + action_spec = BoundedTensorSpec( shape=(action_size,), dtype=torch.float32, minimum=-1, maximum=1 ) rssm_prior = RSSMPrior( diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 7a8a455615d..0dc78de3e84 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -17,7 +17,6 @@ CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, @@ -96,7 +95,7 @@ def test_ndbounded(dtype, shape): for _ in range(100): lb = torch.rand(10) - 1 ub = torch.rand(10) + 1 - ts = NdBoundedTensorSpec(lb, ub, dtype=dtype) + ts = BoundedTensorSpec(lb, ub, dtype=dtype) _dtype = dtype if dtype is None: _dtype = torch.get_default_dtype() @@ -243,7 +242,7 @@ def _composite_spec(is_complete=True, device=None, dtype=None): np.random.seed(0) return CompositeSpec( - obs=NdBoundedTensorSpec( + obs=BoundedTensorSpec( torch.zeros(3, 32, 32), torch.ones(3, 32, 32), dtype=dtype, @@ -256,7 +255,7 @@ def _composite_spec(is_complete=True, device=None, dtype=None): def test_getitem(self, is_complete, device, dtype): ts = self._composite_spec(is_complete, device, dtype) - assert isinstance(ts["obs"], NdBoundedTensorSpec) + assert isinstance(ts["obs"], BoundedTensorSpec) if is_complete: assert isinstance(ts["act"], NdUnboundedContinuousTensorSpec) else: @@ -587,31 +586,31 @@ def test_equality_ndbounded(self): device = "cpu" dtype = torch.float16 - ts = NdBoundedTensorSpec( + ts = BoundedTensorSpec( minimum=minimum, maximum=maximum, device=device, dtype=dtype ) - ts_same = NdBoundedTensorSpec( + ts_same = BoundedTensorSpec( minimum=minimum, maximum=maximum, device=device, dtype=dtype ) assert ts == ts_same - ts_other = NdBoundedTensorSpec( + ts_other = BoundedTensorSpec( minimum=minimum + 1, maximum=maximum, device=device, dtype=dtype ) assert ts != ts_other - ts_other = NdBoundedTensorSpec( + ts_other = BoundedTensorSpec( minimum=minimum, maximum=maximum + 1, device=device, dtype=dtype ) assert ts != ts_other - ts_other = NdBoundedTensorSpec( + ts_other = BoundedTensorSpec( minimum=minimum, maximum=maximum, device="cpu:0", dtype=dtype ) assert ts != ts_other - ts_other = NdBoundedTensorSpec( + ts_other = BoundedTensorSpec( minimum=minimum, maximum=maximum, device=device, dtype=torch.float64 ) assert ts != ts_other @@ -764,13 +763,13 @@ def test_equality_composite(self): bounded_same = BoundedTensorSpec(0, 1, device, dtype) bounded_other = BoundedTensorSpec(0, 2, device, dtype) - nd = NdBoundedTensorSpec( + nd = BoundedTensorSpec( minimum=minimum, maximum=maximum + 1, device=device, dtype=dtype ) - nd_same = NdBoundedTensorSpec( + nd_same = BoundedTensorSpec( minimum=minimum, maximum=maximum + 1, device=device, dtype=dtype ) - _ = NdBoundedTensorSpec( + _ = BoundedTensorSpec( minimum=minimum, maximum=maximum + 3, device=device, dtype=dtype ) @@ -946,7 +945,7 @@ def test_bounded_rand(self): assert (-3 <= sample).all() and (3 >= sample).all() def test_ndbounded_shape(self): - spec = NdBoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5]) + spec = BoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5]) sample = torch.stack([spec.rand() for _ in range(100)], 0) assert (-3 <= sample).all() and (3 >= sample).all() assert sample.shape == torch.Size([100, 10, 5]) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index b7e1708bb77..99ad7a72749 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -11,8 +11,8 @@ from tensordict.nn.functional_modules import make_functional from torch import nn from torchrl.data.tensor_specs import ( + BoundedTensorSpec, CompositeSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, ) from torchrl.envs.utils import set_exploration_mode @@ -114,7 +114,7 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) @@ -176,7 +176,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) else: @@ -239,7 +239,7 @@ def test_functional(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) @@ -293,7 +293,7 @@ def test_functional_probabilistic(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) else: @@ -350,7 +350,7 @@ def test_functional_with_buffer(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 32) + spec = BoundedTensorSpec(-0.1, 0.1, 32) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(32) @@ -404,7 +404,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 32) + spec = BoundedTensorSpec(-0.1, 0.1, 32) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(32) else: @@ -464,7 +464,7 @@ def test_vmap(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) @@ -543,7 +543,7 @@ def test_vmap_probabilistic(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) else: @@ -640,7 +640,7 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) @@ -717,7 +717,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) else: @@ -804,7 +804,7 @@ def test_functional(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) @@ -873,7 +873,7 @@ def test_functional_probabilistic(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) else: @@ -959,7 +959,7 @@ def test_functional_with_buffer(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 7) + spec = BoundedTensorSpec(-0.1, 0.1, 7) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(7) @@ -1031,7 +1031,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 7) + spec = BoundedTensorSpec(-0.1, 0.1, 7) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(7) else: @@ -1123,7 +1123,7 @@ def test_vmap(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) @@ -1222,7 +1222,7 @@ def test_vmap_probabilistic(self, safe, spec_type): if spec_type is None: spec = None elif spec_type == "bounded": - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = NdUnboundedContinuousTensorSpec(4) else: @@ -1340,7 +1340,7 @@ def test_sequential_partial(self, stack, functional): net3 = NormalParamWrapper(net3) net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) kwargs = {"distribution_class": TanhNormal} diff --git a/test/test_transforms.py b/test/test_transforms.py index c07ee7cb8b3..5ec2639bf99 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -28,7 +28,6 @@ from torchrl.data import ( BoundedTensorSpec, CompositeSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, UnboundedContinuousTensorSpec, ) @@ -453,12 +452,12 @@ def test_resize(self, interpolation, keys, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) observation_spec = resize.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = resize.transform_observation_spec(observation_spec) for key in keys: @@ -493,12 +492,12 @@ def test_centercrop(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) observation_spec = cc.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = cc.transform_observation_spec(observation_spec) for key in keys: @@ -533,13 +532,13 @@ def test_flatten(self, keys, size, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) observation_spec = flatten.transform_observation_spec(observation_spec) assert observation_spec.shape[-3] == expected_size else: observation_spec = CompositeSpec( { - key: NdBoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) for key in keys } ) @@ -646,13 +645,13 @@ def test_unsqueeze(self, keys, size, nchannels, batch, device, unsqueeze_dim): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) observation_spec = unsqueeze.transform_observation_spec(observation_spec) assert observation_spec.shape == expected_size else: observation_spec = CompositeSpec( { - key: NdBoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) for key in keys } ) @@ -794,12 +793,12 @@ def test_grayscale(self, keys, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) observation_spec = gs.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, 16, 16]) else: observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = gs.transform_observation_spec(observation_spec) for key in keys: @@ -832,7 +831,7 @@ def test_totensorimage(self, keys, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(0, 255, (16, 16, 3)) + observation_spec = BoundedTensorSpec(0, 255, (16, 16, 3)) observation_spec = totensorimage.transform_observation_spec( observation_spec ) @@ -841,7 +840,7 @@ def test_totensorimage(self, keys, batch, device): assert (observation_spec.space.maximum == 1).all() else: observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(0, 255, (16, 16, 3)) for key in keys} + {key: BoundedTensorSpec(0, 255, (16, 16, 3)) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( observation_spec @@ -878,12 +877,12 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(0, 255, (nchannels, 16, 16)) + observation_spec = BoundedTensorSpec(0, 255, (nchannels, 16, 16)) observation_spec = compose.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels * N, 16, 16]) else: observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(0, 255, (nchannels, 16, 16)) for key in keys} + {key: BoundedTensorSpec(0, 255, (nchannels, 16, 16)) for key in keys} ) observation_spec = compose.transform_observation_spec(observation_spec) for key in keys: @@ -969,7 +968,7 @@ def test_observationnorm( assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = NdBoundedTensorSpec( + observation_spec = BoundedTensorSpec( 0, 1, (nchannels, 16, 16), device=device ) observation_spec = on.transform_observation_spec(observation_spec) @@ -983,7 +982,7 @@ def test_observationnorm( else: observation_spec = CompositeSpec( { - key: NdBoundedTensorSpec(0, 1, (nchannels, 16, 16), device=device) + key: BoundedTensorSpec(0, 1, (nchannels, 16, 16), device=device) for key in keys } ) @@ -1009,14 +1008,14 @@ def test_observationnorm_init_stats( def make_env(): base_env = ContinuousActionVecMockEnv( observation_spec=CompositeSpec( - observation=NdBoundedTensorSpec( + observation=BoundedTensorSpec( minimum=1, maximum=1, shape=torch.Size([size]) ), - observation_orig=NdBoundedTensorSpec( + observation_orig=BoundedTensorSpec( minimum=1, maximum=1, shape=torch.Size([size]) ), ), - action_spec=NdBoundedTensorSpec( + action_spec=BoundedTensorSpec( minimum=1, maximum=1, shape=torch.Size((size,)) ), seed=0, @@ -1117,16 +1116,14 @@ def test_observationnorm_uninitialized_stats_error(self): def test_observationnorm_infinite_stats_error(self, device): base_env = ContinuousActionVecMockEnv( observation_spec=CompositeSpec( - observation=NdBoundedTensorSpec( + observation=BoundedTensorSpec( minimum=1, maximum=1, shape=torch.Size([1]) ), - observation_orig=NdBoundedTensorSpec( + observation_orig=BoundedTensorSpec( minimum=1, maximum=1, shape=torch.Size([1]) ), ), - action_spec=NdBoundedTensorSpec( - minimum=1, maximum=1, shape=torch.Size((1,)) - ), + action_spec=BoundedTensorSpec(minimum=1, maximum=1, shape=torch.Size((1,))), seed=0, ) base_env.out_key = "observation" @@ -1150,7 +1147,7 @@ def test_catframes_transform_observation_spec(self): maxes = [0.5, 1] observation_spec = CompositeSpec( { - key: NdBoundedTensorSpec( + key: BoundedTensorSpec( space_min, space_max, (1, 3, 3), dtype=torch.double ) for key, space_min, space_max in zip(keys, mins, maxes) @@ -1160,7 +1157,7 @@ def test_catframes_transform_observation_spec(self): result = cat_frames.transform_observation_spec(observation_spec) observation_spec = CompositeSpec( { - key: NdBoundedTensorSpec( + key: BoundedTensorSpec( space_min, space_max, (1, 3, 3), dtype=torch.double ) for key, space_min, space_max in zip(keys, mins, maxes) @@ -1271,20 +1268,20 @@ def test_double2float(self, keys, keys_inv, device): assert td.get("dont touch").dtype == torch.double if len(keys_total) == 1 and len(keys_inv) and keys[0] == "action": - action_spec = NdBoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) + action_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) input_spec = CompositeSpec(action=action_spec) action_spec = double2float.transform_input_spec(input_spec) assert action_spec.dtype == torch.float elif len(keys) == 1: - observation_spec = NdBoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) + observation_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) observation_spec = double2float.transform_observation_spec(observation_spec) assert observation_spec.dtype == torch.float else: observation_spec = CompositeSpec( { - key: NdBoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) + key: BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) for key in keys } ) @@ -1328,12 +1325,12 @@ def test_cattensors(self, keys, device): assert td.get("dont touch").shape == dont_touch.shape if len(keys) == 1: - observation_spec = NdBoundedTensorSpec(0, 1, (1, 4, 32)) + observation_spec = BoundedTensorSpec(0, 1, (1, 4, 32)) observation_spec = cattensors.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, len(keys) * 4, 32]) else: observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(0, 1, (1, 4, 32)) for key in keys} + {key: BoundedTensorSpec(0, 1, (1, 4, 32)) for key in keys} ) observation_spec = cattensors.transform_observation_spec(observation_spec) assert observation_spec["observation_out"].shape == torch.Size( @@ -1394,8 +1391,8 @@ def test_noop_reset_env_error(self, random, device, compose): @pytest.mark.parametrize( "spec", [ - CompositeSpec(b=NdBoundedTensorSpec(-3, 3, [4])), - NdBoundedTensorSpec(-3, 3, [4]), + CompositeSpec(b=BoundedTensorSpec(-3, 3, [4])), + BoundedTensorSpec(-3, 3, [4]), ], ) @pytest.mark.parametrize("random", [True, False]) @@ -1793,7 +1790,7 @@ def test_r3mnet_transform_observation_spec( r3m_net = _R3MNet(in_keys, out_keys, model, del_keys) observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: exp_ts = CompositeSpec( @@ -2050,7 +2047,7 @@ def test_vipnet_transform_observation_spec( vip_net = _VIPNet(in_keys, out_keys, model, del_keys) observation_spec = CompositeSpec( - {key: NdBoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: exp_ts = CompositeSpec( diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 1468babedae..aa3d79e1499 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -55,8 +55,8 @@ def __init__(self, action_spec: TensorSpec): Examples: >>> from tensordict import TensorDict - >>> from torchrl.data.tensor_specs import NdBoundedTensorSpec - >>> action_spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3)) + >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) >>> actor = RandomPolicy(spec=action_spec) >>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1] diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 44822cbfa7e..601b65e90db 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -21,7 +21,6 @@ DEVICE_TYPING, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a46af2881e4..2442ce5958e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -36,6 +36,8 @@ _NO_CHECK_SPEC_ENCODE = get_binary_env_var("NO_CHECK_SPEC_ENCODE") +_DEFAULT_SHAPE = torch.Size([1]) + def _default_dtype_and_device( dtype: Union[None, torch.dtype], @@ -379,105 +381,6 @@ def __repr__(self): return string -@dataclass(repr=False) -class BoundedTensorSpec(TensorSpec): - """A bounded, unidimensional, continuous tensor spec. - - Args: - minimum (np.ndarray, torch.Tensor or number): lower bound of the box. - maximum (np.ndarray, torch.Tensor or number): upper bound of the box. - device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors. - """ - - shape: torch.Size - space: ContinuousBox - device: torch.device = torch.device("cpu") - dtype: torch.dtype = torch.float - domain: str = "" - - def __init__( - self, - minimum: Union[np.ndarray, torch.Tensor, float], - maximum: Union[np.ndarray, torch.Tensor, float], - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[torch.dtype] = None, - ): - dtype, device = _default_dtype_and_device(dtype, device) - if not isinstance(minimum, torch.Tensor): - minimum = torch.tensor(minimum, dtype=dtype, device=device) - if minimum.dtype is not dtype: - minimum = minimum.to(dtype) - if minimum.device != device: - minimum = minimum.to(device) - - if not isinstance(maximum, torch.Tensor): - maximum = torch.tensor(maximum, dtype=dtype, device=device) - if maximum.dtype is not dtype: - maximum = maximum.to(dtype) - if maximum.device != device: - maximum = maximum.to(device) - super().__init__( - torch.Size( - [ - 1, - ] - ), - ContinuousBox(minimum, maximum), - device, - dtype, - "continuous", - ) - - def rand(self, shape=None) -> torch.Tensor: - if shape is None: - shape = torch.Size([]) - a, b = self.space - if self.dtype in (torch.float, torch.double, torch.half): - shape = [*shape, *self.shape] - out = ( - torch.zeros(shape, dtype=self.dtype, device=self.device).uniform_() - * (b - a) - + a - ) - if (out > b).any(): - out[out > b] = b.expand_as(out)[out > b] - if (out < a).any(): - out[out < a] = a.expand_as(out)[out < a] - return out - else: - interval = self.space.maximum - self.space.minimum - r = torch.rand( - torch.Size([*shape, *interval.shape]), device=interval.device - ) - r = interval * r - r = self.space.minimum + r - r = r.to(self.dtype).to(self.device) - return r - - def _project(self, val: torch.Tensor) -> torch.Tensor: - minimum = self.space.minimum.to(val.device) - maximum = self.space.maximum.to(val.device) - try: - val = val.clamp_(minimum.item(), maximum.item()) - except ValueError: - minimum = minimum.expand_as(val) - maximum = maximum.expand_as(val) - val[val < minimum] = minimum[val < minimum] - val[val > maximum] = maximum[val > maximum] - except RuntimeError: - minimum = minimum.expand_as(val) - maximum = maximum.expand_as(val) - val[val < minimum] = minimum[val < minimum] - val[val > maximum] = maximum[val > maximum] - return val - - def is_in(self, val: torch.Tensor) -> bool: - return (val >= self.space.minimum.to(val.device)).all() and ( - val <= self.space.maximum.to(val.device) - ).all() - - @dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): """A unidimensional, one-hot discrete tensor spec. @@ -679,8 +582,8 @@ def is_in(self, val: torch.Tensor) -> bool: @dataclass(repr=False) -class NdBoundedTensorSpec(BoundedTensorSpec): - """A bounded, multi-dimensional, continuous tensor spec. +class BoundedTensorSpec(TensorSpec): + """A bounded continuous tensor spec. Args: minimum (np.ndarray, torch.Tensor or number): lower bound of the box. @@ -694,7 +597,7 @@ def __init__( self, minimum: Union[float, torch.Tensor, np.ndarray], maximum: Union[float, torch.Tensor, np.ndarray], - shape: Optional[torch.Size] = None, + shape: Optional[torch.Size] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[torch.dtype, str]] = None, ): @@ -717,7 +620,7 @@ def __init__( if dtype is not None and maximum.dtype is not dtype: maximum = maximum.to(dtype) err_msg = ( - "NdBoundedTensorSpec requires the shape to be explicitely (via " + "BoundedTensorSpec requires the shape to be explicitely (via " "the shape argument) or implicitely defined (via either the " "minimum or the maximum or both). If the maximum and/or the " "minimum have a non-singleton shape, they must match the " @@ -765,10 +668,58 @@ def __init__( raise RuntimeError(shape_err_msg) self.shape = shape - super(BoundedTensorSpec, self).__init__( + super().__init__( shape, ContinuousBox(minimum, maximum), device, dtype, "continuous" ) + def rand(self, shape=None) -> torch.Tensor: + if shape is None: + shape = torch.Size([]) + a, b = self.space + if self.dtype in (torch.float, torch.double, torch.half): + shape = [*shape, *self.shape] + out = ( + torch.zeros(shape, dtype=self.dtype, device=self.device).uniform_() + * (b - a) + + a + ) + if (out > b).any(): + out[out > b] = b.expand_as(out)[out > b] + if (out < a).any(): + out[out < a] = a.expand_as(out)[out < a] + return out + else: + interval = self.space.maximum - self.space.minimum + r = torch.rand( + torch.Size([*shape, *interval.shape]), device=interval.device + ) + r = interval * r + r = self.space.minimum + r + r = r.to(self.dtype).to(self.device) + return r + + def _project(self, val: torch.Tensor) -> torch.Tensor: + minimum = self.space.minimum.to(val.device) + maximum = self.space.maximum.to(val.device) + try: + val = val.clamp_(minimum.item(), maximum.item()) + except ValueError: + minimum = minimum.expand_as(val) + maximum = maximum.expand_as(val) + val[val < minimum] = minimum[val < minimum] + val[val > maximum] = maximum[val > maximum] + except RuntimeError: + minimum = minimum.expand_as(val) + maximum = maximum.expand_as(val) + val[val < minimum] = minimum[val < minimum] + val[val > maximum] = maximum[val > maximum] + return val + + def is_in(self, val: torch.Tensor) -> bool: + return (val >= self.space.minimum.to(val.device)).all() and ( + val <= self.space.maximum.to(val.device) + ).all() + @dataclass(repr=False) class NdUnboundedContinuousTensorSpec(UnboundedContinuousTensorSpec): @@ -1088,10 +1039,10 @@ class CompositeSpec(TensorSpec): effect. `spec.encode` cannot be used with missing values. Examples: - >>> pixels_spec = NdBoundedTensorSpec( + >>> pixels_spec = BoundedTensorSpec( ... torch.zeros(3,32,32), ... torch.ones(3, 32, 32)) - >>> observation_vector_spec = NdBoundedTensorSpec(torch.zeros(33), + >>> observation_vector_spec = BoundedTensorSpec(torch.zeros(33), ... torch.ones(33)) >>> composite_spec = CompositeSpec( ... pixels=pixels_spec, diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 7ab3253ba8c..68f3b74cdd5 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -4,8 +4,8 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import ( + BoundedTensorSpec, CompositeSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, ) from torchrl.envs.common import _EnvWrapper @@ -118,7 +118,7 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.input_spec = CompositeSpec( - action=NdBoundedTensorSpec( + action=BoundedTensorSpec( minimum=-1, maximum=1, shape=(env.action_size,), device=self.device ) ) diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index fbe73e828d3..47d519ee6fb 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -12,8 +12,8 @@ import torch from torchrl.data import ( + BoundedTensorSpec, CompositeSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, TensorSpec, @@ -60,7 +60,7 @@ def _dmcontrol_to_torchrl_spec_transform( shape = spec.shape if not len(shape): shape = torch.Size([1]) - return NdBoundedTensorSpec( + return BoundedTensorSpec( shape=shape, minimum=spec.minimum, maximum=spec.maximum, diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index abecc288c66..285f9d6d52c 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -11,10 +11,10 @@ from torchrl.data import ( BinaryDiscreteTensorSpec, + BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, @@ -75,7 +75,7 @@ def _gym_to_torchrl_spec_transform( shape = torch.Size([1]) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] - return NdBoundedTensorSpec( + return BoundedTensorSpec( torch.tensor(spec.low, device=device, dtype=dtype), torch.tensor(spec.high, device=device, dtype=dtype), shape, diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 67dceafbaec..84c3a2fb901 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -5,10 +5,10 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import ( + BoundedTensorSpec, CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -62,7 +62,7 @@ def _jumanji_to_torchrl_spec_transform( shape = spec.shape if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] - return NdBoundedTensorSpec( + return BoundedTensorSpec( shape=shape, minimum=np.asarray(spec.minimum), maximum=np.asarray(spec.maximum), diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index bdecbf86623..ed0cb94bd11 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -82,10 +82,10 @@ class ProbabilisticActor(SafeProbabilisticSequential): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn.functional_modules import make_functional - >>> from torchrl.data import NdBoundedTensorSpec + >>> from torchrl.data import BoundedTensorSpec >>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, SafeModule, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) - >>> action_spec = NdBoundedTensorSpec(shape=torch.Size([4]), + >>> action_spec = BoundedTensorSpec(shape=torch.Size([4]), ... minimum=-1, maximum=1) >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> tensordict_module = SafeModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -569,7 +569,7 @@ class ActorValueOperator(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule - >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec + >>> from torchrl.data import NdUnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) @@ -579,7 +579,7 @@ class ActorValueOperator(SafeSequential): ... in_keys=["observation"], ... out_keys=["hidden"], ... ) - >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> spec_action = BoundedTensorSpec(-1, 1, torch.Size([8])) >>> module_action = SafeModule( ... NormalParamWrapper(torch.nn.Linear(4, 8)), ... in_keys=["hidden"], @@ -703,7 +703,7 @@ class ActorCriticOperator(ActorValueOperator): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule - >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec + >>> from torchrl.data import NdUnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) @@ -713,7 +713,7 @@ class ActorCriticOperator(ActorValueOperator): ... in_keys=["observation"], ... out_keys=["hidden"], ... ) - >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> spec_action = BoundedTensorSpec(-1, 1, torch.Size([8])) >>> module_action = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> module_action = SafeModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) >>> td_module_action = ProbabilisticActor( @@ -832,7 +832,7 @@ class ActorCriticWrapper(SafeSequential): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec + >>> from torchrl.data import NdUnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ( ActorCriticWrapper, ProbabilisticActor, @@ -841,7 +841,7 @@ class ActorCriticWrapper(SafeSequential): TanhNormal, ValueOperator, ) - >>> action_spec = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> action_spec = BoundedTensorSpec(-1, 1, torch.Size([8])) >>> action_module = SafeModule( NormalParamWrapper(torch.nn.Linear(4, 8)), in_keys=["observation"], diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 2dd65bb339d..ba744497404 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -48,9 +48,9 @@ class EGreedyWrapper(TensorDictModuleWrapper): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import EGreedyWrapper, Actor - >>> from torchrl.data import NdBoundedTensorSpec + >>> from torchrl.data import BoundedTensorSpec >>> torch.manual_seed(0) - >>> spec = NdBoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2) @@ -280,10 +280,10 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdBoundedTensorSpec + >>> from torchrl.data import BoundedTensorSpec >>> from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor >>> torch.manual_seed(0) - >>> spec = NdBoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(module=module, spec=spec) >>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index eb4a97a4cc3..41f61665a96 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -480,10 +480,10 @@ # ------------------------------ torch.manual_seed(0) -from torchrl.data import NdBoundedTensorSpec +from torchrl.data import BoundedTensorSpec from torchrl.modules import SafeModule -spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3)) +spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) base_module = nn.Linear(5, 3) module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True From 5cbf2d62338b9d5106e8af7884817bb1b22c87d6 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 30 Dec 2022 02:07:41 +0100 Subject: [PATCH 2/4] Remove NdUnboundedContinuousTensorSpec --- docs/source/reference/data.rst | 2 +- examples/dreamer/dreamer_utils.py | 6 +- test/mocking_classes.py | 57 ++++++++---------- test/test_collector.py | 8 +-- test/test_cost.py | 16 ++--- test/test_tensor_spec.py | 19 +++--- test/test_tensordictmodules.py | 38 ++++++------ test/test_transforms.py | 15 ++--- torchrl/data/__init__.py | 1 - torchrl/data/tensor_specs.py | 58 +++++++------------ torchrl/envs/libs/brax.py | 10 +--- torchrl/envs/libs/dm_control.py | 4 +- torchrl/envs/libs/gym.py | 4 +- torchrl/envs/libs/jax_utils.py | 4 +- torchrl/envs/libs/jumanji.py | 4 +- torchrl/envs/model_based/common.py | 10 ++-- torchrl/envs/model_based/dreamer.py | 4 +- torchrl/envs/transforms/r3m.py | 4 +- torchrl/envs/transforms/transforms.py | 5 +- torchrl/envs/transforms/vip.py | 4 +- torchrl/modules/planners/cem.py | 10 ++-- torchrl/modules/tensordict_module/actors.py | 16 ++--- torchrl/modules/tensordict_module/common.py | 4 +- torchrl/modules/tensordict_module/sequence.py | 6 +- torchrl/trainers/helpers/models.py | 14 ++--- 25 files changed, 139 insertions(+), 184 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index fc468d346c9..330f96c725a 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -75,7 +75,7 @@ as shape, device, dtype and domain. OneHotDiscreteTensorSpec UnboundedContinuousTensorSpec BoundedTensorSpec - NdUnboundedContinuousTensorSpec + UnboundedContinuousTensorSpec BinaryDiscreteTensorSpec MultOneHotDiscreteTensorSpec DiscreteTensorSpec diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index dfcd444262d..36dbdc365fe 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field as dataclass_field from typing import Any, Callable, Optional, Sequence, Union -from torchrl.data import NdUnboundedContinuousTensorSpec +from torchrl.data import UnboundedContinuousTensorSpec from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import env_creator, EnvCreator @@ -125,8 +125,8 @@ def make_env_transforms( ) default_dict = { - "state": NdUnboundedContinuousTensorSpec(cfg.state_dim), - "belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim), + "state": UnboundedContinuousTensorSpec(cfg.state_dim), + "belief": UnboundedContinuousTensorSpec(cfg.rssm_hidden_dim), } env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 15ac6975e14..a429fbb5161 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -13,7 +13,6 @@ CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, ) @@ -25,7 +24,6 @@ "one_hot": OneHotDiscreteTensorSpec, "categorical": DiscreteTensorSpec, "unbounded": UnboundedContinuousTensorSpec, - "ndunbounded": NdUnboundedContinuousTensorSpec, "binary": BinaryDiscreteTensorSpec, "mult_one_hot": MultOneHotDiscreteTensorSpec, "composite": CompositeSpec, @@ -34,9 +32,8 @@ default_spec_kwargs = { OneHotDiscreteTensorSpec: {"n": 7}, DiscreteTensorSpec: {"n": 7}, - UnboundedContinuousTensorSpec: {}, BoundedTensorSpec: {"minimum": -torch.ones(4), "maximum": torch.ones(4)}, - NdUnboundedContinuousTensorSpec: { + UnboundedContinuousTensorSpec: { "shape": [ 7, ] @@ -111,13 +108,13 @@ def __new__( **kwargs, ): if action_spec is None: - action_spec = NdUnboundedContinuousTensorSpec((1,)) + action_spec = UnboundedContinuousTensorSpec((1,)) if observation_spec is None: observation_spec = CompositeSpec( - observation=NdUnboundedContinuousTensorSpec((1,)) + observation=UnboundedContinuousTensorSpec((1,)) ) if reward_spec is None: - reward_spec = NdUnboundedContinuousTensorSpec((1,)) + reward_spec = UnboundedContinuousTensorSpec((1,)) if input_spec is None: input_spec = CompositeSpec(action=action_spec) cls._reward_spec = reward_spec @@ -172,18 +169,18 @@ def __new__( **kwargs, ): if action_spec is None: - action_spec = NdUnboundedContinuousTensorSpec((1,)) + action_spec = UnboundedContinuousTensorSpec((1,)) if input_spec is None: input_spec = CompositeSpec( action=action_spec, - observation=NdUnboundedContinuousTensorSpec((1,)), + observation=UnboundedContinuousTensorSpec((1,)), ) if observation_spec is None: observation_spec = CompositeSpec( - observation=NdUnboundedContinuousTensorSpec((1,)) + observation=UnboundedContinuousTensorSpec((1,)) ) if reward_spec is None: - reward_spec = NdUnboundedContinuousTensorSpec((1,)) + reward_spec = UnboundedContinuousTensorSpec((1,)) cls._reward_spec = reward_spec cls._observation_spec = observation_spec cls._input_spec = input_spec @@ -280,8 +277,8 @@ def __new__( if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])), - observation_orig=NdUnboundedContinuousTensorSpec( + observation=UnboundedContinuousTensorSpec(shape=torch.Size([size])), + observation_orig=UnboundedContinuousTensorSpec( shape=torch.Size([size]) ), ) @@ -367,8 +364,8 @@ def __new__( if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])), - observation_orig=NdUnboundedContinuousTensorSpec( + observation=UnboundedContinuousTensorSpec(shape=torch.Size([size])), + observation_orig=UnboundedContinuousTensorSpec( shape=torch.Size([size]) ), ) @@ -468,10 +465,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])), - pixels_orig=NdUnboundedContinuousTensorSpec( - shape=torch.Size([1, 7, 7]) - ), + pixels=UnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])), + pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])), ) if action_spec is None: action_spec = OneHotDiscreteTensorSpec(7) @@ -520,10 +515,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), - pixels_orig=NdUnboundedContinuousTensorSpec( - shape=torch.Size([7, 7, 3]) - ), + pixels=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), + pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), ) if action_spec is None: action_spec_cls = ( @@ -582,8 +575,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)), - pixels_orig=NdUnboundedContinuousTensorSpec( + pixels=UnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)), + pixels_orig=UnboundedContinuousTensorSpec( shape=torch.Size(pixel_shape) ), ) @@ -631,10 +624,8 @@ def __new__( if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), - pixels_orig=NdUnboundedContinuousTensorSpec( - shape=torch.Size([7, 7, 3]) - ), + pixels=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), + pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), ) return super().__new__( *args, @@ -693,13 +684,13 @@ def __init__( batch_size=batch_size, ) self.observation_spec = CompositeSpec( - hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + hidden_observation=UnboundedContinuousTensorSpec((4,)) ) self.input_spec = CompositeSpec( - hidden_observation=NdUnboundedContinuousTensorSpec((4,)), - action=NdUnboundedContinuousTensorSpec((1,)), + hidden_observation=UnboundedContinuousTensorSpec((4,)), + action=UnboundedContinuousTensorSpec((1,)), ) - self.reward_spec = NdUnboundedContinuousTensorSpec((1,)) + self.reward_spec = UnboundedContinuousTensorSpec((1,)) def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: td = TensorDict( diff --git a/test/test_collector.py b/test/test_collector.py index b12e974097a..742367e2d15 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -28,11 +28,7 @@ RandomPolicy, ) from torchrl.collectors.utils import split_trajectories -from torchrl.data import ( - CompositeSpec, - NdUnboundedContinuousTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv from torchrl.envs.libs.gym import _has_gym, GymEnv from torchrl.envs.transforms import TransformedEnv, VecNorm @@ -942,7 +938,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe ], } if explicit_spec: - hidden_spec = NdUnboundedContinuousTensorSpec((1, hidden_size)) + hidden_spec = UnboundedContinuousTensorSpec((1, hidden_size)) policy_kwargs["spec"] = CompositeSpec( action=UnboundedContinuousTensorSpec(), hidden1=hidden_spec, diff --git a/test/test_cost.py b/test/test_cost.py index 5e1dafa0937..b84c8ed1f95 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -32,8 +32,8 @@ CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, + UnboundedContinuousTensorSpec, ) from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv @@ -2022,7 +2022,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): distribution_class=TanhNormal, return_log_prob=True, in_keys=["loc", "scale"], - spec=NdUnboundedContinuousTensorSpec(n_act), + spec=UnboundedContinuousTensorSpec(n_act), ) if advantage == "gae": advantage = GAE( @@ -2146,8 +2146,8 @@ def _create_value_data( def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64])) default_dict = { - "state": NdUnboundedContinuousTensorSpec(state_dim), - "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": UnboundedContinuousTensorSpec(state_dim), + "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -2221,8 +2221,8 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64])) default_dict = { - "state": NdUnboundedContinuousTensorSpec(state_dim), - "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": UnboundedContinuousTensorSpec(state_dim), + "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -2270,8 +2270,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64])) default_dict = { - "state": NdUnboundedContinuousTensorSpec(state_dim), - "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": UnboundedContinuousTensorSpec(state_dim), + "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 0dc78de3e84..1733c0b372c 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -17,7 +17,6 @@ CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, ) @@ -134,7 +133,7 @@ def test_ndunbounded(dtype, n, shape): torch.manual_seed(0) np.random.seed(0) - ts = NdUnboundedContinuousTensorSpec( + ts = UnboundedContinuousTensorSpec( shape=[ n, ], @@ -248,7 +247,7 @@ def _composite_spec(is_complete=True, device=None, dtype=None): dtype=dtype, device=device, ), - act=NdUnboundedContinuousTensorSpec((7,), dtype=dtype, device=device) + act=UnboundedContinuousTensorSpec((7,), dtype=dtype, device=device) if is_complete else None, ) @@ -257,7 +256,7 @@ def test_getitem(self, is_complete, device, dtype): ts = self._composite_spec(is_complete, device, dtype) assert isinstance(ts["obs"], BoundedTensorSpec) if is_complete: - assert isinstance(ts["act"], NdUnboundedContinuousTensorSpec) + assert isinstance(ts["act"], UnboundedContinuousTensorSpec) else: assert ts["act"] is None with pytest.raises(KeyError): @@ -660,25 +659,23 @@ def test_equality_ndunbounded(self, shape): device = "cpu" dtype = torch.float16 - ts = NdUnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) + ts = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) - ts_same = NdUnboundedContinuousTensorSpec( - shape=shape, device=device, dtype=dtype - ) + ts_same = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) assert ts == ts_same other_shape = 13 if type(shape) == int else torch.Size(np.array(shape) + 10) - ts_other = NdUnboundedContinuousTensorSpec( + ts_other = UnboundedContinuousTensorSpec( shape=other_shape, device=device, dtype=dtype ) assert ts != ts_other - ts_other = NdUnboundedContinuousTensorSpec( + ts_other = UnboundedContinuousTensorSpec( shape=shape, device="cpu:0", dtype=dtype ) assert ts != ts_other - ts_other = NdUnboundedContinuousTensorSpec( + ts_other = UnboundedContinuousTensorSpec( shape=shape, device=device, dtype=torch.float64 ) assert ts != ts_other diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 99ad7a72749..9e44a77f8c0 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -13,7 +13,7 @@ from torchrl.data.tensor_specs import ( BoundedTensorSpec, CompositeSpec, - NdUnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, ) from torchrl.envs.utils import set_exploration_mode from torchrl.modules import NormalParamWrapper, SafeModule, TanhNormal @@ -87,8 +87,8 @@ def forward(self, x): return self.linear_1(x), self.linear_2(x) spec_dict = { - "_": NdUnboundedContinuousTensorSpec((4,)), - "out_2": NdUnboundedContinuousTensorSpec((3,)), + "_": UnboundedContinuousTensorSpec((4,)), + "out_2": UnboundedContinuousTensorSpec((3,)), } # warning due to "_" in spec keys @@ -116,7 +116,7 @@ def test_stateful(self, safe, spec_type, lazy): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) if safe and spec is None: with pytest.raises( @@ -178,7 +178,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) else: raise NotImplementedError @@ -241,7 +241,7 @@ def test_functional(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) if safe and spec is None: with pytest.raises( @@ -295,7 +295,7 @@ def test_functional_probabilistic(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) else: raise NotImplementedError @@ -352,7 +352,7 @@ def test_functional_with_buffer(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 32) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(32) + spec = UnboundedContinuousTensorSpec(32) if safe and spec is None: with pytest.raises( @@ -406,7 +406,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 32) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(32) + spec = UnboundedContinuousTensorSpec(32) else: raise NotImplementedError @@ -466,7 +466,7 @@ def test_vmap(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) if safe and spec is None: with pytest.raises( @@ -545,7 +545,7 @@ def test_vmap_probabilistic(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) else: raise NotImplementedError @@ -642,7 +642,7 @@ def test_stateful(self, safe, spec_type, lazy): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) kwargs = {} @@ -719,7 +719,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) else: raise NotImplementedError @@ -806,7 +806,7 @@ def test_functional(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") @@ -875,7 +875,7 @@ def test_functional_probabilistic(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) else: raise NotImplementedError @@ -961,7 +961,7 @@ def test_functional_with_buffer(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 7) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(7) + spec = UnboundedContinuousTensorSpec(7) if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") @@ -1033,7 +1033,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 7) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(7) + spec = UnboundedContinuousTensorSpec(7) else: raise NotImplementedError @@ -1125,7 +1125,7 @@ def test_vmap(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") @@ -1224,7 +1224,7 @@ def test_vmap_probabilistic(self, safe, spec_type): elif spec_type == "bounded": spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = NdUnboundedContinuousTensorSpec(4) + spec = UnboundedContinuousTensorSpec(4) else: raise NotImplementedError diff --git a/test/test_transforms.py b/test/test_transforms.py index 5ec2639bf99..03ff454bbda 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -25,12 +25,7 @@ from tensordict import TensorDict from torch import multiprocessing as mp, Tensor from torchrl._utils import prod -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - NdUnboundedContinuousTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs import ( BinarizeReward, CatFrames, @@ -1795,7 +1790,7 @@ def test_r3mnet_transform_observation_spec( if del_keys: exp_ts = CompositeSpec( { - key: NdUnboundedContinuousTensorSpec(r3m_net.outdim, device) + key: UnboundedContinuousTensorSpec(r3m_net.outdim, device) for key in out_keys } ) @@ -1813,7 +1808,7 @@ def test_r3mnet_transform_observation_spec( for key in in_keys: ts_dict[key] = observation_spec[key] for key in out_keys: - ts_dict[key] = NdUnboundedContinuousTensorSpec(r3m_net.outdim, device) + ts_dict[key] = UnboundedContinuousTensorSpec(r3m_net.outdim, device) exp_ts = CompositeSpec(ts_dict) observation_spec_out = r3m_net.transform_observation_spec(observation_spec) @@ -2052,7 +2047,7 @@ def test_vipnet_transform_observation_spec( if del_keys: exp_ts = CompositeSpec( { - key: NdUnboundedContinuousTensorSpec(vip_net.outdim, device) + key: UnboundedContinuousTensorSpec(vip_net.outdim, device) for key in out_keys } ) @@ -2070,7 +2065,7 @@ def test_vipnet_transform_observation_spec( for key in in_keys: ts_dict[key] = observation_spec[key] for key in out_keys: - ts_dict[key] = NdUnboundedContinuousTensorSpec(vip_net.outdim, device) + ts_dict[key] = UnboundedContinuousTensorSpec(vip_net.outdim, device) exp_ts = CompositeSpec(ts_dict) observation_spec_out = vip_net.transform_observation_spec(observation_spec) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 601b65e90db..69e4584cc4b 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -21,7 +21,6 @@ DEVICE_TYPING, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 2442ce5958e..ad2e9c7e0f1 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -36,7 +36,7 @@ _NO_CHECK_SPEC_ENCODE = get_binary_env_var("NO_CHECK_SPEC_ENCODE") -_DEFAULT_SHAPE = torch.Size([1]) +_DEFAULT_SHAPE = torch.Size((1,)) def _default_dtype_and_device( @@ -510,38 +510,6 @@ def __eq__(self, other): ) -@dataclass(repr=False) -class UnboundedContinuousTensorSpec(TensorSpec): - """An unbounded, unidimensional, continuous tensor spec. - - Args: - device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors. - (should be an floating point dtype such as float, double etc.) - - """ - - shape: torch.Size - space: ContinuousBox - device: torch.device = torch.device("cpu") - dtype: torch.dtype = torch.float - domain: str = "" - - def __init__(self, device=None, dtype=None): - dtype, device = _default_dtype_and_device(dtype, device) - box = ContinuousBox(torch.tensor(-np.inf), torch.tensor(np.inf)) - super().__init__(torch.Size((1,)), box, device, dtype, "composite") - - def rand(self, shape=None) -> torch.Tensor: - if shape is None: - shape = torch.Size([]) - shape = [*shape, *self.shape] - return torch.randn(shape, device=self.device, dtype=self.dtype) - - def is_in(self, val: torch.Tensor) -> bool: - return True - - @dataclass(repr=False) class UnboundedDiscreteTensorSpec(TensorSpec): """An unbounded, unidimensional, discrete tensor spec. @@ -722,8 +690,8 @@ def is_in(self, val: torch.Tensor) -> bool: @dataclass(repr=False) -class NdUnboundedContinuousTensorSpec(UnboundedContinuousTensorSpec): - """An unbounded, multi-dimensional, continuous tensor spec. +class UnboundedContinuousTensorSpec(TensorSpec): + """An unbounded continuous tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. @@ -733,7 +701,7 @@ class NdUnboundedContinuousTensorSpec(UnboundedContinuousTensorSpec): def __init__( self, - shape: Union[torch.Size, int], + shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = None, ): @@ -741,14 +709,28 @@ def __init__( shape = torch.Size([shape]) dtype, device = _default_dtype_and_device(dtype, device) - super(UnboundedContinuousTensorSpec, self).__init__( + box = ( + ContinuousBox(torch.tensor(-np.inf), torch.tensor(np.inf)) + if shape == _DEFAULT_SHAPE + else None + ) + super().__init__( shape=shape, - space=None, + space=box, device=device, dtype=dtype, domain="continuous", ) + def rand(self, shape=None) -> torch.Tensor: + if shape is None: + shape = torch.Size([]) + shape = [*shape, *self.shape] + return torch.randn(shape, device=self.device, dtype=self.dtype) + + def is_in(self, val: torch.Tensor) -> bool: + return True + @dataclass(repr=False) class NdUnboundedDiscreteTensorSpec(UnboundedDiscreteTensorSpec): diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 68f3b74cdd5..7d768f34cd3 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -3,11 +3,7 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - NdUnboundedContinuousTensorSpec, -) +from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.common import _EnvWrapper try: @@ -122,14 +118,14 @@ def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 minimum=-1, maximum=1, shape=(env.action_size,), device=self.device ) ) - self.reward_spec = NdUnboundedContinuousTensorSpec( + self.reward_spec = UnboundedContinuousTensorSpec( shape=[ 1, ], device=self.device, ) self.observation_spec = CompositeSpec( - observation=NdUnboundedContinuousTensorSpec( + observation=UnboundedContinuousTensorSpec( shape=(env.observation_size,), device=self.device ) ) diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 47d519ee6fb..56df3eafc8a 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -14,9 +14,9 @@ from torchrl.data import ( BoundedTensorSpec, CompositeSpec, - NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, TensorSpec, + UnboundedContinuousTensorSpec, ) from ...data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict @@ -74,7 +74,7 @@ def _dmcontrol_to_torchrl_spec_transform( if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): - return NdUnboundedContinuousTensorSpec( + return UnboundedContinuousTensorSpec( shape=shape, dtype=dtype, device=device ) else: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 285f9d6d52c..19add975418 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -15,9 +15,9 @@ CompositeSpec, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, + UnboundedContinuousTensorSpec, ) from ..._utils import implement_for @@ -271,7 +271,7 @@ def _make_specs(self, env: "gym.Env") -> None: else: observation_spec = CompositeSpec(observation=observation_spec) self.observation_spec = observation_spec - self.reward_spec = NdUnboundedContinuousTensorSpec( + self.reward_spec = UnboundedContinuousTensorSpec( shape=[1], device=self.device, ) diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index 93bd9325300..e05d7cce672 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -9,9 +9,9 @@ from torch.utils import dlpack as torch_dlpack from torchrl.data import ( CompositeSpec, - NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, TensorSpec, + UnboundedContinuousTensorSpec, ) @@ -94,7 +94,7 @@ def _tensordict_to_object(tensordict: TensorDictBase, object_example): def _extract_spec(data: Union[torch.Tensor, TensorDictBase]) -> TensorSpec: if isinstance(data, torch.Tensor): if data.dtype in (torch.float, torch.double, torch.half): - return NdUnboundedContinuousTensorSpec( + return UnboundedContinuousTensorSpec( shape=data.shape, dtype=data.dtype, device=data.device ) else: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 84c3a2fb901..98bb12ab61b 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -9,10 +9,10 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, + UnboundedContinuousTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs import GymLikeEnv @@ -74,7 +74,7 @@ def _jumanji_to_torchrl_spec_transform( if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): - return NdUnboundedContinuousTensorSpec( + return UnboundedContinuousTensorSpec( shape=shape, dtype=dtype, device=device ) else: diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index e63825f13f7..d7885268164 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -28,18 +28,18 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): Example: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.observation_spec = CompositeSpec( - ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + ... hidden_observation=UnboundedContinuousTensorSpec((4,)) ... ) ... self.input_spec = CompositeSpec( - ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)), - ... action=NdUnboundedContinuousTensorSpec((1,)), + ... hidden_observation=UnboundedContinuousTensorSpec((4,)), + ... action=UnboundedContinuousTensorSpec((1,)), ... ) - ... self.reward_spec = NdUnboundedContinuousTensorSpec((1,)) + ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict({}, diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index fb902d692f7..f1606a3c332 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -40,10 +40,10 @@ def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" super().set_specs_from_env(env) # self.observation_spec = CompositeSpec( - # next_state=NdUnboundedContinuousTensorSpec( + # next_state=UnboundedContinuousTensorSpec( # shape=self.prior_shape, device=self.device # ), - # next_belief=NdUnboundedContinuousTensorSpec( + # next_belief=UnboundedContinuousTensorSpec( # shape=self.belief_shape, device=self.device # ), # ) diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 43b3d61c069..42a44b0ae92 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -12,8 +12,8 @@ from torchrl.data.tensor_specs import ( CompositeSpec, - NdUnboundedContinuousTensorSpec, TensorSpec, + UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( @@ -98,7 +98,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = NdUnboundedContinuousTensorSpec( + observation_spec[out_key] = UnboundedContinuousTensorSpec( shape=torch.Size([*dim, self.outdim]), device=device ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ac7858fe119..95b48d755dc 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -21,7 +21,6 @@ CompositeSpec, ContinuousBox, DEVICE_TYPING, - NdUnboundedContinuousTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, ) @@ -1855,7 +1854,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec device = spec0.device shape[self.dim] = sum_shape shape = torch.Size(shape) - observation_spec[out_key] = NdUnboundedContinuousTensorSpec( + observation_spec[out_key] = UnboundedContinuousTensorSpec( shape=shape, dtype=spec0.dtype, device=device, @@ -2066,7 +2065,7 @@ class TensorDictPrimer(Transform): >>> from torchrl.envs.libs.gym import GymEnv >>> base_env = GymEnv("Pendulum-v1") >>> env = TransformedEnv(base_env) - >>> env.append_transform(TensorDictPrimer(mykey=NdUnboundedContinuousTensorSpec([3]))) + >>> env.append_transform(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3]))) >>> print(env.reset()) TensorDict( fields={ diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index ecb571656a3..3541d786da4 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -12,8 +12,8 @@ from torchrl.data.tensor_specs import ( CompositeSpec, - NdUnboundedContinuousTensorSpec, TensorSpec, + UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms import ( @@ -90,7 +90,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = NdUnboundedContinuousTensorSpec( + observation_spec[out_key] = UnboundedContinuousTensorSpec( shape=torch.Size([*dim, self.outdim]), device=device ) diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index dd69c8b4e16..aa1a8f40b8b 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -45,20 +45,20 @@ class CEMPlanner(MPCPlannerBase): Examples: >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> from torchrl.envs.model_based import ModelBasedEnvBase >>> from torchrl.modules import SafeModule >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.observation_spec = CompositeSpec( - ... next_hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + ... next_hidden_observation=UnboundedContinuousTensorSpec((4,)) ... ) ... self.input_spec = CompositeSpec( - ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)), - ... action=NdUnboundedContinuousTensorSpec((1,)), + ... hidden_observation=UnboundedContinuousTensorSpec((4,)), + ... action=UnboundedContinuousTensorSpec((1,)), ... ) - ... self.reward_spec = NdUnboundedContinuousTensorSpec((1,)) + ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict({}, diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index ed0cb94bd11..e0382326fcb 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -29,10 +29,10 @@ class Actor(SafeModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec + >>> from torchrl.data import UnboundedContinuousTensorSpec >>> from torchrl.modules import Actor >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) - >>> action_spec = NdUnboundedContinuousTensorSpec(4) + >>> action_spec = UnboundedContinuousTensorSpec(4) >>> module = torch.nn.Linear(4, 4) >>> td_module = Actor( ... module=module, @@ -148,7 +148,7 @@ class ValueOperator(SafeModule): >>> from tensordict import TensorDict >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn - >>> from torchrl.data import NdUnboundedContinuousTensorSpec + >>> from torchrl.data import UnboundedContinuousTensorSpec >>> from torchrl.modules import ValueOperator >>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,]) >>> class CustomModule(nn.Module): @@ -569,9 +569,9 @@ class ActorValueOperator(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule - >>> from torchrl.data import NdUnboundedContinuousTensorSpec, BoundedTensorSpec + >>> from torchrl.data import UnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper - >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) + >>> spec_hidden = UnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, @@ -703,9 +703,9 @@ class ActorCriticOperator(ActorValueOperator): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule - >>> from torchrl.data import NdUnboundedContinuousTensorSpec, BoundedTensorSpec + >>> from torchrl.data import UnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP - >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) + >>> spec_hidden = UnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, @@ -832,7 +832,7 @@ class ActorCriticWrapper(SafeSequential): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec, BoundedTensorSpec + >>> from torchrl.data import UnboundedContinuousTensorSpec, BoundedTensorSpec >>> from torchrl.modules import ( ActorCriticWrapper, ProbabilisticActor, diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index d666031d028..2d95e539fda 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -121,10 +121,10 @@ class SafeModule(TensorDictModule): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn.functional_modules import make_functional - >>> from torchrl.data import NdUnboundedContinuousTensorSpec + >>> from torchrl.data import UnboundedContinuousTensorSpec >>> from torchrl.modules import SafeModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) - >>> spec = NdUnboundedContinuousTensorSpec(8) + >>> spec = UnboundedContinuousTensorSpec(8) >>> module = torch.nn.GRUCell(4, 8) >>> td_fmodule = SafeModule( ... module=module, diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 5e4886b70b2..31f98afe6f6 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -34,11 +34,11 @@ class SafeSequential(TensorDictSequential, SafeModule): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn.functional_modules import make_functional - >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> from torchrl.modules import TanhNormal, SafeSequential, SafeModule, NormalParamWrapper >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) - >>> spec1 = CompositeSpec(hidden=NdUnboundedContinuousTensorSpec(4), loc=None, scale=None) + >>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None) >>> net1 = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> module1 = SafeModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( @@ -49,7 +49,7 @@ class SafeSequential(TensorDictSequential, SafeModule): ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) - >>> spec2 = NdUnboundedContinuousTensorSpec(8) + >>> spec2 = UnboundedContinuousTensorSpec(8) >>> module2 = torch.nn.Linear(4, 8) >>> td_module2 = SafeModule( ... module=module2, diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 947e3e9a613..99db2a95652 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -13,7 +13,7 @@ from torchrl.data import ( CompositeSpec, DiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs import TensorDictPrimer, TransformedEnv @@ -1651,11 +1651,11 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): out_keys=["loc", "scale"], spec=CompositeSpec( **{ - "loc": NdUnboundedContinuousTensorSpec( + "loc": UnboundedContinuousTensorSpec( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), - "scale": NdUnboundedContinuousTensorSpec( + "scale": UnboundedContinuousTensorSpec( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), @@ -1701,10 +1701,10 @@ def _dreamer_make_actor_real( out_keys=["loc", "scale"], spec=CompositeSpec( **{ - "loc": NdUnboundedContinuousTensorSpec( + "loc": UnboundedContinuousTensorSpec( proof_environment.action_spec.shape, ), - "scale": NdUnboundedContinuousTensorSpec( + "scale": UnboundedContinuousTensorSpec( proof_environment.action_spec.shape, ), } @@ -1798,8 +1798,8 @@ def _dreamer_make_mbenv( model_based_env.set_specs_from_env(proof_environment) model_based_env = TransformedEnv(model_based_env) default_dict = { - "state": NdUnboundedContinuousTensorSpec(state_dim), - "belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": UnboundedContinuousTensorSpec(state_dim), + "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), # "action": proof_environment.action_spec, } model_based_env.append_transform( From 1036055036ff76cd0385e1d863b1fc3cccffc105 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 30 Dec 2022 03:27:38 +0100 Subject: [PATCH 3/4] Fix tests --- test/test_tensor_spec.py | 56 +++++++++++++++------------ test/test_tensordictmodules.py | 2 +- test/test_transforms.py | 4 +- torchrl/data/tensor_specs.py | 2 +- torchrl/envs/transforms/transforms.py | 1 + 5 files changed, 38 insertions(+), 27 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 1733c0b372c..1ce997c944a 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -28,7 +28,9 @@ def test_bounded(dtype): np.random.seed(0) for _ in range(100): bounds = torch.randn(2).sort()[0] - ts = BoundedTensorSpec(bounds[0].item(), bounds[1].item(), dtype=dtype) + ts = BoundedTensorSpec( + bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype + ) _dtype = dtype if dtype is None: _dtype = torch.get_default_dtype() @@ -473,7 +475,7 @@ def test_nested_composite_spec_update(self, is_complete, device, dtype): td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device)) ts.update(td2) td2 = CompositeSpec( - nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device)) + nested_cp=CompositeSpec(act=UnboundedContinuousTensorSpec(device=device)) ) ts.update(td2) assert set(ts.keys()) == { @@ -509,25 +511,31 @@ def test_equality_bounded(self): device = "cpu" dtype = torch.float16 - ts = BoundedTensorSpec(minimum, maximum, device, dtype) + ts = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype) - ts_same = BoundedTensorSpec(minimum, maximum, device, dtype) + ts_same = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype) assert ts == ts_same - ts_other = BoundedTensorSpec(minimum + 1, maximum, device, dtype) + ts_other = BoundedTensorSpec( + minimum + 1, maximum, torch.Size((1,)), device, dtype + ) assert ts != ts_other - ts_other = BoundedTensorSpec(minimum, maximum + 1, device, dtype) + ts_other = BoundedTensorSpec( + minimum, maximum + 1, torch.Size((1,)), device, dtype + ) assert ts != ts_other - ts_other = BoundedTensorSpec(minimum, maximum, "cpu:0", dtype) + ts_other = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), "cpu:0", dtype) assert ts != ts_other - ts_other = BoundedTensorSpec(minimum, maximum, device, torch.float64) + ts_other = BoundedTensorSpec( + minimum, maximum, torch.Size((1,)), device, torch.float64 + ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device, dtype), ts + UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -555,7 +563,7 @@ def test_equality_onehot(self): assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device, dtype), ts + UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -563,19 +571,19 @@ def test_equality_unbounded(self): device = "cpu" dtype = torch.float16 - ts = UnboundedContinuousTensorSpec(device, dtype) + ts = UnboundedContinuousTensorSpec(device=device, dtype=dtype) - ts_same = UnboundedContinuousTensorSpec(device, dtype) + ts_same = UnboundedContinuousTensorSpec(device=device, dtype=dtype) assert ts == ts_same - ts_other = UnboundedContinuousTensorSpec("cpu:0", dtype) + ts_other = UnboundedContinuousTensorSpec(device="cpu:0", dtype=dtype) assert ts != ts_other - ts_other = UnboundedContinuousTensorSpec(device, torch.float64) + ts_other = UnboundedContinuousTensorSpec(device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, device, dtype), ts + BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -615,7 +623,7 @@ def test_equality_ndbounded(self): assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, device, dtype), ts + UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -643,7 +651,7 @@ def test_equality_discrete(self): assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device, dtype), ts + UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -681,7 +689,7 @@ def test_equality_ndunbounded(self, shape): assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, device, dtype), ts + BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -705,7 +713,7 @@ def test_equality_binary(self): assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, device, dtype), ts + BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -746,7 +754,7 @@ def test_equality_multi_onehot(self, nvec): assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, device, dtype), ts + BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -756,9 +764,9 @@ def test_equality_composite(self): device = "cpu" dtype = torch.float16 - bounded = BoundedTensorSpec(0, 1, device, dtype) - bounded_same = BoundedTensorSpec(0, 1, device, dtype) - bounded_other = BoundedTensorSpec(0, 2, device, dtype) + bounded = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype) + bounded_same = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype) + bounded_other = BoundedTensorSpec(0, 2, torch.Size((1,)), device, dtype) nd = BoundedTensorSpec( minimum=minimum, maximum=maximum + 1, device=device, dtype=dtype @@ -937,7 +945,7 @@ def test_categorical_action_spec_encode(self): ).all() def test_bounded_rand(self): - spec = BoundedTensorSpec(-3, 3) + spec = BoundedTensorSpec(-3, 3, torch.Size((1,))) sample = torch.stack([spec.rand() for _ in range(100)]) assert (-3 <= sample).all() and (3 >= sample).all() diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 9e44a77f8c0..eddf40546d2 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -114,7 +114,7 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = BoundedTensorSpec(-0.1, 0.1, torch.Size((1,)), 4) elif spec_type == "unbounded": spec = UnboundedContinuousTensorSpec(4) diff --git a/test/test_transforms.py b/test/test_transforms.py index 03ff454bbda..8748a9cb1a6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -308,7 +308,9 @@ def test_added_transforms_are_in_eval_mode(): class TestTransformedEnv: def test_independent_obs_specs_from_shared_env(self): - obs_spec = CompositeSpec(observation=BoundedTensorSpec(minimum=0, maximum=10)) + obs_spec = CompositeSpec( + observation=BoundedTensorSpec(minimum=0, maximum=10, shape=torch.Size((1,))) + ) base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec) t1 = TransformedEnv(base_env, transform=ObservationNorm(loc=3, scale=2)) t2 = TransformedEnv(base_env, transform=ObservationNorm(loc=1, scale=6)) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ad2e9c7e0f1..580f1e699a5 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -565,7 +565,7 @@ def __init__( self, minimum: Union[float, torch.Tensor, np.ndarray], maximum: Union[float, torch.Tensor, np.ndarray], - shape: Optional[torch.Size] = _DEFAULT_SHAPE, + shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[torch.dtype, str]] = None, ): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 95b48d755dc..92185d64768 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -852,6 +852,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: return BoundedTensorSpec( self.clamp_min, self.clamp_max, + torch.Size((1,)), device=reward_spec.device, dtype=reward_spec.dtype, ) From 0112fb7896726835f9a6ecc4afd202b1cf69442f Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 30 Dec 2022 03:45:36 +0100 Subject: [PATCH 4/4] Remove NbUnboundedDiscreteTensorSpec --- test/test_tensordictmodules.py | 2 +- torchrl/data/__init__.py | 1 - torchrl/data/tensor_specs.py | 62 ++++++++++----------------------- torchrl/envs/libs/dm_control.py | 6 ++-- torchrl/envs/libs/jax_utils.py | 4 +-- torchrl/envs/libs/jumanji.py | 6 ++-- 6 files changed, 25 insertions(+), 56 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index eddf40546d2..9e44a77f8c0 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -114,7 +114,7 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, torch.Size((1,)), 4) + spec = BoundedTensorSpec(-0.1, 0.1, 4) elif spec_type == "unbounded": spec = UnboundedContinuousTensorSpec(4) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 69e4584cc4b..4b880b98760 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -21,7 +21,6 @@ DEVICE_TYPING, DiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 580f1e699a5..233ca7ab698 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -510,45 +510,6 @@ def __eq__(self, other): ) -@dataclass(repr=False) -class UnboundedDiscreteTensorSpec(TensorSpec): - """An unbounded, unidimensional, discrete tensor spec. - - Args: - device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors - (should be an integer dtype such as long, uint8 etc.) - - """ - - shape: torch.Size - space: ContinuousBox - device: torch.device = torch.device("cpu") - dtype: torch.dtype = torch.uint8 - domain: str = "" - - def __init__(self, device=None, dtype=None): - dtype, device = _default_dtype_and_device(dtype, device) - box = ContinuousBox( - torch.tensor(torch.iinfo(dtype).min, device=device), - torch.tensor(torch.iinfo(dtype).max, device=device), - ) - super().__init__(torch.Size((1,)), box, device, dtype, "composite") - - def rand(self, shape=None) -> torch.Tensor: - if shape is None: - shape = torch.Size([]) - interval = self.space.maximum - self.space.minimum - r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) - r = r * interval - r = self.space.minimum + r - r = r.to(self.dtype) - return r.to(self.device) - - def is_in(self, val: torch.Tensor) -> bool: - return True - - @dataclass(repr=False) class BoundedTensorSpec(TensorSpec): """A bounded continuous tensor spec. @@ -565,7 +526,7 @@ def __init__( self, minimum: Union[float, torch.Tensor, np.ndarray], maximum: Union[float, torch.Tensor, np.ndarray], - shape: Optional[torch.Size] = None, + shape: Optional[Union[torch.Size, int]] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[torch.dtype, str]] = None, ): @@ -733,8 +694,8 @@ def is_in(self, val: torch.Tensor) -> bool: @dataclass(repr=False) -class NdUnboundedDiscreteTensorSpec(UnboundedDiscreteTensorSpec): - """An unbounded, multi-dimensional, discrete tensor spec. +class UnboundedDiscreteTensorSpec(TensorSpec): + """An unbounded discrete tensor spec. Args: device (str, int or torch.device, optional): device of the tensors. @@ -744,7 +705,7 @@ class NdUnboundedDiscreteTensorSpec(UnboundedDiscreteTensorSpec): def __init__( self, - shape: Union[torch.Size, int], + shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = None, ): @@ -763,7 +724,7 @@ def __init__( torch.full(shape, max_value, device=device), ) - super(UnboundedDiscreteTensorSpec, self).__init__( + super().__init__( shape=shape, space=space, device=device, @@ -771,6 +732,19 @@ def __init__( domain="continuous", ) + def rand(self, shape=None) -> torch.Tensor: + if shape is None: + shape = torch.Size([]) + interval = self.space.maximum - self.space.minimum + r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) + r = r * interval + r = self.space.minimum + r + r = r.to(self.dtype) + return r.to(self.device) + + def is_in(self, val: torch.Tensor) -> bool: + return True + @dataclass(repr=False) class BinaryDiscreteTensorSpec(TensorSpec): diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 56df3eafc8a..aece9e33dda 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -14,9 +14,9 @@ from torchrl.data import ( BoundedTensorSpec, CompositeSpec, - NdUnboundedDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, ) from ...data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict @@ -78,9 +78,7 @@ def _dmcontrol_to_torchrl_spec_transform( shape=shape, dtype=dtype, device=device ) else: - return NdUnboundedDiscreteTensorSpec( - shape=shape, dtype=dtype, device=device - ) + return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device) else: raise NotImplementedError(type(spec)) diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index e05d7cce672..1319b5cf77b 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -9,9 +9,9 @@ from torch.utils import dlpack as torch_dlpack from torchrl.data import ( CompositeSpec, - NdUnboundedDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, ) @@ -98,7 +98,7 @@ def _extract_spec(data: Union[torch.Tensor, TensorDictBase]) -> TensorSpec: shape=data.shape, dtype=data.dtype, device=data.device ) else: - return NdUnboundedDiscreteTensorSpec( + return UnboundedDiscreteTensorSpec( shape=data.shape, dtype=data.dtype, device=data.device ) elif isinstance(data, TensorDictBase): diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 98bb12ab61b..95a213b96a3 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -9,10 +9,10 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, - NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs import GymLikeEnv @@ -78,9 +78,7 @@ def _jumanji_to_torchrl_spec_transform( shape=shape, dtype=dtype, device=device ) else: - return NdUnboundedDiscreteTensorSpec( - shape=shape, dtype=dtype, device=device - ) + return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device) elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"): new_spec = {} for key, value in spec.__dict__.items():