Skip to content

Conversation

shunting314
Copy link
Contributor

@shunting314 shunting314 commented Nov 7, 2024

Stack from ghstack (oldest at bottom):

I recently added a new pattern here #139136 to remove pointless view/permute pairs. At that PR, I've already updated the matched pattern/node count in test_linear_binary to account for the new pattern. But it looks like with cpp wrapper, one more pattern will be matched.

7 patterns without cpp-wrapper:

========== pattern matched <code object pointless_view at 0x7f6d25c67aa0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view_pair at 0x7f6d25c67b50, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.p
y", line 581> =======
========== pattern matched <code object pointless_view at 0x7f6d25c67aa0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view at 0x7f6d25c67aa0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object linear at 0x7f6d176e5dc0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 11
21> =======
========== pattern matched <code object reshape_linear_reshape_pattern at 0x7f6d176e5210, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mk
ldnn_fusion.py", line 732> =======
========== pattern matched <code object fn at 0x7f6d176d3ec0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 476> =
======

8 patterns with cpp wrapper:
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view_pair at 0x7f8e78bf0870, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.p
y", line 581> =======
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object linear at 0x7f8e59c04190, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 11
21> =======
========== pattern matched <code object reshape_linear_reshape_pattern at 0x7f8e59dfb520, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mk
ldnn_fusion.py", line 732> =======
========== pattern matched <code object fn at 0x7f8e59dfa290, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 476> =
======

I fixed this test by +1 to the expected number if cpp wrapper is enabled. But I think fundamentally can we not assert for the total number of patterns matched in the test? I think that makes the test very fragile. People adding new patterns may keep breaking these 'un-related' tests. One possible way to improve is, we have a counter for each specific pattern, in the tests, instead of check the total number of patterns matched, just check the match count for the RELEVANT patterns. That should reduce false-positive for broken tests. cc possible test creator @jgong5

Fixes #139812 (we need to have this to run this disabled test on your PR)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

Copy link

pytorch-bot bot commented Nov 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139942

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit b4a6dbf with merge base 8f077b8 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Nov 7, 2024
Copy link
Contributor

@huydhn huydhn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix!

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Nov 7, 2024

One possible way to improve is, we have a counter for each specific pattern, in the tests, instead of check the total number of patterns matched, just check the match count for the RELEVANT patterns. That should reduce false-positive for broken tests.

I think we have added specific count for oneDNN quantization pattern matcher such as:

counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
So, we probably also need to add some other specific count for the other oneDNN pattern matchers.

Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we probably also need to add some other specific count for the other oneDNN pattern matchers.

@leslie-fang-intel Can we add a BE task to revise existing inductor cpp tests?

@leslie-fang-intel
Copy link
Collaborator

Track this task: #139970 @Valentine233 could you help on this?

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 7, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@huydhn
Copy link
Contributor

huydhn commented Nov 8, 2024

@pytorchbot revert -m 'Sorry for revert this, but I think we miss running the test and it is now failing in trunk' -c nosignal

inductor/test_cpu_cpp_wrapper.py::TestCppWrapper::test_linear_binary_cpp_wrapper GH job link HUD commit link

It's one of those slow tests, so need ciflow/slow to run them

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@shunting314 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Nov 8, 2024
…139942)"

This reverts commit 0618c7f.

Reverted #139942 on behalf of https://github.com/huydhn due to Sorry for revert this, but I think we miss running the test and it is now failing in trunk ([comment](#139942 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Nov 8, 2024
@shunting314
Copy link
Contributor Author

Interesting, I think they pass on my dev gpu. Let me double check

@shunting314
Copy link
Contributor Author

hmm, this is just another evidence how fragile these tests are.

Reverting my PR can make those previously failed test pass now. I think maybe something recently changed in cpp-wrapper and make the different number of patterns being matched. @huydhn I'll close this PR. I can look more if you still see this failure on trunk. In a bit longer term, I think Intel folks agreed to improves these tests

@shunting314 shunting314 closed this Nov 8, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…#139942)

I recently added a new pattern here pytorch#139136 to remove pointless view/permute pairs.  At that PR, I've already updated the matched pattern/node count in `test_linear_binary` to account for the new pattern. But it looks like with cpp wrapper, one more pattern will be matched.

```
7 patterns without cpp-wrapper:

========== pattern matched <code object pointless_view at 0x7f6d25c67aa0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view_pair at 0x7f6d25c67b50, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.p
y", line 581> =======
========== pattern matched <code object pointless_view at 0x7f6d25c67aa0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view at 0x7f6d25c67aa0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object linear at 0x7f6d176e5dc0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 11
21> =======
========== pattern matched <code object reshape_linear_reshape_pattern at 0x7f6d176e5210, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mk
ldnn_fusion.py", line 732> =======
========== pattern matched <code object fn at 0x7f6d176d3ec0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 476> =
======

8 patterns with cpp wrapper:
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view_pair at 0x7f8e78bf0870, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.p
y", line 581> =======
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object pointless_view at 0x7f8e78bf07c0, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/joint_graph.py", l
ine 568> =======
========== pattern matched <code object linear at 0x7f8e59c04190, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 11
21> =======
========== pattern matched <code object reshape_linear_reshape_pattern at 0x7f8e59dfb520, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mk
ldnn_fusion.py", line 732> =======
========== pattern matched <code object fn at 0x7f8e59dfa290, file "/home/shunting/ws/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py", line 476> =
======
```

I fixed this test by +1 to the expected number if cpp wrapper is enabled. But I think fundamentally can we not assert for the total number of patterns matched in the test? I think that makes the test very fragile. People adding new patterns may keep  breaking these 'un-related' tests. One possible way to improve is, we have a counter for each specific pattern, in the tests, instead of check the total number of patterns matched, just check the match count for the ***RELEVANT*** patterns. That should reduce false-positive for broken tests.   cc possible test creator @jgong5

Fixes pytorch#139812 (we need to have this to run this disabled test on your PR)

Pull Request resolved: pytorch#139942
Approved by: https://github.com/huydhn, https://github.com/jgong5
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ytorch#139942)"

This reverts commit 0618c7f.

Reverted pytorch#139942 on behalf of https://github.com/huydhn due to Sorry for revert this, but I think we miss running the test and it is now failing in trunk ([comment](pytorch#139942 (comment)))
@github-actions github-actions bot deleted the gh/shunting314/185/head branch December 9, 2024 02:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/slow ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants