From 84656cd65dbaa3f1e06aaf94ff6e0d79c5c3527f Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 19 Apr 2024 14:15:28 -0400 Subject: [PATCH] [NJT] Inline through torch.nested.nested_tensor_from_jagged instead of graph break (#124343) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124343 Approved by: https://github.com/jbschlosser --- test/dynamo/test_subclasses.py | 8 ++++++++ test/profiler/test_profiler.py | 1 + torch/_dynamo/trace_rules.py | 1 + 3 files changed, 10 insertions(+) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 387b6bf59b1a..8005d6e3a28c 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1361,6 +1361,14 @@ def fn(x): self._check_recompiles(fn, (nt,), (nt2,), False) self._check_recompiles(fn, (nt,), (nt3,), True) + def test_inline_nested_tensor_from_jagged(self): + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + + def fn(x): + return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets()) + + torch.compile(fn, fullgraph=True, backend="aot_eager")(nt) + def _get_views(self): # Test all cases with both an NT base and a dense base # Subclass -> Subclass diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 8e4e31718d9c..e149ea379b80 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -634,6 +634,7 @@ def test_execution_trace_no_capture(self): found_root_node = True assert found_root_node + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500") def test_execution_trace_nested_tensor(self): fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 763f6482cb92..daeb8626c109 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -173,6 +173,7 @@ "torch.nn.Parameter": TorchInGraphFunctionVariable, "torch._nested_tensor_from_mask": SkipFunctionVariable, "torch._nested_from_padded": SkipFunctionVariable, + "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable,