diff --git a/tensordict/_td.py b/tensordict/_td.py index 57529590d..73dd4d800 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2579,7 +2579,7 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool): ignore_lock=True, ) is_diff = dest[idx].tolist() != value.tolist() - if is_diff: + if is_diff.any(): dest_val = dest.maybe_to_stack() dest_val[idx] = value if dest_val is not dest: