Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions test/test_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
MultiDiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
Expand Down Expand Up @@ -233,6 +234,93 @@ def test_mult_onehot(shape, ns):
assert (ts.encode(np_r) == r).all()


@pytest.mark.parametrize(
"ns",
[
[
5,
],
[5, 2, 3],
[4, 5, 1, 3],
],
)
@pytest.mark.parametrize(
"shape",
[
None,
[],
torch.Size([3]),
torch.Size([4, 5]),
],
)
def test_multi_discrete(shape, ns):
torch.manual_seed(0)
np.random.seed(0)
ts = MultiDiscreteTensorSpec(ns)
_real_shape = shape if shape is not None else []
_len_ns = [len(ns)] if len(ns) > 1 else []
for _ in range(100):
r = ts.rand(shape)

assert r.shape == torch.Size(
[
*_real_shape,
*_len_ns,
]
)
assert ts.is_in(r)
rand = torch.rand(
torch.Size(
[
*_real_shape,
*_len_ns,
]
)
)
projection = ts._project(rand)
assert rand.shape == projection.shape
assert ts.is_in(projection)


@pytest.mark.parametrize(
"n",
[
1,
4,
7,
99,
],
)
@pytest.mark.parametrize("device", get_available_devices())
def test_discrete_conversion(n, device):
categorical = DiscreteTensorSpec(n, device=device)
one_hot = OneHotDiscreteTensorSpec(n, device=device)

assert categorical != one_hot
assert categorical.to_onehot() == one_hot
assert one_hot.to_categorical() == categorical


@pytest.mark.parametrize(
"ns",
[
[
5,
],
[5, 2, 3],
[4, 5, 1, 3],
],
)
@pytest.mark.parametrize("device", get_available_devices())
def test_multi_discrete_conversion(ns, device):
categorical = MultiDiscreteTensorSpec(ns, device=device)
one_hot = MultOneHotDiscreteTensorSpec(ns, device=device)

assert categorical != one_hot
assert categorical.to_onehot() == one_hot
assert one_hot.to_categorical() == categorical


@pytest.mark.parametrize("is_complete", [True, False])
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
Expand Down Expand Up @@ -758,6 +846,41 @@ def test_equality_multi_onehot(self, nvec):
)
assert ts != ts_other

@pytest.mark.parametrize("nvec", [[3], [3, 4], [3, 4, 5]])
def test_equality_multi_discrete(self, nvec):
device = "cpu"
dtype = torch.float16

ts = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype)

ts_same = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype)
assert ts == ts_same

other_nvec = np.array(nvec) + 3
ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other

other_nvec = [12]
ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other

other_nvec = [12, 13]
ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other

ts_other = MultiDiscreteTensorSpec(nvec=nvec, device="cpu:0", dtype=dtype)
assert ts != ts_other

ts_other = MultiDiscreteTensorSpec(
nvec=nvec, device=device, dtype=torch.float64
)
assert ts != ts_other

ts_other = TestEquality._ts_make_all_fields_equal(
BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts
)
assert ts != ts_other

def test_equality_composite(self):
minimum = np.arange(12).reshape((3, 4))
maximum = minimum + 100
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CompositeSpec,
DEVICE_TYPING,
DiscreteTensorSpec,
MultiDiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
Expand Down
92 changes: 92 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,9 @@ def __eq__(self, other):
and self.use_register == other.use_register
)

def to_categorical(self) -> DiscreteTensorSpec:
return DiscreteTensorSpec(self.space.n, device=self.device, dtype=self.dtype)


@dataclass(repr=False)
class BoundedTensorSpec(TensorSpec):
Expand Down Expand Up @@ -905,6 +908,11 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
vals = self._split(val)
return torch.cat([super()._project(_val) for _val in vals], -1)

def to_categorical(self) -> MultiDiscreteTensorSpec:
return MultiDiscreteTensorSpec(
[_space.n for _space in self.space], self.device, self.dtype
)


class DiscreteTensorSpec(TensorSpec):
"""A discrete tensor spec.
Expand Down Expand Up @@ -981,6 +989,90 @@ def __eq__(self, other):
def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
return super().to_numpy(val, safe)

def to_onehot(self) -> OneHotDiscreteTensorSpec:
return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype)


@dataclass(repr=False)
class MultiDiscreteTensorSpec(DiscreteTensorSpec):
"""A concatenation of discrete tensor spec.

Args:
nvec (iterable of integers): cardinality of each of the elements of
the tensor.
device (str, int or torch.device, optional): device of
the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.

Examples:
>>> ts = MultiDiscreteTensorSpec((3,2,3))
>>> ts.is_in(torch.tensor([2, 0, 1]))
True
>>> ts.is_in(torch.tensor([2, 2, 1]))
False
"""

def __init__(
self,
nvec: Sequence[int],
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.long,
):
dtype, device = _default_dtype_and_device(dtype, device)
self._size = len(nvec)
shape = torch.Size([self._size])
space = BoxList([DiscreteBox(n) for n in nvec])
super(DiscreteTensorSpec, self).__init__(
shape, space, device, dtype, domain="discrete"
)

def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor:
if shape is None:
shape = torch.Size([])
x = torch.cat(
[
torch.randint(
0,
space.n,
torch.Size([1, *shape]),
device=self.device,
dtype=self.dtype,
)
for space in self.space
]
).squeeze()
_size = [self._size] if self._size > 1 else []
return x.T.reshape([*shape, *_size])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to the following warning if the number of dims of x is greater than 2

<string>:3: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at  /Users/distiller/project/conda/conda-bld/pytorch_1646756029501/work/aten/src/ATen/native/TensorShape.cpp:2318.)


def _split(self, val: torch.Tensor):
return [val] if self._size < 2 else val.split(1, -1)

def _project(self, val: torch.Tensor) -> torch.Tensor:
if val.dtype not in (torch.int, torch.long):
val = torch.round(val)
vals = self._split(val)
return torch.cat(
[
_val.clamp_(min=0, max=space.n - 1).unsqueeze(0)
for _val, space in zip(vals, self.space)
],
dim=-1,
).squeeze()

def is_in(self, val: torch.Tensor) -> bool:
vals = self._split(val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also check the dtype here
(note to myself: we should check that is_in always checks the dtype)

return self._size == len(vals) and all(
[
(0 <= _val).all() and (_val < space.n).all()
for _val, space in zip(vals, self.space)
]
)

def to_onehot(self) -> MultOneHotDiscreteTensorSpec:
return MultOneHotDiscreteTensorSpec(
[_space.n for _space in self.space], self.device, self.dtype
)


class CompositeSpec(TensorSpec):
"""A composition of TensorSpecs.
Expand Down
7 changes: 6 additions & 1 deletion torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
MultiDiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
Expand Down Expand Up @@ -68,7 +69,11 @@ def _gym_to_torchrl_spec_transform(
elif isinstance(spec, gym.spaces.multi_binary.MultiBinary):
return BinaryDiscreteTensorSpec(spec.n, device=device)
elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete):
return MultOneHotDiscreteTensorSpec(spec.nvec, device=device)
return (
MultiDiscreteTensorSpec(spec.nvec, device=device)
if categorical_action_encoding
else MultOneHotDiscreteTensorSpec(spec.nvec, device=device)
)
elif isinstance(spec, gym.spaces.Box):
shape = spec.shape
if not len(shape):
Expand Down