diff --git a/tensordict/_td.py b/tensordict/_td.py index a2e0361fa..23d1d228f 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2330,8 +2330,6 @@ def rename_key_( # these checks are not perfect, tuples that are not tuples of strings or empty # tuples could go through but (1) it will raise an error anyway and (2) # those checks are expensive when repeated often. - if old_key == new_key: - return self if not isinstance(old_key, (str, tuple)): raise TypeError( f"Expected old_name to be a string or a tuple of strings but found {type(old_key)}" @@ -2340,13 +2338,17 @@ def rename_key_( raise TypeError( f"Expected new_name to be a string or a tuple of strings but found {type(new_key)}" ) + old_key = unravel_key(old_key) + new_key = unravel_key(new_key) + if old_key == new_key: + return self if safe and (new_key in self.keys(include_nested=True)): raise KeyError(f"key {new_key} already present in TensorDict.") if isinstance(new_key, str): self._set_str( new_key, - self.get(old_key), + self.get(old_key, default=NO_DEFAULT), inplace=False, validated=True, non_blocking=False, @@ -2354,7 +2356,7 @@ def rename_key_( else: self._set_tuple( new_key, - self.get(old_key), + self.get(old_key, default=NO_DEFAULT), inplace=False, validated=True, non_blocking=False, @@ -4195,7 +4197,7 @@ def __contains__(self, key: NestedKey) -> bool: if self.leaves_only: # TODO: make this faster for LazyStacked without compromising regular return not _is_tensor_collection( - type(self.tensordict._get_str(key)) + type(self.tensordict._get_str(key, NO_DEFAULT)) ) return True return False diff --git a/tensordict/utils.py b/tensordict/utils.py index 2cb4b8d6d..67df379d8 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1500,6 +1500,8 @@ def _get_leaf_tensordict( tensordict = hook(tensordict, key) else: tensordict = tensordict.get(key[0]) + if tensordict is None: + raise KeyError key = key[1:] return tensordict, key[0]