diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 84367a7c4e2..722e03b9328 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -237,11 +237,11 @@ def test_mult_onehot(shape, ns): @pytest.mark.parametrize( "ns", [ - [ - 5, - ], + 5, [5, 2, 3], [4, 5, 1, 3], + [[1, 2], [3, 4]], + [[[2, 4], [3, 5]], [[4, 5], [2, 3]], [[2, 3], [3, 2]]], ], ) @pytest.mark.parametrize( @@ -253,19 +253,20 @@ def test_mult_onehot(shape, ns): torch.Size([4, 5]), ], ) -def test_multi_discrete(shape, ns): +@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.long]) +def test_multi_discrete(shape, ns, dtype): torch.manual_seed(0) np.random.seed(0) - ts = MultiDiscreteTensorSpec(ns) + ts = MultiDiscreteTensorSpec(ns, dtype=dtype) _real_shape = shape if shape is not None else [] - _len_ns = [len(ns)] if len(ns) > 1 else [] + nvec_shape = torch.tensor(ns).size() for _ in range(100): r = ts.rand(shape) assert r.shape == torch.Size( [ *_real_shape, - *_len_ns, + *nvec_shape, ] ) assert ts.is_in(r) @@ -273,13 +274,19 @@ def test_multi_discrete(shape, ns): torch.Size( [ *_real_shape, - *_len_ns, + *nvec_shape, ] ) ) projection = ts._project(rand) + assert rand.shape == projection.shape assert ts.is_in(projection) + if projection.ndim < 1: + projection.fill_(-1) + else: + projection[..., 0] = -1 + assert not ts.is_in(projection) @pytest.mark.parametrize( @@ -846,7 +853,7 @@ def test_equality_multi_onehot(self, nvec): ) assert ts != ts_other - @pytest.mark.parametrize("nvec", [[3], [3, 4], [3, 4, 5]]) + @pytest.mark.parametrize("nvec", [[3], [3, 4], [3, 4, 5], [[1, 2], [3, 4]]]) def test_equality_multi_discrete(self, nvec): device = "cpu" dtype = torch.float16 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 77efa407b9e..d0bad50de46 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -159,6 +159,13 @@ def __iter__(self): def __repr__(self): return f"{self.__class__.__name__}(boxes={self.boxes})" + @staticmethod + def from_nvec(nvec: torch.Tensor): + if nvec.ndim == 0: + return DiscreteBox(nvec.item()) + else: + return BoxList([BoxList.from_nvec(n) for n in nvec]) + @dataclass(repr=False) class BinaryBox(Box): @@ -990,6 +997,11 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: return super().to_numpy(val, safe) def to_onehot(self) -> OneHotDiscreteTensorSpec: + if len(self.shape) > 1: + raise RuntimeError( + f"DiscreteTensorSpec with shape that has several dimensions can't be converted to " + f"OneHotDiscreteTensorSpec. Got shape={self.shape}." + ) return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype) @@ -998,8 +1010,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): """A concatenation of discrete tensor spec. Args: - nvec (iterable of integers): cardinality of each of the elements of - the tensor. + nvec (iterable of integers or torch.Tensor): cardinality of each of the elements of + the tensor. Can have several axes. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. @@ -1014,61 +1026,79 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): def __init__( self, - nvec: Sequence[int], + nvec: Union[Sequence[int], torch.Tensor, int], device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, ): + if not isinstance(nvec, torch.Tensor): + nvec = torch.tensor(nvec) + if nvec.ndim < 1: + nvec = nvec.unsqueeze(0) + self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) - self._size = len(nvec) - shape = torch.Size([self._size]) - space = BoxList([DiscreteBox(n) for n in nvec]) + shape = nvec.shape + space = BoxList.from_nvec(nvec) super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" ) + def _rand(self, space: Box, shape: torch.Size, i: int): + x = [] + for _s in space: + if isinstance(_s, BoxList): + x.append(self._rand(_s, shape, i - 1)) + else: + x.append( + torch.randint( + 0, + _s.n, + shape, + device=self.device, + dtype=self.dtype, + ) + ) + return torch.stack(x, -i) + def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = torch.Size([]) - x = torch.cat( - [ - torch.randint( - 0, - space.n, - torch.Size([1, *shape]), - device=self.device, - dtype=self.dtype, - ) - for space in self.space - ] - ).squeeze() - _size = [self._size] if self._size > 1 else [] - return x.T.reshape([*shape, *_size]) - - def _split(self, val: torch.Tensor): - return [val] if self._size < 2 else val.split(1, -1) + x = self._rand(self.space, shape, self.nvec.ndim) + if self.shape == torch.Size([1]): + x = x.squeeze(-1) + return x def _project(self, val: torch.Tensor) -> torch.Tensor: - if val.dtype not in (torch.int, torch.long): + val_is_scalar = val.ndim < 1 + if val_is_scalar: + val = val.unsqueeze(0) + if not self.dtype.is_floating_point: val = torch.round(val) - vals = self._split(val) - return torch.cat( - [ - _val.clamp_(min=0, max=space.n - 1).unsqueeze(0) - for _val, space in zip(vals, self.space) - ], - dim=-1, - ).squeeze() + val = val.type(self.dtype) + val[val >= self.nvec] = (self.nvec.expand_as(val)[val >= self.nvec] - 1).type( + self.dtype + ) + return val.squeeze(0) if val_is_scalar else val def is_in(self, val: torch.Tensor) -> bool: - vals = self._split(val) - return self._size == len(vals) and all( - [ - (0 <= _val).all() and (_val < space.n).all() - for _val, space in zip(vals, self.space) - ] + if val.ndim < 1: + val = val.unsqueeze(0) + val_have_wrong_dim = ( + self.shape != torch.Size([1]) + and val.shape[-len(self.shape) :] != self.shape ) + if self.dtype != val.dtype or len(self.shape) > val.ndim or val_have_wrong_dim: + return False + + return ((val >= torch.zeros(self.nvec.size())) & (val < self.nvec)).all().item() def to_onehot(self) -> MultOneHotDiscreteTensorSpec: + if len(self.shape) > 1: + raise RuntimeError( + f"DiscreteTensorSpec with shape that has several dimensions can't be converted to" + f"OneHotDiscreteTensorSpec. Got shape={self.shape}. This could be accomplished via padding or " + f"nestedtensors but it is not implemented yet. If you would like to see that feature, please submit " + f"an issue of torchrl's github repo. " + ) return MultOneHotDiscreteTensorSpec( [_space.n for _space in self.space], self.device, self.dtype )