From 00424b7cbc74641b3d2c6c6f63a3f46cb25b1ba0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 1 Jan 2023 08:17:55 +0100 Subject: [PATCH 1/7] init (#779) --- torchrl/collectors/collectors.py | 37 +++++++++++++++++++++++++++----- torchrl/envs/common.py | 9 +++++++- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index aa3d79e1499..47a7cec7f19 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import _pickle import abc import inspect import os @@ -708,7 +709,14 @@ def shutdown(self) -> None: del self.env def __del__(self): - self.shutdown() # make sure env is closed + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass def state_dict(self) -> OrderedDict: """Returns the local state_dict of the data collector (environment and policy). @@ -1016,7 +1024,19 @@ def _run_processes(self) -> None: } proc = mp.Process(target=_main_async_collector, kwargs=kwargs) # proc.daemon can't be set as daemonic processes may be launched by the process itself - proc.start() + try: + proc.start() + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. +This error is likely due to an attempt to use a ParallelEnv in a +multiprocessed data collector. To do this, consider wrapping your +lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: +`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. +This will not only ensure that your lambda function is cloud-pickled once, but +also that the state dict is synchronised across processes if needed.""" + ) pipe_child.close() self.procs.append(proc) self.pipes.append(pipe_parent) @@ -1027,7 +1047,14 @@ def _run_processes(self) -> None: self.closed = False def __del__(self): - self.shutdown() + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass def shutdown(self) -> None: """Shuts down all processes. This operation is irreversible.""" @@ -1624,8 +1651,8 @@ def _main_async_collector( f"without receiving a command from main. Consider increasing the maximum idle count " f"if this is expected via the environment variable MAX_IDLE_COUNT " f"(current value is {_MAX_IDLE_COUNT})." - f"\nIf this occurs at the end of a function, it means that your collector has not been " - f"collected, consider calling `collector.shutdown()` or `del collector` at the end of the function." + f"\nIf this occurs at the end of a function or program, it means that your collector has not been " + f"collected, consider calling `collector.shutdown()` or `del collector` before ending the program." ) continue if msg in ("continue", "continue_random"): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 406013d2480..aec32eff395 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -677,7 +677,14 @@ def __del__(self): # if del occurs before env has been set up, we don't want a recursion # error if "is_closed" in self.__dict__ and not self.is_closed: - self.close() + try: + self.close() + except Exception: + # a TypeError will typically be raised if the env is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass def to(self, device: DEVICE_TYPING) -> EnvBase: device = torch.device(device) From 66b7637d9344b42ea2153a6bcc081b8a74b6534a Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Sun, 1 Jan 2023 21:13:45 +0100 Subject: [PATCH 2/7] Add MultDiscreteTensorSpec --- torchrl/data/tensor_specs.py | 79 ++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 233ca7ab698..11263b76da6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -982,6 +982,85 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: return super().to_numpy(val, safe) +@dataclass(repr=False) +class MultDiscreteTensorSpec(DiscreteTensorSpec): + """A concatenation of discrete tensor spec. + + Args: + nvec (iterable of integers): cardinality of each of the elements of + the tensor. + shape: (torch.Size, optional): shape of the variables, default is "(1,)". + device (str, int or torch.device, optional): device of + the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + + Examples: + >>> ts = MultDiscreteTensorSpec((3,2,3)) + >>> ts.is_in(torch.tensor([2, 0, 1])) + True + >>> ts.is_in(torch.tensor([2, 2, 1])) + False + """ + + def __init__( + self, + nvec: Sequence[int], + shape: Optional[torch.Size] = None, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + ): + if shape is None: + shape = torch.Size([]) + dtype, device = _default_dtype_and_device(dtype, device) + self._size = len(nvec) + self._individual_shape = shape + self._first_dim_size = shape[0] if len(shape) != 0 else 1 + shape = torch.Size([self._first_dim_size * self._size, *shape[1:]]) + space = BoxList([DiscreteBox(n) for n in nvec]) + super(DiscreteTensorSpec, self).__init__( + shape, space, device, dtype, domain="discrete" + ) + + def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: + if shape is None: + shape = torch.Size([]) + return torch.cat( + [ + torch.randint( + 0, + space.n, + torch.Size([*shape, *self._individual_shape]), + device=self.device, + dtype=self.dtype, + ) + for space in self.space + ] + ) + + def _split(self, val: torch.Tensor): + return val.split(self._first_dim_size) + + def _project(self, val: torch.Tensor) -> torch.Tensor: + if val.dtype not in (torch.int, torch.long): + val = torch.round(val) + vals = self._split(val) + return torch.cat( + [ + _val.clamp_(min=0, max=space.n - 1) + for _val, space in zip(vals, self.space) + ] + ) + + def is_in(self, val: torch.Tensor) -> bool: + vals = self._split(val) + return self._size == len(vals) and all( + [ + (0 <= _val).all() and (_val < space.n).all() + for _val, space in zip(vals, self.space) + ] + ) + + class CompositeSpec(TensorSpec): """A composition of TensorSpecs. From 7d50ef086b5ca7e60362eb6c3c277998c2ef4035 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Mon, 2 Jan 2023 18:31:39 +0100 Subject: [PATCH 3/7] Add test_mult_discrete --- test/test_tensor_spec.py | 49 ++++++++++++++++++++++++++++++++++++ torchrl/data/tensor_specs.py | 25 ++++++++---------- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 1ce997c944a..fd881c49947 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -16,6 +16,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, + MultDiscreteTensorSpec, MultOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, @@ -233,6 +234,54 @@ def test_mult_onehot(shape, ns): assert (ts.encode(np_r) == r).all() +@pytest.mark.parametrize( + "ns", + [ + [ + 5, + ], + [5, 2, 3], + [4, 5, 1, 3], + ], +) +@pytest.mark.parametrize( + "shape", + [ + None, + [], + torch.Size([3]), + torch.Size([4, 5]), + ], +) +def test_mult_discrete(shape, ns): + torch.manual_seed(0) + np.random.seed(0) + ts = MultDiscreteTensorSpec(ns) + _real_shape = shape if shape is not None else [] + _len_ns = [len(ns)] if len(ns) > 1 else [] + for _ in range(100): + r = ts.rand(shape) + + assert r.shape == torch.Size( + [ + *_real_shape, + *_len_ns, + ] + ) + assert ts.is_in(r) + rand = torch.rand( + torch.Size( + [ + *_real_shape, + *_len_ns, + ] + ) + ) + projection = ts._project(rand) + assert rand.shape == projection.shape + assert ts.is_in(projection) + + @pytest.mark.parametrize("is_complete", [True, False]) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 11263b76da6..849849265f5 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -989,7 +989,6 @@ class MultDiscreteTensorSpec(DiscreteTensorSpec): Args: nvec (iterable of integers): cardinality of each of the elements of the tensor. - shape: (torch.Size, optional): shape of the variables, default is "(1,)". device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. @@ -1005,17 +1004,12 @@ class MultDiscreteTensorSpec(DiscreteTensorSpec): def __init__( self, nvec: Sequence[int], - shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, ): - if shape is None: - shape = torch.Size([]) dtype, device = _default_dtype_and_device(dtype, device) self._size = len(nvec) - self._individual_shape = shape - self._first_dim_size = shape[0] if len(shape) != 0 else 1 - shape = torch.Size([self._first_dim_size * self._size, *shape[1:]]) + shape = torch.Size([self._size]) space = BoxList([DiscreteBox(n) for n in nvec]) super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" @@ -1024,21 +1018,23 @@ def __init__( def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - return torch.cat( + x = torch.cat( [ torch.randint( 0, space.n, - torch.Size([*shape, *self._individual_shape]), + torch.Size([1, *shape]), device=self.device, dtype=self.dtype, ) for space in self.space ] - ) + ).squeeze() + _size = [self._size] if self._size > 1 else [] + return x.T.reshape([*shape, *_size]) def _split(self, val: torch.Tensor): - return val.split(self._first_dim_size) + return [val] if self._size < 2 else val.split(1, -1) def _project(self, val: torch.Tensor) -> torch.Tensor: if val.dtype not in (torch.int, torch.long): @@ -1046,10 +1042,11 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: vals = self._split(val) return torch.cat( [ - _val.clamp_(min=0, max=space.n - 1) + _val.clamp_(min=0, max=space.n - 1).unsqueeze(0) for _val, space in zip(vals, self.space) - ] - ) + ], + dim=-1, + ).squeeze() def is_in(self, val: torch.Tensor) -> bool: vals = self._split(val) From a3f24c02278c0fe006c99619521f7e52063d11f4 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Mon, 2 Jan 2023 19:51:41 +0100 Subject: [PATCH 4/7] Add test_equality_multi_discrete --- test/test_tensor_spec.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index fd881c49947..660cc544c9d 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -807,6 +807,39 @@ def test_equality_multi_onehot(self, nvec): ) assert ts != ts_other + @pytest.mark.parametrize("nvec", [[3], [3, 4], [3, 4, 5]]) + def test_equality_multi_discrete(self, nvec): + device = "cpu" + dtype = torch.float16 + + ts = MultDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + + ts_same = MultDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + assert ts == ts_same + + other_nvec = np.array(nvec) + 3 + ts_other = MultDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + assert ts != ts_other + + other_nvec = [12] + ts_other = MultDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + assert ts != ts_other + + other_nvec = [12, 13] + ts_other = MultDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + assert ts != ts_other + + ts_other = MultDiscreteTensorSpec(nvec=nvec, device="cpu:0", dtype=dtype) + assert ts != ts_other + + ts_other = MultDiscreteTensorSpec(nvec=nvec, device=device, dtype=torch.float64) + assert ts != ts_other + + ts_other = TestEquality._ts_make_all_fields_equal( + BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + ) + assert ts != ts_other + def test_equality_composite(self): minimum = np.arange(12).reshape((3, 4)) maximum = minimum + 100 From 1f97db7e47449b68132672481a87ff8fab5a5b57 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Mon, 2 Jan 2023 20:13:06 +0100 Subject: [PATCH 5/7] Add conversions methods --- test/test_tensor_spec.py | 39 ++++++++++++++++++++++++++++++++++++ torchrl/data/__init__.py | 1 + torchrl/data/tensor_specs.py | 16 +++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 660cc544c9d..3ee32acd6c1 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -282,6 +282,45 @@ def test_mult_discrete(shape, ns): assert ts.is_in(projection) +@pytest.mark.parametrize( + "n", + [ + 1, + 4, + 7, + 99, + ], +) +@pytest.mark.parametrize("device", get_available_devices()) +def test_discrete_conversion(n, device): + categorical = DiscreteTensorSpec(n, device=device) + one_hot = OneHotDiscreteTensorSpec(n, device=device) + + assert categorical != one_hot + assert categorical.to_onehot() == one_hot + assert one_hot.to_categorical() == categorical + + +@pytest.mark.parametrize( + "ns", + [ + [ + 5, + ], + [5, 2, 3], + [4, 5, 1, 3], + ], +) +@pytest.mark.parametrize("device", get_available_devices()) +def test_mult_discrete_conversion(ns, device): + categorical = MultDiscreteTensorSpec(ns, device=device) + one_hot = MultOneHotDiscreteTensorSpec(ns, device=device) + + assert categorical != one_hot + assert categorical.to_onehot() == one_hot + assert one_hot.to_categorical() == categorical + + @pytest.mark.parametrize("is_complete", [True, False]) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 4b880b98760..b359f54b8ef 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -20,6 +20,7 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, + MultDiscreteTensorSpec, MultOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 849849265f5..e34244abb1d 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -509,6 +509,9 @@ def __eq__(self, other): and self.use_register == other.use_register ) + def to_categorical(self) -> DiscreteTensorSpec: + return DiscreteTensorSpec(self.space.n, device=self.device, dtype=self.dtype) + @dataclass(repr=False) class BoundedTensorSpec(TensorSpec): @@ -905,6 +908,11 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: vals = self._split(val) return torch.cat([super()._project(_val) for _val in vals], -1) + def to_categorical(self) -> MultDiscreteTensorSpec: + return MultDiscreteTensorSpec( + [_space.n for _space in self.space], self.device, self.dtype + ) + class DiscreteTensorSpec(TensorSpec): """A discrete tensor spec. @@ -981,6 +989,9 @@ def __eq__(self, other): def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: return super().to_numpy(val, safe) + def to_onehot(self) -> OneHotDiscreteTensorSpec: + return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype) + @dataclass(repr=False) class MultDiscreteTensorSpec(DiscreteTensorSpec): @@ -1057,6 +1068,11 @@ def is_in(self, val: torch.Tensor) -> bool: ] ) + def to_onehot(self) -> MultOneHotDiscreteTensorSpec: + return MultOneHotDiscreteTensorSpec( + [_space.n for _space in self.space], self.device, self.dtype + ) + class CompositeSpec(TensorSpec): """A composition of TensorSpecs. From 435d1e20b68368ccac97c0cb5cca6ae11b4d8e90 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Mon, 2 Jan 2023 20:17:36 +0100 Subject: [PATCH 6/7] Add support for Categorical MultiDiscrete Gym space --- torchrl/envs/libs/gym.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 13faf2262d3..3823d7d52a9 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -13,6 +13,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, + MultDiscreteTensorSpec, MultOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, @@ -68,7 +69,11 @@ def _gym_to_torchrl_spec_transform( elif isinstance(spec, gym.spaces.multi_binary.MultiBinary): return BinaryDiscreteTensorSpec(spec.n, device=device) elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): - return MultOneHotDiscreteTensorSpec(spec.nvec, device=device) + return ( + MultDiscreteTensorSpec(spec.nvec, device=device) + if categorical_action_encoding + else MultOneHotDiscreteTensorSpec(spec.nvec, device=device) + ) elif isinstance(spec, gym.spaces.Box): shape = spec.shape if not len(shape): From 1147efcf407ca8175c43fdac357dbb109b463d51 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Tue, 3 Jan 2023 00:31:41 +0100 Subject: [PATCH 7/7] Renaming MultDiscreteTensorSpec to MultiDiscreteTensorSpec --- test/test_tensor_spec.py | 26 ++++++++++++++------------ torchrl/data/__init__.py | 2 +- torchrl/data/tensor_specs.py | 8 ++++---- torchrl/envs/libs/gym.py | 4 ++-- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 3ee32acd6c1..84367a7c4e2 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -16,7 +16,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - MultDiscreteTensorSpec, + MultiDiscreteTensorSpec, MultOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, @@ -253,10 +253,10 @@ def test_mult_onehot(shape, ns): torch.Size([4, 5]), ], ) -def test_mult_discrete(shape, ns): +def test_multi_discrete(shape, ns): torch.manual_seed(0) np.random.seed(0) - ts = MultDiscreteTensorSpec(ns) + ts = MultiDiscreteTensorSpec(ns) _real_shape = shape if shape is not None else [] _len_ns = [len(ns)] if len(ns) > 1 else [] for _ in range(100): @@ -312,8 +312,8 @@ def test_discrete_conversion(n, device): ], ) @pytest.mark.parametrize("device", get_available_devices()) -def test_mult_discrete_conversion(ns, device): - categorical = MultDiscreteTensorSpec(ns, device=device) +def test_multi_discrete_conversion(ns, device): + categorical = MultiDiscreteTensorSpec(ns, device=device) one_hot = MultOneHotDiscreteTensorSpec(ns, device=device) assert categorical != one_hot @@ -851,27 +851,29 @@ def test_equality_multi_discrete(self, nvec): device = "cpu" dtype = torch.float16 - ts = MultDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) - ts_same = MultDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts_same = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) assert ts == ts_same other_nvec = np.array(nvec) + 3 - ts_other = MultDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12] - ts_other = MultDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12, 13] - ts_other = MultDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other - ts_other = MultDiscreteTensorSpec(nvec=nvec, device="cpu:0", dtype=dtype) + ts_other = MultiDiscreteTensorSpec(nvec=nvec, device="cpu:0", dtype=dtype) assert ts != ts_other - ts_other = MultDiscreteTensorSpec(nvec=nvec, device=device, dtype=torch.float64) + ts_other = MultiDiscreteTensorSpec( + nvec=nvec, device=device, dtype=torch.float64 + ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index b359f54b8ef..4f0fe172d04 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -20,7 +20,7 @@ CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, - MultDiscreteTensorSpec, + MultiDiscreteTensorSpec, MultOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index e34244abb1d..77efa407b9e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -908,8 +908,8 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: vals = self._split(val) return torch.cat([super()._project(_val) for _val in vals], -1) - def to_categorical(self) -> MultDiscreteTensorSpec: - return MultDiscreteTensorSpec( + def to_categorical(self) -> MultiDiscreteTensorSpec: + return MultiDiscreteTensorSpec( [_space.n for _space in self.space], self.device, self.dtype ) @@ -994,7 +994,7 @@ def to_onehot(self) -> OneHotDiscreteTensorSpec: @dataclass(repr=False) -class MultDiscreteTensorSpec(DiscreteTensorSpec): +class MultiDiscreteTensorSpec(DiscreteTensorSpec): """A concatenation of discrete tensor spec. Args: @@ -1005,7 +1005,7 @@ class MultDiscreteTensorSpec(DiscreteTensorSpec): dtype (str or torch.dtype, optional): dtype of the tensors. Examples: - >>> ts = MultDiscreteTensorSpec((3,2,3)) + >>> ts = MultiDiscreteTensorSpec((3,2,3)) >>> ts.is_in(torch.tensor([2, 0, 1])) True >>> ts.is_in(torch.tensor([2, 2, 1])) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 3823d7d52a9..088e7eefbae 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -13,7 +13,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - MultDiscreteTensorSpec, + MultiDiscreteTensorSpec, MultOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, @@ -70,7 +70,7 @@ def _gym_to_torchrl_spec_transform( return BinaryDiscreteTensorSpec(spec.n, device=device) elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): return ( - MultDiscreteTensorSpec(spec.nvec, device=device) + MultiDiscreteTensorSpec(spec.nvec, device=device) if categorical_action_encoding else MultOneHotDiscreteTensorSpec(spec.nvec, device=device) )