From 781c1ba151f17fe769dd3a067a59c74d4f9d5559 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Wed, 4 Jan 2023 15:40:04 +0100 Subject: [PATCH 01/11] Support multidimensional nvec for MultiDiscreteTensorSpec --- test/test_tensor_spec.py | 11 ++-- torchrl/data/tensor_specs.py | 103 ++++++++++++++++++++++------------- 2 files changed, 73 insertions(+), 41 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 84367a7c4e2..e8de9bf71e5 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -242,6 +242,7 @@ def test_mult_onehot(shape, ns): ], [5, 2, 3], [4, 5, 1, 3], + [[1, 2], [3, 4]], ], ) @pytest.mark.parametrize( @@ -258,14 +259,16 @@ def test_multi_discrete(shape, ns): np.random.seed(0) ts = MultiDiscreteTensorSpec(ns) _real_shape = shape if shape is not None else [] - _len_ns = [len(ns)] if len(ns) > 1 else [] + nvec_shape = torch.tensor(ns).size() + if nvec_shape == torch.Size([1]): + nvec_shape = [] for _ in range(100): r = ts.rand(shape) assert r.shape == torch.Size( [ *_real_shape, - *_len_ns, + *nvec_shape, ] ) assert ts.is_in(r) @@ -273,7 +276,7 @@ def test_multi_discrete(shape, ns): torch.Size( [ *_real_shape, - *_len_ns, + *nvec_shape, ] ) ) @@ -846,7 +849,7 @@ def test_equality_multi_onehot(self, nvec): ) assert ts != ts_other - @pytest.mark.parametrize("nvec", [[3], [3, 4], [3, 4, 5]]) + @pytest.mark.parametrize("nvec", [[3], [3, 4], [3, 4, 5], [[1, 2], [3, 4]]]) def test_equality_multi_discrete(self, nvec): device = "cpu" dtype = torch.float16 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 77efa407b9e..0b8c954fbde 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import itertools from copy import deepcopy from dataclasses import dataclass from textwrap import indent @@ -159,6 +160,13 @@ def __iter__(self): def __repr__(self): return f"{self.__class__.__name__}(boxes={self.boxes})" + @staticmethod + def from_nvec(nvec: torch.Tensor): + if nvec.ndim == 0: + return DiscreteBox(nvec.item()) + else: + return BoxList([BoxList.from_nvec(n) for n in nvec]) + @dataclass(repr=False) class BinaryBox(Box): @@ -990,6 +998,11 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: return super().to_numpy(val, safe) def to_onehot(self) -> OneHotDiscreteTensorSpec: + if len(self.shape) > 1: + raise RuntimeError( + f"DiscreteTensorSpec with shape != tensor.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " + f"shape={self.shape}." + ) return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype) @@ -998,8 +1011,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): """A concatenation of discrete tensor spec. Args: - nvec (iterable of integers): cardinality of each of the elements of - the tensor. + nvec (iterable of integers or torch.Tensor): cardinality of each of the elements of + the tensor. Can have several axes. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. @@ -1014,61 +1027,77 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): def __init__( self, - nvec: Sequence[int], + nvec: Union[Sequence[int], torch.Tensor], device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, ): + if not isinstance(nvec, torch.Tensor): + nvec = torch.tensor(nvec) + self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) - self._size = len(nvec) - shape = torch.Size([self._size]) - space = BoxList([DiscreteBox(n) for n in nvec]) + shape = nvec.size() + space = BoxList.from_nvec(nvec) super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" ) + def _rand(self, space: Box, shape: torch.Size): + x = [] + for _s in space: + if isinstance(_s, BoxList): + x.append(self._rand(_s, shape)) + else: + x.append( + torch.randint( + 0, + _s.n, + torch.Size([1, *shape]), + device=self.device, + dtype=self.dtype, + ) + ) + return torch.cat(x) + def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - x = torch.cat( - [ - torch.randint( - 0, - space.n, - torch.Size([1, *shape]), - device=self.device, - dtype=self.dtype, - ) - for space in self.space - ] - ).squeeze() - _size = [self._size] if self._size > 1 else [] - return x.T.reshape([*shape, *_size]) - def _split(self, val: torch.Tensor): - return [val] if self._size < 2 else val.split(1, -1) + x = self._rand(self.space, shape) + _size = [] if self.shape == torch.Size([1]) else self.shape + return x.permute(*torch.arange(x.ndim - 1, -1, -1)).reshape([*shape, *_size]) def _project(self, val: torch.Tensor) -> torch.Tensor: if val.dtype not in (torch.int, torch.long): - val = torch.round(val) - vals = self._split(val) - return torch.cat( - [ - _val.clamp_(min=0, max=space.n - 1).unsqueeze(0) - for _val, space in zip(vals, self.space) - ], - dim=-1, - ).squeeze() + val = torch.round(val).type(self.dtype) + val = val.unsqueeze(0) + for permutation in itertools.product(*[range(d) for d in self.shape]): + val.unsqueeze(0)[[..., *permutation]] = val.unsqueeze(0)[ + [..., *permutation] + ].clamp_(min=0, max=self.nvec[permutation] - 1) + return val.squeeze(0) def is_in(self, val: torch.Tensor) -> bool: - vals = self._split(val) - return self._size == len(vals) and all( - [ - (0 <= _val).all() and (_val < space.n).all() - for _val, space in zip(vals, self.space) - ] + if val.ndim < 1: + val = val.unsqueeze(0) + val_have_wrong_dim = ( + self.shape != torch.Size([1]) + and val.shape[-len(self.shape) :] != self.shape ) + if self.dtype != val.dtype or len(self.shape) > val.ndim or val_have_wrong_dim: + return False + + for permutation in itertools.product(*[range(d) for d in self.shape]): + x = val.unsqueeze(0)[[..., *permutation]] + if not ((0 <= x).all() and (x < self.nvec[permutation]).all()): + return False + return True def to_onehot(self) -> MultOneHotDiscreteTensorSpec: + if len(self.shape) > 1: + raise RuntimeError( + f"DiscreteTensorSpec with shape != tensor.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " + f"shape={self.shape}." + ) return MultOneHotDiscreteTensorSpec( [_space.n for _space in self.space], self.device, self.dtype ) From a3ff995002608343781ecfc12bd61f87144f4cdf Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 16:16:56 +0100 Subject: [PATCH 02/11] Support int for nvec argument --- test/test_tensor_spec.py | 4 +--- torchrl/data/tensor_specs.py | 26 ++++++++++++++++---------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index e8de9bf71e5..5a1a33473f9 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -237,9 +237,7 @@ def test_mult_onehot(shape, ns): @pytest.mark.parametrize( "ns", [ - [ - 5, - ], + 5, [5, 2, 3], [4, 5, 1, 3], [[1, 2], [3, 4]], diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0b8c954fbde..63949305d14 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -504,7 +504,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return out def is_in(self, val: torch.Tensor) -> bool: - return (val.sum(-1) == 1).all() + return self.dtype == val.dtype and (val.sum(-1) == 1).all() def __eq__(self, other): return ( @@ -656,9 +656,11 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - return (val >= self.space.minimum.to(val.device)).all() and ( - val <= self.space.maximum.to(val.device) - ).all() + return ( + self.dtype == val.dtype + and (val >= self.space.minimum.to(val.device)).all() + and (val <= self.space.maximum.to(val.device)).all() + ) @dataclass(repr=False) @@ -701,7 +703,7 @@ def rand(self, shape=None) -> torch.Tensor: return torch.randn(shape, device=self.device, dtype=self.dtype) def is_in(self, val: torch.Tensor) -> bool: - return True + return self.dtype == val.dtype @dataclass(repr=False) @@ -754,7 +756,7 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return True + return self.dtype == val.dtype @dataclass(repr=False) @@ -802,7 +804,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten return tensor_to_index.gather(-1, index) def is_in(self, val: torch.Tensor) -> bool: - return ((val == 0) | (val == 1)).all() + return self.dtype == val.dtype and ((val == 0) | (val == 1)).all() @dataclass(repr=False) @@ -908,7 +910,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten def is_in(self, val: torch.Tensor) -> bool: vals = self._split(val) - return all( + return self.dtype == val.dtype and all( [super(MultOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals] ) @@ -982,7 +984,9 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val.clamp_(min=0, max=self.space.n - 1) def is_in(self, val: torch.Tensor) -> bool: - return (0 <= val).all() and (val < self.space.n).all() + return ( + self.dtype == val.dtype and (0 <= val).all() and (val < self.space.n).all() + ) def __eq__(self, other): return ( @@ -1027,12 +1031,14 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): def __init__( self, - nvec: Union[Sequence[int], torch.Tensor], + nvec: Union[Sequence[int], torch.Tensor, int], device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, ): if not isinstance(nvec, torch.Tensor): nvec = torch.tensor(nvec) + if nvec.ndim < 1: + nvec = nvec.unsqueeze(0) self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) shape = nvec.size() From 263b1b2e1fd1a803b95695fde34798d4153547fb Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 16:55:55 +0100 Subject: [PATCH 03/11] Simplify _rand by avoid permutation and reshape --- test/test_tensor_spec.py | 1 + torchrl/data/tensor_specs.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 5a1a33473f9..be3da85ed84 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -241,6 +241,7 @@ def test_mult_onehot(shape, ns): [5, 2, 3], [4, 5, 1, 3], [[1, 2], [3, 4]], + [[[2, 4], [3, 5]], [[4, 5], [2, 3]], [[2, 3], [3, 2]]], ], ) @pytest.mark.parametrize( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 63949305d14..ba22a06a6eb 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1057,20 +1057,21 @@ def _rand(self, space: Box, shape: torch.Size): torch.randint( 0, _s.n, - torch.Size([1, *shape]), + shape, device=self.device, dtype=self.dtype, ) ) - return torch.cat(x) + return torch.stack(x, -1) def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) x = self._rand(self.space, shape) - _size = [] if self.shape == torch.Size([1]) else self.shape - return x.permute(*torch.arange(x.ndim - 1, -1, -1)).reshape([*shape, *_size]) + if self.nvec.ndim > 1: + x = x.transpose(len(shape), -1) + return x.squeeze(-1) def _project(self, val: torch.Tensor) -> torch.Tensor: if val.dtype not in (torch.int, torch.long): From fb004c2d82ff58bbba70a2d955b2b512983cff17 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 17:33:23 +0100 Subject: [PATCH 04/11] Fix to_one_hot method condition --- torchrl/data/tensor_specs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ba22a06a6eb..7d781cf3b0e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1004,7 +1004,7 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: def to_onehot(self) -> OneHotDiscreteTensorSpec: if len(self.shape) > 1: raise RuntimeError( - f"DiscreteTensorSpec with shape != tensor.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " + f"DiscreteTensorSpec with shape != torch.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " f"shape={self.shape}." ) return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype) @@ -1100,10 +1100,12 @@ def is_in(self, val: torch.Tensor) -> bool: return True def to_onehot(self) -> MultOneHotDiscreteTensorSpec: - if len(self.shape) > 1: + if self.shape != torch.Size([1]): raise RuntimeError( - f"DiscreteTensorSpec with shape != tensor.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " - f"shape={self.shape}." + f"DiscreteTensorSpec with shape != torch.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " + f"shape={self.shape}. This could be accomplished via padding or nestedtensors but it is not " + f"implemented yet. If you would like to see that feature, please submit an issue of torchrl's github " + f"repo. " ) return MultOneHotDiscreteTensorSpec( [_space.n for _space in self.space], self.device, self.dtype From 3c5d4d02292188d384bfd6775d1d5dc2853efa28 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 20:00:15 +0100 Subject: [PATCH 05/11] Improve _project and is_in methods --- test/test_tensor_spec.py | 6 ++++++ torchrl/data/tensor_specs.py | 25 ++++++++++--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index be3da85ed84..0dfb89c1161 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -280,8 +280,14 @@ def test_multi_discrete(shape, ns): ) ) projection = ts._project(rand) + assert rand.shape == projection.shape assert ts.is_in(projection) + if projection.ndim < 1: + projection.fill_(-1) + else: + projection[..., 0] = -1 + assert not ts.is_in(projection) @pytest.mark.parametrize( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7d781cf3b0e..37ba7eee1c6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -6,7 +6,6 @@ from __future__ import annotations import abc -import itertools from copy import deepcopy from dataclasses import dataclass from textwrap import indent @@ -1041,7 +1040,7 @@ def __init__( nvec = nvec.unsqueeze(0) self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) - shape = nvec.size() + shape = nvec.shape space = BoxList.from_nvec(nvec) super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" @@ -1074,14 +1073,14 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: return x.squeeze(-1) def _project(self, val: torch.Tensor) -> torch.Tensor: - if val.dtype not in (torch.int, torch.long): - val = torch.round(val).type(self.dtype) - val = val.unsqueeze(0) - for permutation in itertools.product(*[range(d) for d in self.shape]): - val.unsqueeze(0)[[..., *permutation]] = val.unsqueeze(0)[ - [..., *permutation] - ].clamp_(min=0, max=self.nvec[permutation] - 1) - return val.squeeze(0) + val_is_scalar = val.ndim < 1 + if val_is_scalar: + val = val.unsqueeze(0) + if not self.dtype.is_floating_point: + val = torch.round(val) + val = val.type(self.dtype) + val[val >= self.nvec] = self.nvec.expand_as(val)[val >= self.nvec] - 1 + return val.squeeze(0) if val_is_scalar else val def is_in(self, val: torch.Tensor) -> bool: if val.ndim < 1: @@ -1093,11 +1092,7 @@ def is_in(self, val: torch.Tensor) -> bool: if self.dtype != val.dtype or len(self.shape) > val.ndim or val_have_wrong_dim: return False - for permutation in itertools.product(*[range(d) for d in self.shape]): - x = val.unsqueeze(0)[[..., *permutation]] - if not ((0 <= x).all() and (x < self.nvec[permutation]).all()): - return False - return True + return ((val >= torch.zeros(self.nvec.size())) & (val < self.nvec)).all().item() def to_onehot(self) -> MultOneHotDiscreteTensorSpec: if self.shape != torch.Size([1]): From a4effbb20456ecc9e6815b68bc9102ad1b4d1d0f Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 20:02:57 +0100 Subject: [PATCH 06/11] Update tests --- test/test_tensor_spec.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 0dfb89c1161..a87024ad09f 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -259,8 +259,6 @@ def test_multi_discrete(shape, ns): ts = MultiDiscreteTensorSpec(ns) _real_shape = shape if shape is not None else [] nvec_shape = torch.tensor(ns).size() - if nvec_shape == torch.Size([1]): - nvec_shape = [] for _ in range(100): r = ts.rand(shape) From 7c9aca12124791fbcac955e0081f42c0cea4966c Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 20:25:19 +0100 Subject: [PATCH 07/11] Support multi dtype --- test/test_tensor_spec.py | 5 +++-- torchrl/data/tensor_specs.py | 26 ++++++++++++-------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index a87024ad09f..722e03b9328 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -253,10 +253,11 @@ def test_mult_onehot(shape, ns): torch.Size([4, 5]), ], ) -def test_multi_discrete(shape, ns): +@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.long]) +def test_multi_discrete(shape, ns, dtype): torch.manual_seed(0) np.random.seed(0) - ts = MultiDiscreteTensorSpec(ns) + ts = MultiDiscreteTensorSpec(ns, dtype=dtype) _real_shape = shape if shape is not None else [] nvec_shape = torch.tensor(ns).size() for _ in range(100): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 37ba7eee1c6..8e953a08cdd 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -503,7 +503,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return out def is_in(self, val: torch.Tensor) -> bool: - return self.dtype == val.dtype and (val.sum(-1) == 1).all() + return (val.sum(-1) == 1).all() def __eq__(self, other): return ( @@ -655,11 +655,9 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - return ( - self.dtype == val.dtype - and (val >= self.space.minimum.to(val.device)).all() - and (val <= self.space.maximum.to(val.device)).all() - ) + return (val >= self.space.minimum.to(val.device)).all() and ( + val <= self.space.maximum.to(val.device) + ).all() @dataclass(repr=False) @@ -702,7 +700,7 @@ def rand(self, shape=None) -> torch.Tensor: return torch.randn(shape, device=self.device, dtype=self.dtype) def is_in(self, val: torch.Tensor) -> bool: - return self.dtype == val.dtype + return True @dataclass(repr=False) @@ -755,7 +753,7 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return self.dtype == val.dtype + return True @dataclass(repr=False) @@ -803,7 +801,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten return tensor_to_index.gather(-1, index) def is_in(self, val: torch.Tensor) -> bool: - return self.dtype == val.dtype and ((val == 0) | (val == 1)).all() + return ((val == 0) | (val == 1)).all() @dataclass(repr=False) @@ -909,7 +907,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten def is_in(self, val: torch.Tensor) -> bool: vals = self._split(val) - return self.dtype == val.dtype and all( + return all( [super(MultOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals] ) @@ -983,9 +981,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val.clamp_(min=0, max=self.space.n - 1) def is_in(self, val: torch.Tensor) -> bool: - return ( - self.dtype == val.dtype and (0 <= val).all() and (val < self.space.n).all() - ) + return (0 <= val).all() and (val < self.space.n).all() def __eq__(self, other): return ( @@ -1079,7 +1075,9 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: if not self.dtype.is_floating_point: val = torch.round(val) val = val.type(self.dtype) - val[val >= self.nvec] = self.nvec.expand_as(val)[val >= self.nvec] - 1 + val[val >= self.nvec] = (self.nvec.expand_as(val)[val >= self.nvec] - 1).type( + self.dtype + ) return val.squeeze(0) if val_is_scalar else val def is_in(self, val: torch.Tensor) -> bool: From 7082960a65b4d48084ed859257f6a8baea9241a5 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 22:23:28 +0100 Subject: [PATCH 08/11] Fix tests --- torchrl/data/tensor_specs.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 8e953a08cdd..60259900074 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1093,12 +1093,12 @@ def is_in(self, val: torch.Tensor) -> bool: return ((val >= torch.zeros(self.nvec.size())) & (val < self.nvec)).all().item() def to_onehot(self) -> MultOneHotDiscreteTensorSpec: - if self.shape != torch.Size([1]): + if len(self.shape) > 1: raise RuntimeError( - f"DiscreteTensorSpec with shape != torch.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " - f"shape={self.shape}. This could be accomplished via padding or nestedtensors but it is not " - f"implemented yet. If you would like to see that feature, please submit an issue of torchrl's github " - f"repo. " + f"DiscreteTensorSpec with shape that has several dimensions can't be converted " + f"OneHotDiscreteTensorSpec. Got shape={self.shape}. This could be accomplished via padding or " + f"nestedtensors but it is not implemented yet. If you would like to see that feature, please submit " + f"an issue of torchrl's github repo. " ) return MultOneHotDiscreteTensorSpec( [_space.n for _space in self.space], self.device, self.dtype From 62e324dc99299a7bf700283042373c7a37cb0d6e Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Thu, 5 Jan 2023 23:46:37 +0100 Subject: [PATCH 09/11] Fix docs --- torchrl/data/tensor_specs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 60259900074..9fd891e3ae2 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -999,8 +999,8 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: def to_onehot(self) -> OneHotDiscreteTensorSpec: if len(self.shape) > 1: raise RuntimeError( - f"DiscreteTensorSpec with shape != torch.Size([1]) can't be converted OneHotDiscreteTensorSpec. Got " - f"shape={self.shape}." + f"DiscreteTensorSpec with shape that has several dimensions can't be converted to " + f"OneHotDiscreteTensorSpec. Got shape={self.shape}." ) return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype) @@ -1095,7 +1095,7 @@ def is_in(self, val: torch.Tensor) -> bool: def to_onehot(self) -> MultOneHotDiscreteTensorSpec: if len(self.shape) > 1: raise RuntimeError( - f"DiscreteTensorSpec with shape that has several dimensions can't be converted " + f"DiscreteTensorSpec with shape that has several dimensions can't be converted to" f"OneHotDiscreteTensorSpec. Got shape={self.shape}. This could be accomplished via padding or " f"nestedtensors but it is not implemented yet. If you would like to see that feature, please submit " f"an issue of torchrl's github repo. " From f9b5ac138785ab8befb9237e0d00e084e5e91e4e Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 6 Jan 2023 20:41:44 +0100 Subject: [PATCH 10/11] Fix the _rand method --- torchrl/data/tensor_specs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9fd891e3ae2..d0bad50de46 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1042,11 +1042,11 @@ def __init__( shape, space, device, dtype, domain="discrete" ) - def _rand(self, space: Box, shape: torch.Size): + def _rand(self, space: Box, shape: torch.Size, i: int): x = [] for _s in space: if isinstance(_s, BoxList): - x.append(self._rand(_s, shape)) + x.append(self._rand(_s, shape, i - 1)) else: x.append( torch.randint( @@ -1057,16 +1057,15 @@ def _rand(self, space: Box, shape: torch.Size): dtype=self.dtype, ) ) - return torch.stack(x, -1) + return torch.stack(x, -i) def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - - x = self._rand(self.space, shape) - if self.nvec.ndim > 1: - x = x.transpose(len(shape), -1) - return x.squeeze(-1) + x = self._rand(self.space, shape, self.nvec.ndim) + if self.shape == torch.Size([1]): + x = x.squeeze(-1) + return x def _project(self, val: torch.Tensor) -> torch.Tensor: val_is_scalar = val.ndim < 1 From f0fd64a8c9aa096fc3076c58d590dfd553b31b66 Mon Sep 17 00:00:00 2001 From: Waris Radji Date: Fri, 6 Jan 2023 22:53:48 +0100 Subject: [PATCH 11/11] empty