Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent pattern matches across mutation ops in Inductor pre-grad FX passes #101124

Closed
williamwen42 opened this issue May 10, 2023 · 1 comment
Closed
Assignees
Labels
bug oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@williamwen42
Copy link
Member

williamwen42 commented May 10, 2023

The inductor pre-grad FX passes (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pre_grad.py) are unsafe with respect to mutation. We need to discard matches that cross a mutation op.

Example repro:

    def test_match_with_mutation(self):
        from torch._inductor.pattern_matcher import (
            CallFunction,
            KeywordArg,
            PatternMatcherPass,
            register_graph_pattern,
        )

        counter = 0
        test_pass = PatternMatcherPass()

        @register_graph_pattern(
            CallFunction(
                torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
            ),
            pass_dict=test_pass,
        )
        def _test(match, x):
            nonlocal counter
            counter += 1

        def fn(x, y):
            a = torch.sin(x)
            x.copy_(y)
            b = torch.add(x, a)
            return b

        args = [
            torch.randn(5, 5, device="cuda"),
            torch.randn(5, 5, device="cuda"),
        ]

        with unittest.mock.patch(
            "torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
        ):
            expected = fn(*args)
            actual = torch.compile(fn)(*args)
            # should not match
            self.assertEqual(counter, 0)
            torch.testing.assert_close(actual, expected)       

reveals that pre-grad FX passes currently match across mutation ops.

To run, add it to a test file (e.g. to test/inductor/test_pattern_matcher.py) and run as any other unittest.

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305

@ezyang
Copy link
Contributor

ezyang commented May 11, 2023

oh this is literally the thing we were talking about today @bdhirsh

williamwen42 added a commit that referenced this issue May 18, 2023
…ion ops in inductor pre-grad FX passes"


Per #101124

cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue May 18, 2023
…tor pre-grad FX passes"


Per #101124

cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue May 18, 2023
…ion ops in inductor pre-grad FX passes"


Per #101124

cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue May 18, 2023
…tor pre-grad FX passes"


Per #101124

cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
@Chillee Chillee added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants