|
36 | 36 |
|
37 | 37 | from tensordict.nn.functional_modules import _exclude_td_from_pytree |
38 | 38 |
|
| 39 | +from tensordict.tensorclass import TensorClass |
| 40 | + |
39 | 41 | from torch.utils._pytree import SUPPORTED_NODES, tree_map |
40 | 42 |
|
41 | 43 | TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) |
@@ -1207,6 +1209,56 @@ def test_td_input_non_tdmodule_nontensor(self, compiled): |
1207 | 1209 | func(torch.zeros(()), 2.0) |
1208 | 1210 |
|
1209 | 1211 |
|
| 1212 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") |
| 1213 | +class TestCompileNontensor: |
| 1214 | + # Same issue with the decorator @tensorclass version |
| 1215 | + @pytest.fixture(scope="class") |
| 1216 | + def data(self): |
| 1217 | + return torch.zeros((4, 3), device="cuda") |
| 1218 | + |
| 1219 | + class TensorClassWithNonTensorData(TensorClass["nocast"]): |
| 1220 | + tensor: torch.Tensor |
| 1221 | + non_tensor_data: int |
| 1222 | + |
| 1223 | + def fn_no_device_no_batch_size(self, data): |
| 1224 | + a = self.TensorClassWithNonTensorData(tensor=data, non_tensor_data=1) |
| 1225 | + return a.tensor |
| 1226 | + |
| 1227 | + def fn_no_device(self, data): |
| 1228 | + a = self.TensorClassWithNonTensorData( |
| 1229 | + tensor=data, non_tensor_data=1, batch_size=[4] |
| 1230 | + ) |
| 1231 | + return a.tensor |
| 1232 | + |
| 1233 | + def fn_with_device(self, data): |
| 1234 | + a = self.TensorClassWithNonTensorData( |
| 1235 | + tensor=data, non_tensor_data=1, batch_size=[4], device="cuda" |
| 1236 | + ) |
| 1237 | + return a.tensor |
| 1238 | + |
| 1239 | + def fn_with_device_without_batch_size(self, data): |
| 1240 | + a = self.TensorClassWithNonTensorData( |
| 1241 | + tensor=data, non_tensor_data=1, device="cuda" |
| 1242 | + ) |
| 1243 | + return a.tensor |
| 1244 | + |
| 1245 | + def test_nontensor_no_device_no_batch_size(self, data): |
| 1246 | + torch._dynamo.reset_code_caches() |
| 1247 | + torch.compile(self.fn_no_device_no_batch_size)(data) |
| 1248 | + |
| 1249 | + def test_nontensor_no_device(self, data): |
| 1250 | + torch._dynamo.reset_code_caches() |
| 1251 | + torch.compile(self.fn_no_device)(data) |
| 1252 | + |
| 1253 | + def test_nontensor_with_device(self, data): |
| 1254 | + torch._dynamo.reset_code_caches() |
| 1255 | + torch.compile(self.fn_with_device)(data) |
| 1256 | + |
| 1257 | + def test_nontensor_with_device_without_batch_size(self, data): |
| 1258 | + torch._dynamo.reset_code_caches() |
| 1259 | + torch.compile(self.fn_with_device_without_batch_size)(data) |
| 1260 | + |
| 1261 | + |
1210 | 1262 | if __name__ == "__main__": |
1211 | 1263 | args, unknown = argparse.ArgumentParser().parse_known_args() |
1212 | 1264 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
0 commit comments