Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,11 @@ def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(shape, val.shape)
mask_expand = self.mask.expand(shape)
gathered = mask_expand & val
return gathered.any(-1).all()
return (
(gathered.sum(-1) == 1).all()
and val.dtype == self.dtype
and val.shape[-len(self.shape) :] == self.shape
)

def __eq__(self, other):
if not hasattr(other, "mask"):
Expand Down Expand Up @@ -1763,9 +1767,12 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:

def is_in(self, val: torch.Tensor) -> bool:
try:
return (val >= self.space.low.to(val.device)).all() and (
val <= self.space.high.to(val.device)
).all()
return (
(val >= self.space.low.to(val.device)).all()
and (val <= self.space.high.to(val.device)).all()
and (val.dtype == self.dtype)
and (val.shape[-len(self.shape) :] == self.shape)
)
except RuntimeError as err:
if "The size of tensor a" in str(err):
warnings.warn(f"Got a shape mismatch: {str(err)}")
Expand Down Expand Up @@ -1894,7 +1901,7 @@ def rand(self, shape=None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
return True
return val.dtype == self.dtype and val.shape[-len(self.shape) :] == self.shape

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
Expand Down Expand Up @@ -2034,7 +2041,7 @@ def rand(self, shape=None) -> torch.Tensor:
return r.to(self.device)

def is_in(self, val: torch.Tensor) -> bool:
return True
return val.dtype == self.dtype and val.shape[-len(self.shape) :] == self.shape

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
Expand Down Expand Up @@ -2589,7 +2596,11 @@ def is_in(self, val: torch.Tensor) -> bool:
shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
mask_expand = self.mask.expand(shape)
gathered = mask_expand.gather(-1, val.unsqueeze(-1))
return gathered.all()
return (
gathered.all()
and (val.dtype == self.dtype)
and (val.shape[-len(self.shape) :] == self.shape)
)

def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
Expand Down