Skip to content

Commit

Permalink
[NJT] Inline through torch.nested.nested_tensor_from_jagged instead o…
Browse files Browse the repository at this point in the history
…f graph break (#124343)

Pull Request resolved: #124343
Approved by: https://github.com/jbschlosser
  • Loading branch information
soulitzer authored and pytorchmergebot committed Apr 18, 2024
1 parent bbb6e36 commit ef93402
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,14 @@ def fn(x):
self._check_recompiles(fn, (nt,), (nt2,), False)
self._check_recompiles(fn, (nt,), (nt3,), True)

def test_inline_nested_tensor_from_jagged(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)

def fn(x):
return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())

torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)

def _get_views(self):
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
"torch.nn.Parameter": TorchInGraphFunctionVariable,
"torch._nested_tensor_from_mask": SkipFunctionVariable,
"torch._nested_from_padded": SkipFunctionVariable,
"torch.nested.nested_tensor_from_jagged": UserFunctionVariable,
# symbol operators implemented in Python
"torch.sym_not": TorchInGraphFunctionVariable,
"torch.sym_float": TorchInGraphFunctionVariable,
Expand Down

1 comment on commit ef93402

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #124343 on behalf of https://github.com/DanilBaibak due to Broken trunk (comment)

Please sign in to comment.