Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Upstream push 0627 #80355

Closed
wants to merge 13 commits into from
Closed

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jun 27, 2022

Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

  • TransformPropagator refactor: switched to Dijkstra instead of exhaustive enumeration on all possible paths to reduce compilation time on transform propagation;
  • Indexing refactor: remove reference tensor creation in all tensor indexing logic (THD #1690)
  • (more) generic grouped grid reduction kernel;
  • Minor parser/fuser patches:
    1. zero-dim tensor reduction support
    2. no-op binary removal within fused graph
    3. expand supported in fusion

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

a054b3efcf5af58ea518de283f55aaf9fe06ff5f Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (#1775)
d67e1cda9b802036841a371318014a818a849b0a Indexing refactor stage 1: remove reference tensor creation in all tensor indexing logic (#1690)
1b6529956a1ace220898ad09dde0bf85e49827f7 Issue 1770 (#1774)
35b04276b648c9b55cdb6a67f3889f54e745c3d2 Avoid compilation errors like below: (#1773)
452c77326a340d2a4130b7802f4f319aec60e72a Ignore reductions of zero-dim tensors per PyTorch conventions (#1771)
31d6c56d88afba09ac53b2d5dd3493d625f8cd57 TransformPropagator refactor (#1769)
570c5a84b91a3cf67207331be9650d26a2d37e3d Merge pull request #1767 from csarofeen/upstream_merge_0621
9d6c3d84be86da643df6fd51695543938111f20d merging upstream 61305cd638b6fcd73a0b66b4cde7014fecb9e8ce
0ed815f76b08f285bda855dd500692ff10a8abce New TransformPropagator algorithm (#1763)
6c195200c0a92fb0f38c833431a8940ed07569b9 no-op binary removal (#1764)
ec7fa4187c177186527409dfc5c7b1754d30bc92 Proper propagation of IterType (#1762)
b263562dbc3c865007ad7d7d42a58a20be8d7922 Fix dimensionality check (#1759)
2d6343f6cc1e47b63ef20a50d1446f6480736478 More generic grouped grid reduction kernel (#1740)
64e2b56df2c8b9fd22a362d9cc05974a8607ef3d [nvfuser] prevent spamming warning message (#77777) (#1758)
0c431624ff15b6458b9f9b674a3852373fc426b1 [nvFuser] Improving bitwise ops support (#77158) (#1757)
b93a14777fde3b9b39684b9cf1715651a806b281 Parser expand (#1754)

RUN_TORCHBENCH: nvfuser

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 27, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit bb73c20 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build TorchBench CI (pytorch-linux-py3.7-cu102) / run-torchbench (1/1)

Step: "Run TorchBench" (full log | diagnosis details | 🔁 rerun)

2022-07-12T05:54:14.8977292Z ##[error]Process completed with exit code 128.
b93a14777fde3b9b39684b9cf1715651a806b281 Parser expand (#1754)

RUN_TORCHBENCH: nvfuser
2022-07-12T05:54:14.8813995Z PR_BASE_SHA: d321be6
2022-07-12T05:54:14.8814335Z PR_HEAD_SHA: bb73c20
2022-07-12T05:54:14.8814632Z TORCHBENCH_BRANCH: main
2022-07-12T05:54:14.8814997Z ##[endgroup]
2022-07-12T05:54:14.8850545Z ~/pytorch ~/nvme/pytorch-org-runner/_work/pytorch/pytorch
2022-07-12T05:54:14.8957340Z fatal: Not a valid commit name bb73c20
2022-07-12T05:54:14.8977292Z ##[error]Process completed with exit code 128.
2022-07-12T05:54:14.9036372Z Post job cleanup.
2022-07-12T05:54:15.0015597Z [command]/usr/bin/git version
2022-07-12T05:54:15.0058492Z git version 2.23.3
2022-07-12T05:54:15.0093059Z [command]/usr/bin/git config --local --name-only --get-regexp core.sshCommand
2022-07-12T05:54:15.0152924Z [command]/usr/bin/git submodule foreach --recursive git config --local --name-only --get-regexp 'core.sshCommand' && git config --local --unset-all 'core.sshCommand' || :
2022-07-12T05:54:15.0646526Z Entering 'submodules/FAMBench'
2022-07-12T05:54:15.0728952Z Entering 'submodules/FAMBench/FBGEMM'
2022-07-12T05:54:15.0810986Z Entering 'submodules/FAMBench/FBGEMM/third_party/asmjit'
2022-07-12T05:54:15.0895128Z Entering 'submodules/FAMBench/FBGEMM/third_party/cpuinfo'
2022-07-12T05:54:15.0978918Z Entering 'submodules/FAMBench/FBGEMM/third_party/googletest'


</details></details>

---
<details><summary>This comment was automatically generated by <a href="https://code.facebook.com/ci/dr-ci-info/">Dr. CI</a> (expand for details).</summary>

Please report bugs/suggestions to the (internal) <a href="https://fburl.com/ujo0mikv">Dr. CI Users group</a>.
</details>Click<a href="https://our.intern.facebook.com/intern/opensource/ci/regenerate_comment/528747182265486/"> here </a> to manually regenerate this comment.
</details>
<!-- dr-ci-comment-end -->

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 27, 2022
@jjsjann123 jjsjann123 force-pushed the upstream_push_0627 branch 2 times, most recently from f48d7b3 to 5caf4a2 Compare June 27, 2022 18:44
@jjsjann123
Copy link
Collaborator Author

hmmm, the win error looks nasty. https://pipelines.actions.githubusercontent.com/serviceHosts/7d146c05-69c3-4c20-a0e7-818111670117/_apis/pipelines/1/runs/2090862/signedlogcontent/40?urlExpires=2022-06-27T17%3A32%3A49.9990256Z&urlSigningMethod=HMACV1&urlSignature=ZEbsHmbmDmbJ%2B8x8xw6QAV1HmLCDqH3R38WG5ENxSWU%3D

2022-06-27T17:12:48.3910454Z C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Tools\MSVC\14.28.29333\include\xmemory(701): error C2280: 'std::pair<const torch::jit::fuser::cuda::VectorOfUniqueEntries<torch::jit::fuser::cuda::IterDomain *,std::hash<torch::jit::fuser::cuda::IterDomain *>> *const ,torch::jit::fuser::cuda::ComputeAtMap::DoubleBufferIndicesPtr>::pair(const std::pair<const torch::jit::fuser::cuda::VectorOfUniqueEntries<torch::jit::fuser::cuda::IterDomain *,std::hash<torch::jit::fuser::cuda::IterDomain *>> *const ,torch::jit::fuser::cuda::ComputeAtMap::DoubleBufferIndicesPtr> &)': attempting to reference a deleted function

@jjsjann123 jjsjann123 added ciflow/slow ciflow/trunk Trigger trunk jobs on your pull request labels Jun 30, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 30, 2022

We have recently simplified the CIFlow labels and ciflow/slow is no longer in use.
You can use any of the following

  • ciflow/trunk (.github/workflows/trunk.yml): all jobs we run per-commit on master
  • ciflow/periodic (.github/workflows/periodic.yml): all jobs we run periodically on master
  • ciflow/android (.github/workflows/run_android_tests.yml): android build and test
  • ciflow/nightly (.github/workflows/nightly.yml): all jobs we run nightly
  • ciflow/binaries: all binary build and upload jobs
  • ciflow/binaries_conda: binary build and upload job for conda
  • ciflow/binaries_libtorch: binary build and upload job for libtorch
  • ciflow/binaries_wheel: binary build and upload job for wheel

@jjsjann123
Copy link
Collaborator Author

hmmmm. expand is patched here: csarofeen#1790.
I'm waiting on CI to finish before cherry-picking that and reverting the expand disablement.

@jjsjann123 jjsjann123 marked this pull request as ready for review June 30, 2022 20:59
@jjsjann123
Copy link
Collaborator Author

@davidberard98 is there something wrong with the torchbench? The log seems to be complaining about commit name
fatal: Not a valid commit name 1fa9cac858f5cdcdd12b8c460a62a221a0aa752f

@davidberard98
Copy link
Contributor

unfortunately I think you don't have permissions to run torchbench on PRs, I can see if I can get it to run

@davidberard98
Copy link
Contributor

@jjsjann123 can you comment on expand enablement here, do you think there's any risk there? Since expand is a view op, and we had issues with view ops last time?

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor

@jjsjann123 we probably need another rebase in order to skip over the macos-11-py3-x86-64 test failure

@jjsjann123
Copy link
Collaborator Author

Apparently viable/strict is too old.. I'll rebase again....

@davidberard98
Copy link
Contributor

@jjsjann123 looks like viable/strict is only 5 hrs old on https://hud.pytorch.org/metrics ?

@jjsjann123
Copy link
Collaborator Author

I saw these warning:
::error::Your PR is based on a version of master that is too old for our CI to work. Please rebase your PR on latest master and resubmit.

But looking at the full log, I also noticed this vvv. So it looked like a real thing. Let me fix it.

/var/lib/jenkins/workspace/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h:40:10:error: ‘struct torch::jit::fuser::cuda::MaxInfoSpanningTree::Selector’ has virtual functions and accessible non-virtual destructor [-Werror=non-virtual-dtor]
   40 |   struct Selector {
      |          ^~~~~~~~
/var/lib/jenkins/workspace/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h:46:10:error: ‘struct torch::jit::fuser::cuda::MaxInfoSpanningTree::Propagator’ has virtual functions and accessible non-virtual destructor [-Werror=non-virtual-dtor]
   46 |   struct Propagator {
      |          ^~~~~~~~~~
In file included from /var/lib/jenkins/workspace/torch/csrc/jit/codegen/cuda/compute_at.cpp:9:
/var/lib/jenkins/workspace/torch/csrc/jit/codegen/cuda/transform_replay.h:159:25:error: base class ‘struct torch::jit::fuser::cuda::MaxInfoSpanningTree::Propagator’ has accessible non-virtual destructor [-Werror=non-virtual-dtor]
  159 | class TORCH_CUDA_CU_API TransformPropagator
      |                         ^~~~~~~~~~~~~~~~~~~
/var/lib/jenkins/workspace/torch/csrc/jit/codegen/cuda/transform_replay.h:159:25:error: ‘class torch::jit::fuser::cuda::TransformPropagator’ has virtual functions and accessible non-virtual destructor [-Werror=non-virtual-dtor]

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jjsjann123
Copy link
Collaborator Author

This one looks promising. @davidberard98 should we start importing it?

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jjsjann123
Copy link
Collaborator Author

bumping threads~

@davidberard98
Copy link
Contributor

@jjsjann123 5-6% regression on autogen-40 and autogen-43, see https://github.com/pytorch/pytorch/runs/7223997510?check_suite_focus=true

Not sure if this is significant or just random, but they do both look like they are pretty similar benchmarks (native_batch_norm + pointwise)

https://github.com/pytorch/benchmark/blob/main/userbenchmark/nvfuser/ir.py

Also depending on your judgement of whether or not this regression is okay could you rebase one more time?

@jjsjann123
Copy link
Collaborator Author

@jjsjann123 5-6% regression on autogen-40 and autogen-43, see https://github.com/pytorch/pytorch/runs/7223997510?check_suite_focus=true

Not sure if this is significant or just random, but they do both look like they are pretty similar benchmarks (native_batch_norm + pointwise)

https://github.com/pytorch/benchmark/blob/main/userbenchmark/nvfuser/ir.py

Also depending on your judgement of whether or not this regression is okay could you rebase one more time?

Neat!

I mean, regression is bad, but it's nice that we are catching them with easy repros! Let met see if I can repro them locally and I can open issues internally to track it, so our next perf tuning would hopefully fix those.
^^^ I think that'll be our approach for minor regression moving forward.

@jjsjann123
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

Successfully rebased upstream_push_0627 onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout upstream_push_0627 && git pull --rebase)

@jjsjann123
Copy link
Collaborator Author

XLA failure seems unrelated. Is this one good to be merged? @davidberard98

re: regression on microbenchmark, haven't forgot about that, but it would be easier for me to repro after the merge (comparing perf across a single commit is easier) 😉 Will do it after the merge.

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! But I'm going to merge it via the internal workflow, so don't use pytorchbot to merge.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link

Hey @jjsjann123.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jul 13, 2022
Summary:
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- TransformPropagator refactor: switched to Dijkstra instead of exhaustive enumeration on all possible paths to reduce compilation time on transform propagation;
- Indexing refactor: remove reference tensor creation in all tensor indexing logic (#1690)
- (more) generic grouped grid reduction kernel;
- Minor parser/fuser patches:
  1. zero-dim tensor reduction support
  3. no-op binary removal within fused graph
  4. expand supported in fusion

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
a054b3e Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (#1775)
d67e1cd Indexing refactor stage 1: remove reference tensor creation in all tensor indexing logic (#1690)
1b65299 Issue 1770 (#1774)
35b0427 Avoid compilation errors like below: (#1773)
452c773 Ignore reductions of zero-dim tensors per PyTorch conventions (#1771)
31d6c56 TransformPropagator refactor (#1769)
570c5a8 Merge pull request #1767 from csarofeen/upstream_merge_0621
9d6c3d8 merging upstream 61305cd
0ed815f New TransformPropagator algorithm (#1763)
6c19520 no-op binary removal (#1764)
ec7fa41 Proper propagation of IterType (#1762)
b263562 Fix dimensionality check (#1759)
2d6343f More generic grouped grid reduction kernel (#1740)
64e2b56 [nvfuser] prevent spamming warning message (#77777) (#1758)
0c43162 [nvFuser] Improving bitwise ops support (#77158) (#1757)
b93a147 Parser expand (#1754)
```

RUN_TORCHBENCH: nvfuser

Pull Request resolved: #80355

Reviewed By: qihqi

Differential Revision: D37573400

Pulled By: davidberard98

fbshipit-source-id: 52ab68d89ec01ef61f69f5abeb18c9d3a312aa64
@jjsjann123 jjsjann123 deleted the upstream_push_0627 branch July 14, 2022 06:32
@jjsjann123
Copy link
Collaborator Author

@jjsjann123 5-6% regression on autogen-40 and autogen-43, see https://github.com/pytorch/pytorch/runs/7223997510?check_suite_focus=true

Not sure if this is significant or just random, but they do both look like they are pretty similar benchmarks (native_batch_norm + pointwise)

https://github.com/pytorch/benchmark/blob/main/userbenchmark/nvfuser/ir.py

Also depending on your judgement of whether or not this regression is okay could you rebase one more time?

FYI, running the microbenchmark locally seems to indicate that benchmark results are a little bit flaky.
I don't see real issues or changes to generated code.

Looks like our kernel time varies a little bit, which is not surprising for a small kernel running < 20 us:

kernel1 run in 0.016384 ms, achieved: 612.813 GB/s      total time: 0.016384     total buffer:10040320
kernel1 run in 0.017408 ms, achieved: 576.765 GB/s      total time: 0.017408     total buffer:10040320
kernel1 run in 0.016384 ms, achieved: 612.813 GB/s      total time: 0.016384     total buffer:10040320
kernel1 run in 0.017408 ms, achieved: 576.765 GB/s      total time: 0.017408     total buffer:10040320
kernel1 run in 0.016384 ms, achieved: 612.813 GB/s      total time: 0.016384     total buffer:10040320
kernel1 run in 0.017408 ms, achieved: 576.765 GB/s      total time: 0.017408     total buffer:10040320
kernel1 run in 0.018432 ms, achieved: 544.722 GB/s      total time: 0.018432     total buffer:10040320
kernel1 run in 0.022528 ms, achieved: 445.682 GB/s      total time: 0.022528     total buffer:10040320
kernel1 run in 0.017408 ms, achieved: 576.765 GB/s      total time: 0.017408     total buffer:10040320
kernel1 run in 0.017408 ms, achieved: 576.765 GB/s      total time: 0.017408     total buffer:10040320
kernel1 run in 0.017408 ms, achieved: 576.765 GB/s      total time: 0.017408     total buffer:10040320

Also, autogen-40 & autogen-43 are just PW kernels, (even though there's batch_norm in it, it's running in inference mode.

@davidberard98
Copy link
Contributor

@jjsjann123 thanks for investigating! Figured it was worth flagging since both microbenchmarks looked similar and both were performing badly... but if your tests don't repro any issue or any kernel changes then it's probably fine.

jjsjann123 added a commit to csarofeen/pytorch that referenced this pull request Jul 18, 2022
upstream fixes cherry-picked from pytorch#80355
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- TransformPropagator refactor: switched to Dijkstra instead of exhaustive enumeration on all possible paths to reduce compilation time on transform propagation;
- Indexing refactor: remove reference tensor creation in all tensor indexing logic (#1690)
- (more) generic grouped grid reduction kernel;
- Minor parser/fuser patches:
  1. zero-dim tensor reduction support
  3. no-op binary removal within fused graph
  4. expand supported in fusion

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
a054b3efcf5af58ea518de283f55aaf9fe06ff5f Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (#1775)
d67e1cda9b802036841a371318014a818a849b0a Indexing refactor stage 1: remove reference tensor creation in all tensor indexing logic (#1690)
1b6529956a1ace220898ad09dde0bf85e49827f7 Issue 1770 (#1774)
35b04276b648c9b55cdb6a67f3889f54e745c3d2 Avoid compilation errors like below: (#1773)
452c77326a340d2a4130b7802f4f319aec60e72a Ignore reductions of zero-dim tensors per PyTorch conventions (#1771)
31d6c56d88afba09ac53b2d5dd3493d625f8cd57 TransformPropagator refactor (#1769)
570c5a84b91a3cf67207331be9650d26a2d37e3d Merge pull request #1767 from csarofeen/upstream_merge_0621
9d6c3d84be86da643df6fd51695543938111f20d merging upstream 61305cd638b6fcd73a0b66b4cde7014fecb9e8ce
0ed815f76b08f285bda855dd500692ff10a8abce New TransformPropagator algorithm (#1763)
6c195200c0a92fb0f38c833431a8940ed07569b9 no-op binary removal (#1764)
ec7fa4187c177186527409dfc5c7b1754d30bc92 Proper propagation of IterType (#1762)
b263562dbc3c865007ad7d7d42a58a20be8d7922 Fix dimensionality check (#1759)
2d6343f6cc1e47b63ef20a50d1446f6480736478 More generic grouped grid reduction kernel (#1740)
64e2b56df2c8b9fd22a362d9cc05974a8607ef3d [nvfuser] prevent spamming warning message (#77777) (#1758)
0c431624ff15b6458b9f9b674a3852373fc426b1 [nvFuser] Improving bitwise ops support (#77158) (#1757)
b93a14777fde3b9b39684b9cf1715651a806b281 Parser expand (#1754)
```

RUN_TORCHBENCH: nvfuser
Pull Request resolved: pytorch/pytorch#80355
Approved by: https://github.com/davidberard98
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- TransformPropagator refactor: switched to Dijkstra instead of exhaustive enumeration on all possible paths to reduce compilation time on transform propagation;
- Indexing refactor: remove reference tensor creation in all tensor indexing logic (#1690)
- (more) generic grouped grid reduction kernel;
- Minor parser/fuser patches:
  1. zero-dim tensor reduction support
  3. no-op binary removal within fused graph
  4. expand supported in fusion

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
2848a5f Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (#1775)
97d3b84 Indexing refactor stage 1: remove reference tensor creation in all tensor indexing logic (#1690)
c8b4f42 Issue 1770 (#1774)
35b04276b648c9b55cdb6a67f3889f54e745c3d2 Avoid compilation errors like below: (#1773)
0773c33 Ignore reductions of zero-dim tensors per PyTorch conventions (#1771)
074d078 TransformPropagator refactor (#1769)
3e9637b Merge pull request #1767 from csarofeen/upstream_merge_0621
690900a merging upstream 61305cd638b6fcd73a0b66b4cde7014fecb9e8ce
86fc20a New TransformPropagator algorithm (#1763)
073e521 no-op binary removal (#1764)
dfaca9a Proper propagation of IterType (#1762)
4bc0e6b Fix dimensionality check (#1759)
7ec1263 More generic grouped grid reduction kernel (#1740)
bf9c6c6 [nvfuser] prevent spamming warning message (#77777) (#1758)
6f631a8 [nvFuser] Improving bitwise ops support (#77158) (#1757)
921efc8 Parser expand (#1754)
```

RUN_TORCHBENCH: nvfuser
Pull Request resolved: pytorch/pytorch#80355
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants