diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index efe928856b9..3dd95a1284b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1456,7 +1456,11 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(shape, val.shape) mask_expand = self.mask.expand(shape) gathered = mask_expand & val - return gathered.any(-1).all() + return ( + (gathered.sum(-1) == 1).all() + and val.dtype == self.dtype + and val.shape[-len(self.shape) :] == self.shape + ) def __eq__(self, other): if not hasattr(other, "mask"): @@ -1763,9 +1767,12 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: def is_in(self, val: torch.Tensor) -> bool: try: - return (val >= self.space.low.to(val.device)).all() and ( - val <= self.space.high.to(val.device) - ).all() + return ( + (val >= self.space.low.to(val.device)).all() + and (val <= self.space.high.to(val.device)).all() + and (val.dtype == self.dtype) + and (val.shape[-len(self.shape) :] == self.shape) + ) except RuntimeError as err: if "The size of tensor a" in str(err): warnings.warn(f"Got a shape mismatch: {str(err)}") @@ -1894,7 +1901,7 @@ def rand(self, shape=None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - return True + return val.dtype == self.dtype and val.shape[-len(self.shape) :] == self.shape def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2034,7 +2041,7 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return True + return val.dtype == self.dtype and val.shape[-len(self.shape) :] == self.shape def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2589,7 +2596,11 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) - return gathered.all() + return ( + gathered.all() + and (val.dtype == self.dtype) + and (val.shape[-len(self.shape) :] == self.shape) + ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index."""