From 1ff3193b574c11a80d1d53b61e8f39382e58aa35 Mon Sep 17 00:00:00 2001 From: Naman Kumar Date: Wed, 21 Feb 2024 23:43:42 +0530 Subject: [PATCH 1/4] Checking for dtype to all tensorspec is_in funcs --- torchrl/data/tensor_specs.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index efe928856b9..9b05bf81774 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1456,7 +1456,7 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(shape, val.shape) mask_expand = self.mask.expand(shape) gathered = mask_expand & val - return gathered.any(-1).all() + return gathered.any(-1).all() and val.dtype == self.dtype def __eq__(self, other): if not hasattr(other, "mask"): @@ -1765,7 +1765,7 @@ def is_in(self, val: torch.Tensor) -> bool: try: return (val >= self.space.low.to(val.device)).all() and ( val <= self.space.high.to(val.device) - ).all() + ).all() and (val.dtype == self.dtype) except RuntimeError as err: if "The size of tensor a" in str(err): warnings.warn(f"Got a shape mismatch: {str(err)}") @@ -1894,7 +1894,7 @@ def rand(self, shape=None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - return True + return val.dtype == self.dtype def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2034,7 +2034,7 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return True + return val.dtype == self.dtype def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2333,7 +2333,7 @@ def is_in(self, val: torch.Tensor) -> bool: vals = self._split(val) if vals is None: return False - return all(spec.is_in(val) for val, spec in zip(vals, self._split_self())) + return all(spec.is_in(val) for val, spec in zip(vals, self._split_self())) and (val.dtype == self.dtype) def _project(self, val: torch.Tensor) -> torch.Tensor: vals = self._split(val) @@ -2589,7 +2589,7 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) - return gathered.all() + return gathered.all() and (val.dtype == self.dtype) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" @@ -3585,7 +3585,7 @@ def is_in(self, val: Union[dict, TensorDictBase]) -> bool: val_item = val.get(key) if not item.is_in(val_item): return False - return True + return val.dtype == self.dtype def project(self, val: TensorDictBase) -> TensorDictBase: for key, item in self.items(): @@ -4107,7 +4107,7 @@ def is_in(self, val) -> bool: for spec, subval in zip(self._specs, val.unbind(self.dim)): if not spec.is_in(subval): return False - return True + return val.dtype == self.dtype def __delitem__(self, key: NestedKey): """Deletes a key from the stacked composite spec. From 47051e481290077f211a1c8a84cba4a066922cc1 Mon Sep 17 00:00:00 2001 From: Naman Kumar Date: Wed, 28 Feb 2024 20:50:45 +0530 Subject: [PATCH 2/4] Checks for shape for is_in functions if applicable --- torchrl/data/tensor_specs.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9b05bf81774..664be480230 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1456,7 +1456,7 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(shape, val.shape) mask_expand = self.mask.expand(shape) gathered = mask_expand & val - return gathered.any(-1).all() and val.dtype == self.dtype + return gathered.any(-1).all() and val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape def __eq__(self, other): if not hasattr(other, "mask"): @@ -1765,7 +1765,7 @@ def is_in(self, val: torch.Tensor) -> bool: try: return (val >= self.space.low.to(val.device)).all() and ( val <= self.space.high.to(val.device) - ).all() and (val.dtype == self.dtype) + ).all() and (val.dtype == self.dtype) and (val.shape[-len(self.shape):] == self.shape) except RuntimeError as err: if "The size of tensor a" in str(err): warnings.warn(f"Got a shape mismatch: {str(err)}") @@ -1894,7 +1894,7 @@ def rand(self, shape=None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - return val.dtype == self.dtype + return val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2034,7 +2034,7 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return val.dtype == self.dtype + return val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2333,7 +2333,7 @@ def is_in(self, val: torch.Tensor) -> bool: vals = self._split(val) if vals is None: return False - return all(spec.is_in(val) for val, spec in zip(vals, self._split_self())) and (val.dtype == self.dtype) + return all(spec.is_in(val) for val, spec in zip(vals, self._split_self())) def _project(self, val: torch.Tensor) -> torch.Tensor: vals = self._split(val) @@ -2589,7 +2589,7 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) - return gathered.all() and (val.dtype == self.dtype) + return gathered.all() and (val.dtype == self.dtype) and (val.shape[-len(self.shape):] == self.shape) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" @@ -3585,7 +3585,7 @@ def is_in(self, val: Union[dict, TensorDictBase]) -> bool: val_item = val.get(key) if not item.is_in(val_item): return False - return val.dtype == self.dtype + return True def project(self, val: TensorDictBase) -> TensorDictBase: for key, item in self.items(): @@ -4107,7 +4107,7 @@ def is_in(self, val) -> bool: for spec, subval in zip(self._specs, val.unbind(self.dim)): if not spec.is_in(subval): return False - return val.dtype == self.dtype + return True def __delitem__(self, key: NestedKey): """Deletes a key from the stacked composite spec. From 7626742043a298b62d15654c4f69b5f211781092 Mon Sep 17 00:00:00 2001 From: Naman Kumar Date: Wed, 28 Feb 2024 23:52:04 +0530 Subject: [PATCH 3/4] Fixed small bug in tensor_specs.py --- torchrl/data/tensor_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 664be480230..7742d3cca7e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1456,7 +1456,7 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(shape, val.shape) mask_expand = self.mask.expand(shape) gathered = mask_expand & val - return gathered.any(-1).all() and val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape + return (gathered.sum(-1) == 1).all() and val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape def __eq__(self, other): if not hasattr(other, "mask"): From b1352fbaad549ec933217f96d9533589235d0ed5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 28 Feb 2024 17:29:28 -0500 Subject: [PATCH 4/4] lint --- torchrl/data/tensor_specs.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7742d3cca7e..3dd95a1284b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1456,7 +1456,11 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.broadcast_shapes(shape, val.shape) mask_expand = self.mask.expand(shape) gathered = mask_expand & val - return (gathered.sum(-1) == 1).all() and val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape + return ( + (gathered.sum(-1) == 1).all() + and val.dtype == self.dtype + and val.shape[-len(self.shape) :] == self.shape + ) def __eq__(self, other): if not hasattr(other, "mask"): @@ -1763,9 +1767,12 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: def is_in(self, val: torch.Tensor) -> bool: try: - return (val >= self.space.low.to(val.device)).all() and ( - val <= self.space.high.to(val.device) - ).all() and (val.dtype == self.dtype) and (val.shape[-len(self.shape):] == self.shape) + return ( + (val >= self.space.low.to(val.device)).all() + and (val <= self.space.high.to(val.device)).all() + and (val.dtype == self.dtype) + and (val.shape[-len(self.shape) :] == self.shape) + ) except RuntimeError as err: if "The size of tensor a" in str(err): warnings.warn(f"Got a shape mismatch: {str(err)}") @@ -1894,7 +1901,7 @@ def rand(self, shape=None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - return val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape + return val.dtype == self.dtype and val.shape[-len(self.shape) :] == self.shape def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2034,7 +2041,7 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return val.dtype == self.dtype and val.shape[-len(self.shape):] == self.shape + return val.dtype == self.dtype and val.shape[-len(self.shape) :] == self.shape def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2589,7 +2596,11 @@ def is_in(self, val: torch.Tensor) -> bool: shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) - return gathered.all() and (val.dtype == self.dtype) and (val.shape[-len(self.shape):] == self.shape) + return ( + gathered.all() + and (val.dtype == self.dtype) + and (val.shape[-len(self.shape) :] == self.shape) + ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index."""