Skip to content

Commit 5c98749

Browse files
authored
[BugFix] Fix compilation of TC with non-tensor + batch-size + device (#1337)
1 parent a4a0ef2 commit 5c98749

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

tensordict/tensorclass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,6 @@ def wrapper(
12751275
new_params.append(
12761276
inspect.Parameter("names", inspect.Parameter.KEYWORD_ONLY, default=None)
12771277
)
1278-
12791278
wrapper.__signature__ = init_sig.replace(parameters=params + new_params)
12801279

12811280
return wrapper
@@ -2284,7 +2283,9 @@ def set_tensor(
22842283
self._non_tensordict[key] = value
22852284
return self
22862285
if non_tensor:
2287-
value = NonTensorData(value)
2286+
value = NonTensorData(
2287+
value, batch_size=self.batch_size, device=self.device
2288+
)
22882289
if key in self._non_tensordict:
22892290
del self._non_tensordict[key]
22902291
# Avoiding key clash, honoring the user input to assign tensor type data to the key

test/test_compile.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
from tensordict.nn.functional_modules import _exclude_td_from_pytree
3838

39+
from tensordict.tensorclass import TensorClass
40+
3941
from torch.utils._pytree import SUPPORTED_NODES, tree_map
4042

4143
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
@@ -1207,6 +1209,56 @@ def test_td_input_non_tdmodule_nontensor(self, compiled):
12071209
func(torch.zeros(()), 2.0)
12081210

12091211

1212+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
1213+
class TestCompileNontensor:
1214+
# Same issue with the decorator @tensorclass version
1215+
@pytest.fixture(scope="class")
1216+
def data(self):
1217+
return torch.zeros((4, 3), device="cuda")
1218+
1219+
class TensorClassWithNonTensorData(TensorClass["nocast"]):
1220+
tensor: torch.Tensor
1221+
non_tensor_data: int
1222+
1223+
def fn_no_device_no_batch_size(self, data):
1224+
a = self.TensorClassWithNonTensorData(tensor=data, non_tensor_data=1)
1225+
return a.tensor
1226+
1227+
def fn_no_device(self, data):
1228+
a = self.TensorClassWithNonTensorData(
1229+
tensor=data, non_tensor_data=1, batch_size=[4]
1230+
)
1231+
return a.tensor
1232+
1233+
def fn_with_device(self, data):
1234+
a = self.TensorClassWithNonTensorData(
1235+
tensor=data, non_tensor_data=1, batch_size=[4], device="cuda"
1236+
)
1237+
return a.tensor
1238+
1239+
def fn_with_device_without_batch_size(self, data):
1240+
a = self.TensorClassWithNonTensorData(
1241+
tensor=data, non_tensor_data=1, device="cuda"
1242+
)
1243+
return a.tensor
1244+
1245+
def test_nontensor_no_device_no_batch_size(self, data):
1246+
torch._dynamo.reset_code_caches()
1247+
torch.compile(self.fn_no_device_no_batch_size)(data)
1248+
1249+
def test_nontensor_no_device(self, data):
1250+
torch._dynamo.reset_code_caches()
1251+
torch.compile(self.fn_no_device)(data)
1252+
1253+
def test_nontensor_with_device(self, data):
1254+
torch._dynamo.reset_code_caches()
1255+
torch.compile(self.fn_with_device)(data)
1256+
1257+
def test_nontensor_with_device_without_batch_size(self, data):
1258+
torch._dynamo.reset_code_caches()
1259+
torch.compile(self.fn_with_device_without_batch_size)(data)
1260+
1261+
12101262
if __name__ == "__main__":
12111263
args, unknown = argparse.ArgumentParser().parse_known_args()
12121264
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)