Skip to content

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Aug 26, 2024

We say Node a is fusible into node b if node b is an auto_functionalized
node that may reinplace node a later on.

Fixes #134468

Test Plan:
- new test

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Aug 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134490

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f071437 with merge base 2553278 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zou3519 zou3519 added keep-going Don't stop on first failure, keep running tests until the end ci-no-td Do not run TD on this PR labels Aug 26, 2024
…functionalized"

We say Node a is fusible into node b if node b is an auto_functionalized
node that may reinplace node a later on.

This PR also changes aten.empty to be recomputable w.r.t the Partitioner
(it is, like aten.zeros, cheap to recompute and fusible into other ops).

Fixes #134468

Test Plan:
- new test

[ghstack-poisoned]
Comment on lines 1624 to 1625
if node.target is torch.ops.higher_order.auto_functionalized:
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto_functionalized nodes can have out = blah arguments. Those are not mutable (the auto_functionalized node is always functional), so we fix that here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I'll split this out into its own PR

@zou3519 zou3519 requested a review from Chillee August 26, 2024 23:03
arg = b.kwargs[name]
if a is arg:
return True
if isinstance(arg, list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a tree_map or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only support TensorList args as inputs to operators, so the list is fine

subtest(torch.empty_like, name="empty_like"),
],
)
def test_partitioner_recomputes_factory(self, factory_op):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally like test_perf tests, since they're "property-based" (i.e. they measure the amount of bytes read and written). It's possible we don't support HOPs yet or something though.

Copy link
Contributor Author

@zou3519 zou3519 Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add some test_perf tests too (if possible)

Comment on lines 1624 to 1625
if node.target is torch.ops.higher_order.auto_functionalized:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated?

aten.full,
aten.as_strided,
aten.zeros,
aten.empty,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think plausibly aten.empty should be treated even more specially than this. Since aten.empty is always free to recompute, regardless of whether or not it's fusible into a downstream op.

…functionalized"

We say Node a is fusible into node b if node b is an auto_functionalized
node that may reinplace node a later on.

This PR also changes aten.empty to be recomputable w.r.t the Partitioner
(it is, like aten.zeros, cheap to recompute and fusible into other ops).

Fixes #134468

Test Plan:
- new test

[ghstack-poisoned]
@zou3519 zou3519 added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 27, 2024
@zou3519
Copy link
Contributor Author

zou3519 commented Aug 27, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 27, 2024
…134491)

mutated arguments to triton kernels are fusible into the triton kernel.

Test Plan:
- new test

Pull Request resolved: #134491
Approved by: https://github.com/Chillee
ghstack dependencies: #134364, #134466, #134490
pytorchmergebot pushed a commit that referenced this pull request Aug 28, 2024
ROCM doesn't trigger the layout optimization that makes the test case
valid so we're going to skip the checks.

Should fix the following (I'll close them later)
- #134481
- #134519

Pull Request resolved: #134690
Approved by: https://github.com/FindHao
ghstack dependencies: #134466, #134490, #134491
pytorchmergebot pushed a commit that referenced this pull request Aug 28, 2024
Fixes #134119
From user feedback, it's difficult to understand what the tests do. We
clarify the docs more.
Pull Request resolved: #134692
Approved by: https://github.com/albanD
ghstack dependencies: #134466, #134490, #134491, #134690
pytorchmergebot pushed a commit that referenced this pull request Aug 28, 2024
Fixes #134278

Test Plan:
- tested locally
Pull Request resolved: #134688
Approved by: https://github.com/yushangdi
ghstack dependencies: #134466, #134490, #134491, #134690, #134692
pytorchmergebot pushed a commit that referenced this pull request Aug 29, 2024
aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

@torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
```

- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
  both tensors at the same time)
- the clone lowers into a new buffer + a copy_ kernel
- the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
  buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
  helped).

Pull Request resolved: #134703
Approved by: https://github.com/yf225
ghstack dependencies: #134466, #134490, #134491
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…zed (pytorch#134490)

We say Node a is fusible into node b if node b is an auto_functionalized
node that may reinplace node a later on.

This PR also changes aten.empty to be recomputable w.r.t the Partitioner
(it is, like aten.zeros, cheap to recompute and fusible into other ops).

Fixes pytorch#134468

Test Plan:
- new test

Pull Request resolved: pytorch#134490
Approved by: https://github.com/Chillee
ghstack dependencies: pytorch#134364, pytorch#134466
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…ytorch#134491)

mutated arguments to triton kernels are fusible into the triton kernel.

Test Plan:
- new test

Pull Request resolved: pytorch#134491
Approved by: https://github.com/Chillee
ghstack dependencies: pytorch#134364, pytorch#134466, pytorch#134490
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
ROCM doesn't trigger the layout optimization that makes the test case
valid so we're going to skip the checks.

Should fix the following (I'll close them later)
- pytorch#134481
- pytorch#134519

Pull Request resolved: pytorch#134690
Approved by: https://github.com/FindHao
ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Fixes pytorch#134119
From user feedback, it's difficult to understand what the tests do. We
clarify the docs more.
Pull Request resolved: pytorch#134692
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491, pytorch#134690
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

@torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
```

- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
  both tensors at the same time)
- the clone lowers into a new buffer + a copy_ kernel
- the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
  buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
  helped).

Pull Request resolved: pytorch#134703
Approved by: https://github.com/yf225
ghstack dependencies: pytorch#134466, pytorch#134490, pytorch#134491
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: inductor release notes: inductor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants