Skip to content

Commit

Permalink
[NJT] Inline through torch.nested.nested_tensor_from_jagged instead o…
Browse files Browse the repository at this point in the history
…f graph break (#124343)

Pull Request resolved: #124343
Approved by: https://github.com/jbschlosser
  • Loading branch information
soulitzer authored and pytorchmergebot committed Apr 19, 2024
1 parent acbf888 commit cf5ca58
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,14 @@ def fn(x):
self._check_recompiles(fn, (nt,), (nt2,), False)
self._check_recompiles(fn, (nt,), (nt3,), True)

def test_inline_nested_tensor_from_jagged(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)

def fn(x):
return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())

torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)

def _get_views(self):
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass
Expand Down
1 change: 1 addition & 0 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ def test_execution_trace_no_capture(self):
found_root_node = True
assert found_root_node

@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
def test_execution_trace_nested_tensor(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
"torch.nn.Parameter": TorchInGraphFunctionVariable,
"torch._nested_tensor_from_mask": SkipFunctionVariable,
"torch._nested_from_padded": SkipFunctionVariable,
"torch.nested.nested_tensor_from_jagged": UserFunctionVariable,
# symbol operators implemented in Python
"torch.sym_not": TorchInGraphFunctionVariable,
"torch.sym_float": TorchInGraphFunctionVariable,
Expand Down

0 comments on commit cf5ca58

Please sign in to comment.