diff --git a/test/export/test_export.py b/test/export/test_export.py index dcb8c5923d93b..221ea9ba075b5 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1693,6 +1693,34 @@ def f(): self.assertEqual(a.size(), torch.Size([3, 4])) self.assertEqual(b.size(), torch.Size([3, 4])) + def test_export_then_compile_tensor_ctor(self): + class M(torch.nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, scores, mask): + scores = scores.masked_fill( + mask, torch.tensor(torch.finfo(scores.dtype).min) + ) # (bs, n_heads, q_length, k_length) + return scores + + tensor_cpu = torch.randn(2, 4) + mask_cpu = torch.BoolTensor( + [[False, True, False, False], + [False, False, False, False]] + ) + + m = M().eval() + # res_ref = m(tensor_cpu, mask_cpu) + # print("res_ref is: {}".format(res_ref), flush=True) + + exported_model = capture_pre_autograd_graph( + m, + (tensor_cpu, mask_cpu), + ) + optimized_model = torch.compile(exported_model) + optimized_model(tensor_cpu, mask_cpu) + if __name__ == '__main__': run_tests() diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 901396750d3f9..b36bc4c5bf8be 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1465,11 +1465,14 @@ def maybe_to_constant(t): return t # To constant propagate through these functions: - # 1, If this is a lift, the input tensor is guaranteed to be a + # 1, If this is a lift due to a torch.tensor call, + # the input tensor is guaranteed to be a # constant, so we keep a copy of the original argument along so - # we can query it if we're asked to item() it at some later point + # we can query it if we're asked to item() it at some later point. + # (Note that you can always call a lift fn manually, so we do + # have to check if there are any fake tensors!) # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div - if func in self.lift_fns or ( + if (func in self.lift_fns and not flat_arg_fake_tensors) or ( should_allow_numbers_as_tensors(func) and not has_symbolic_sizes and not flat_arg_fake_tensors @@ -1509,11 +1512,10 @@ def maybe_to_constant(t): # this is generated from torch.tensor(), which does not use the # dispatcher, to allow wrapper subclasses to wrap the new tensor if func in self.lift_fns: - assert ( - len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor - ), f"{args} {kwargs}" + assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}" - return converter(self, args[0]) + if type(args[0]) is torch.Tensor: + return converter(self, args[0]) # Recompute flat_arg_fake_tensors here again in case some of the inputs # were real tensors and fakified in validate_and_convert_non_fake_tensors