diff --git a/test/test_shared.py b/test/test_shared.py index f93adcaa90b..f28c4d81d4a 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -59,7 +59,9 @@ def test_shared(self, indexing_method): td = tensordict.clone().share_memory_() if indexing_method == 0: subtd = TensorDict( - source={key: item[0] for key, item in td.items()}, batch_size=[] + source={key: item[0] for key, item in td.items()}, + batch_size=[], + _is_shared=True, ) elif indexing_method == 1: subtd = td.get_sub_tensordict(0)