diff --git a/tensordict/_td.py b/tensordict/_td.py index d18e37e39..619c06b2b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -860,7 +860,11 @@ def __setitem__( if isinstance(value, (TensorDictBase, dict)): indexed_bs = _getitem_batch_size(self.batch_size, index) if isinstance(value, dict): - value = self.from_dict_instance(value, batch_size=indexed_bs) + value = self.from_dict_instance( + value, batch_size=indexed_bs, device=self.device + ) + elif value.device != self.device: + value = value.to(self.device) # value = self.empty(recurse=True)[index].update(value) if value.batch_size != indexed_bs: if value.shape == indexed_bs[-len(value.shape) :]: @@ -883,7 +887,7 @@ def __setitem__( for value_key, item in value.items(): if value_key in keys: self._set_at_str( - value_key, item, index, validated=False, non_blocking=False + value_key, item, index, validated=True, non_blocking=False ) else: if subtd is None: @@ -3129,6 +3133,7 @@ def _exclude( # self._maybe_set_shared_attributes(result) return result + # @cache def keys( self, include_nested: bool = False,