@@ -2232,6 +2232,34 @@ def test_requires_grad(self, device):
22322232 # First stacked tensor has requires_grad == True
22332233 assert list (stacked_td .values ())[0 ].requires_grad is True
22342234
2235+ def test_refine_names_setitem_subtd (self ):
2236+ batch_size = 1
2237+ seq_len = 2
2238+ n_agents = 3
2239+ td = TensorDict (
2240+ {
2241+ "agents" : TensorDict (
2242+ {
2243+ "obs" : torch .zeros ((batch_size , seq_len , n_agents , 5 )),
2244+ "dones" : torch .zeros ((batch_size , seq_len , n_agents , 1 )),
2245+ },
2246+ batch_size = (batch_size , seq_len , n_agents ),
2247+ names = [None , "time" , "other" ],
2248+ ),
2249+ "dones" : torch .zeros ((batch_size , seq_len )),
2250+ },
2251+ batch_size = (batch_size , seq_len ),
2252+ names = [None , "time" ],
2253+ )
2254+ #
2255+ td ["agents" ] = td ["agents" ].repeat_interleave (2 , dim = - 1 )
2256+ assert len (td ["agents" ].names ) == 3
2257+ assert td ["agents" ].names [- 1 ] == "other"
2258+ td ["agents" ] = td ["agents" ].repeat (1 , 1 , 2 )
2259+ assert td ["agents" ].names [- 1 ] == "other"
2260+ td ["agents" ] = torch .cat ((td ["agents" ], td ["agents" ]), dim = 2 )
2261+ assert td ["agents" ].names [- 1 ] == "other"
2262+
22352263 def test_rename_key_nested (self ):
22362264 td = TensorDict (a = {"b" : {"c" : 0 }})
22372265 td .rename_key_ (("a" , "b" , "c" ), ("a" , "b" ))
0 commit comments