Skip to content

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Feb 23, 2024

Stack from ghstack (oldest at bottom):

Fixes #120242

The example from the issue now results in the graph

def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)

and the corresponding inductor kernel eliminates the intermediate buffer
completely

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang

Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Feb 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/120514

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 39c69f3 with merge base 953c6c3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

peterbell10 added a commit that referenced this pull request Feb 23, 2024
Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

ghstack-source-id: ba0ba28
Pull Request resolved: #120514
…mutations"

Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Feb 24, 2024
…mutations"


Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (arg0_1,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Feb 24, 2024
Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

ghstack-source-id: 37afb61
Pull Request resolved: #120514
…mutations"


Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (arg0_1,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames

[ghstack-poisoned]
…mutations"


Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames

[ghstack-poisoned]
…mutations"


Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Feb 26, 2024
Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

ghstack-source-id: 2ab5b66
Pull Request resolved: #120514
@peterbell10 peterbell10 marked this pull request as ready for review February 27, 2024 00:32
# necessary.
if get_node_storage(node) in output_storages and (
get_node_storage(src) in input_storages
or get_node_storage(src) in output_storages
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This previous logic banned all noop eliminations where the src and node storages are inputs or outputs, but this is only problematic if the storages weren't expected to alias. In the failing test I saw we had node = aten.slice(argn, ...) where argn was an input and output to the graph because of this change. The slice op itself was not returned, so eliminating the view is not an issue.

This also generalizes further, we might have a view of a view where the second view is returned but it's still safe to eliminate the first view op because that tensor is not returned directly.

Copy link
Collaborator

@lezcano lezcano Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check, node_storage != src_storage is equivalent to op not being a view, right? If so, could you either leave a comment or use a variable node_is_view = node_storage == src_storage?

if copy_node == user:
# Ignore uses after the copy_ epilogue node, where the input
# has already been mutated anyway
if copy_node_loc is not None and copy_node_loc <= user_loc:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is required because the output node on the graph counts as a user, so was preventing reinplacing on mutated inputs.

copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
if copy_node is not None:
graph.erase_node(copy_node)
replace_dict[copy_node] = copy_node.args[0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is needed because when we have an inplace mutation on a tensor x, make_fx replaces all future references to x with the inplace node. This means the aten.copy_ op now has a user that needs to be updated.

@Chillee Chillee requested a review from oulgen February 27, 2024 00:55
@ezyang ezyang removed their request for review February 27, 2024 04:47
@lezcano
Copy link
Collaborator

lezcano commented Mar 4, 2024

ping @bdhirsh

@ezyang
Copy link
Contributor

ezyang commented Mar 5, 2024

@bdhirsh is on vacation for two weeks.

@lezcano
Copy link
Collaborator

lezcano commented Mar 6, 2024

Let's just wait for @bdhirsh to be back then, as this issue is not blocking anything.

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2024

I'm 👍 the AOTAutograd changes. But the Inductor passes also need reviewing. I can bug @Chillee / @oulgen to look at it in person, or in a pinch I can review them too.

Copy link
Contributor

@oulgen oulgen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reinplace changes look good to me

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 8, 2024
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict gh/peterbell10/697/orig returned non-zero exit code 1

Rebasing (1/1)
Auto-merging test/functorch/test_aotdispatch.py
Auto-merging torch/_functorch/_aot_autograd/traced_function_transforms.py
Auto-merging torch/_inductor/fx_passes/post_grad.py
Auto-merging torch/_inductor/fx_passes/reinplace.py
CONFLICT (content): Merge conflict in torch/_inductor/fx_passes/reinplace.py
error: could not apply 8a18644291f... [AOTDispatch] Return mutated inputs directly when keeping mutations
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply 8a18644291f... [AOTDispatch] Return mutated inputs directly when keeping mutations

Raised by https://github.com/pytorch/pytorch/actions/runs/8203733237

…mutations"


Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Mar 8, 2024
Fixes #120242

The example from the issue now results in the graph
```python
def forward(self, arg0_1, arg1_1):
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
    copy_ = torch.ops.aten.copy_.default(arg1_1, sin);  arg1_1 = sin = None
    return (copy_,)
```

and the corresponding inductor kernel eliminates the intermediate buffer
completely

```python
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        # Source Nodes: [sin], Original ATen: [aten.sin]
        stream0 = get_raw_stream(0)
        triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0)
        del arg0_1
    return (arg1_1, )
```

ghstack-source-id: eee04a3
Pull Request resolved: #120514
@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@peterbell10 peterbell10 added the topic: not user facing topic category label Mar 8, 2024
@github-actions github-actions bot deleted the gh/peterbell10/697/head branch April 8, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants