Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BFloat16 dtype support for oneDNN Graph JIT fuser #85591

Closed
wants to merge 12 commits into from

Conversation

sanchitintel
Copy link
Collaborator

@sanchitintel sanchitintel commented Sep 24, 2022

BFloat16 dtype support for faster inference with TorchScript using oneDNN Graph

Intel Xeon Cooper Lake platform & beyond support the AVX512_BF16 ISA, which is essentially native BFloat16 support.
oneDNN Graph delivers high inference performance with BFloat16 on such machines.

While oneDNN Graph can still be used with BFloat16 on older machines that lack avx512_bf16 ISA but support avx512bw, avx512vl & avx512dq ISAs, the BF16 performance on these older machines will be significantly poorer (probably even poorer than Float32), as they lack native BF16 support.

Currently, AMP support for eager mode & JIT mode is divergent in PyTorch.
So, for using oneDNN Graph with BFloat16, eager-mode AMP should be leveraged by turning off AMP for JIT mode, using torch._C._jit_set_autocast_mode(False) in python code, so as to avoid conflicts.

Please use the following environment variable to view JIT logs -
PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface"

Changes being made in this PR

  1. This PR does NOT change the oneDNN commit or the ideep files. While the ideep commit is being updated, only files pertaining to oneDNN Graph are being updated. oneDNN Graph is being upgraded to version 0.5.2 (alpha patch release 2).
    To put things into perspective, ideep is a git submodule of PyTorch. oneDNN Graph is a git submodule of ideep (ideep/mkl-dnn), and oneDNN is a git submodule of oneDNN Graph (ideep/mkl-dnn/third_party/oneDNN).
  2. Unit-tests are being updated. We now use the existing dtypes decorator.
  3. Suggestions made by @eellison in the FP32 PR are being incorporated/addressed -
Action-item Status
checkInputCompatibility follow up Fixed
the mayConvertScalarInputToTensor logic we can consider Added type promotion code
fix up fixConvOptionalBias The current approach seems correct
Use opinfo tests using dtypes decorator. Will use OpInfo in a subsequent PR, if that'd be possible. Should we create a list of ops from opDB that are supported by oneDNN Graph, and add it to common_methods_invocations.py?
inferDevice torch_check call not necessary now, perhaps, as only CPU is supported, for now? We'd add it by the beta release of oneDNN Graph, though, so that by then, users might be able to use other fusers with oneDNN Graph (NNC/TensorExpr are already compatible with the oneDNN Graph fuser). We can still add it, if you'd insist.
not checking shapes of input mkldnn tensor to llga guard Those checks should not be present because oneDNN Graph may use blocked or channels-last layout, so those strides would be different. They're only skipped if an LLGA subgraph's output is input to another LLGA subgraph, which enables LLGA to choose an optimal layout between them.
fix test failures with respect to unsupported inputs We'll address them with the upcoming release of oneDNN Graph beta version
  1. More PyTorch ops are being been mapped to oneDNN Graph

Example of using oneDNN Graph with BFloat16

# Assuming we have a model of the name 'model'

example_input = torch.rand(1, 3, 224, 224)

# enable oneDNN Graph
torch.jit.enable_onednn_fusion(True)
# Disable AMP for JIT
torch._C._jit_set_autocast_mode(False)
with torch.no_grad(), torch.cpu.amp.autocast():
    model = torch.jit.trace(model, (example_input))
    model = torch.jit.freeze(model)
     # 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
    model(example_input)
    model(example_input)
    
    # speedup would be observed in subsequent runs.
    model(example_input)

TorchBench based Benchmarks

URL: https://github.com/sanchitintel/benchmark/tree/onednn_graph_benchmark (instructions present at URL).
Batch-size(s): TorchBench-default for each model
Baseline : PyTorch JIT OFI FP32
Machine: Intel(R) Xeon(R) Platinum 8371HC (Cooper Lake)
Sockets used: 1
Number of cores on one socket: 26
Intel OpenMP & tcmalloc were preloaded

Benchmark results with single thread

name latency of PyTorch JIT OFI FP32 (s) Latency of oneDNN Graph BF16 (s) % change
test_eval[alexnet-cpu-jit] 1.063851 0.509820 -52.1%
test_eval[mnasnet1_0-cpu-jit] 0.218435 0.107100 -51.0%
test_eval[mobilenet_v2-cpu-jit] 0.114467 0.058359 -49.0%
test_eval[mobilenet_v3_large-cpu-jit] 0.233873 0.117614 -49.7%
test_eval[resnet18-cpu-jit] 0.160584 0.075854 -52.8%
test_eval[resnet50-cpu-jit] 1.652846 0.713373 -56.8%
test_eval[resnext50_32x4d-cpu-jit] 0.471174 0.209431 -55.6%
test_eval[shufflenet_v2_x1_0-cpu-jit] 0.310306 0.167090 -46.2%
test_eval[squeezenet1_1-cpu-jit] 0.161247 0.045684 -71.7%
test_eval[timm_efficientnet-cpu-jit] 1.643772 0.800099 -51.3%
test_eval[timm_regnet-cpu-jit] 5.732272 2.333417 -59.3%
test_eval[timm_resnest-cpu-jit] 1.366464 0.715252 -47.7%
test_eval[timm_vision_transformer-cpu-jit] 0.508521 0.271598 -46.6%
test_eval[timm_vovnet-cpu-jit] 2.756692 1.125033 -59.2%
test_eval[vgg16-cpu-jit] 0.711533 0.312344 -56.1%

Benchmark results with 26 threads:

name latency of PyTorch JIT OFI FP32 (s) Latency of oneDNN Graph BF16 (s) % change
test_eval[alexnet-cpu-jit] 0.062871 0.034198 -45.6%
test_eval[mnasnet1_0-cpu-jit] 0.022490 0.008172 -63.7%
test_eval[mobilenet_v2-cpu-jit] 0.012730 0.005866 -53.9%
test_eval[mobilenet_v3_large-cpu-jit] 0.025948 0.010346 -60.1%
test_eval[resnet18-cpu-jit] 0.011194 0.005726 -48.9%
test_eval[resnet50-cpu-jit] 0.124662 0.045599 -63.4%
test_eval[resnext50_32x4d-cpu-jit] 0.034737 0.015214 -56.2%
test_eval[shufflenet_v2_x1_0-cpu-jit] 0.028820 0.012517 -56.6%
test_eval[squeezenet1_1-cpu-jit] 0.012557 0.003876 -69.1%
test_eval[timm_efficientnet-cpu-jit] 0.203177 0.051879 -74.5%
test_eval[timm_regnet-cpu-jit] 0.452050 0.151113 -66.6%
test_eval[timm_resnest-cpu-jit] 0.117072 0.052848 -54.9%
test_eval[timm_vision_transformer-cpu-jit] 0.046048 0.023275 -49.5%
test_eval[timm_vovnet-cpu-jit] 0.213187 0.077482 -63.7%
test_eval[vgg16-cpu-jit] 0.044726 0.021998 -50.8%

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 24, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85591

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 4 Pending

As of commit a1c08be:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Sep 24, 2022
@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Sep 24, 2022
@sanchitintel

This comment was marked as off-topic.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2022
@sanchitintel sanchitintel changed the title [WIP] Add BF16 support for LLGA [WIP] Add BF16 support for oneDNN Graph Sep 24, 2022
@sanchitintel sanchitintel force-pushed the onednn_graph_bf16 branch 3 times, most recently from 305a249 to 142cb17 Compare September 27, 2022 05:49
@sanchitintel sanchitintel marked this pull request as ready for review September 27, 2022 06:00
@sanchitintel sanchitintel changed the title [WIP] Add BF16 support for oneDNN Graph Add BFloat16 dtype support for oneDNN Graph JIT fuser Sep 27, 2022
Copy link
Collaborator Author

@sanchitintel sanchitintel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XiaobingSuper, added you as a reviewer, because this PR uses a distinct ideep commit (same oneDNN commit in ideep/mkl-dnn/third_party/oneDNN & doesn't change any ideep file, but only changes oneDNN Graph files (ideep/mkl-dnn). Thanks!

@sanchitintel

This comment was marked as resolved.

@frank-wei
Copy link
Contributor

cc @malfet , do we remove the choice of importing the diff into internal?

@malfet
Copy link
Contributor

malfet commented Sep 28, 2022

cc @malfet , do we remove the choice of importing the diff into internal?

No, but the plugin has been finicky throughout the day

@facebook-github-bot
Copy link
Contributor

@malfet has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

torch/csrc/jit/codegen/onednn/operator.h Outdated Show resolved Hide resolved
test/test_jit_llga_fuser.py Outdated Show resolved Hide resolved
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 additional jobs have failed, first few of them are: Meta Internal-Only Changes Check

Details for Dev Infra team Raised by workflow job

@sanchitintel
Copy link
Collaborator Author

Thanks again, @frank-wei! It looks like it'd have to be imported again before merging.

@facebook-github-bot
Copy link
Contributor

@frank-wei has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@sanchitintel
Copy link
Collaborator Author

Hi @malfet, please help me with a query about the pytorchmergebot rule for not merging a PR if the last push to it was more than 3 days ago - my last push to this PR was on Oct 10. Does this rule mean that the PR must be merged on or before Oct 13, or the PR would've to be rebased on Oct 14 (& subsequently reimported as well - since pytorchmergebot didn't merge this PR last time because of the Meta Internal-Only Changes Check due to rebased commit being different than the commit for which Meta-internal CI checks were run, so I'm assuming it'd do so again, if it comes to that)?

Thanks!

@frank-wei
Copy link
Contributor

I feel like once we imported this PR, it always has to sync internally whenever we want to merge it from outside. This could get improved since once the PR is on diff train, the diff is supposed to re-created internally.
For stale state check, I do not have clear answer. Better check with @malfet

@sanchitintel
Copy link
Collaborator Author

Thanks for your inputs, @frank-wei! :)

If the answer to my query about the pytorchmergebot stale check is yes, then IMHO, the 3 day period seems a bit small, as it might require Meta engineers to frequently re-import PRs. But on the other hand, I guess it does seem to help keep the trunk CI greener, so maybe it was determined as being worth the trade-off. :)

@sanchitintel sanchitintel added intel This tag is for PR from Intel intel priority matters to intel architecture from performance wise labels Oct 12, 2022
@frank-wei
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @sanchitintel.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@sanchitintel
Copy link
Collaborator Author

Thanks again, @frank-wei! :)

@sanchitintel sanchitintel added topics topic: performance topic category and removed topics labels Oct 13, 2022
@malfet
Copy link
Contributor

malfet commented Oct 14, 2022

Note to internal oncall: this PR updates ideep, it needs to be kept in sync internally.
cc: @weiwangmeta

malfet added a commit that referenced this pull request Oct 14, 2022
Fix DOS newlines introduced by #85591
malfet added a commit that referenced this pull request Oct 14, 2022
Fix DOS newlines introduced by #85591
malfet added a commit that referenced this pull request Oct 14, 2022
Fix DOS newlines introduced by #85591
malfet added a commit that referenced this pull request Oct 14, 2022
Fix DOS newlines introduced by #85591
malfet added a commit that referenced this pull request Oct 14, 2022
Fix DOS newlines introduced by #85591
@sanchitintel
Copy link
Collaborator Author

sanchitintel commented Oct 14, 2022

this PR updates ideep, it needs to be kept in sync internally.

Hi @weiwangmeta, we'd be submitting a PR to update ideep again today with a new oneDNN version. I'll add you as a reviewer in that PR. Thanks!

pytorchmergebot pushed a commit that referenced this pull request Oct 14, 2022
pytorchmergebot pushed a commit that referenced this pull request Oct 15, 2022
Fix DOS newlines in `onednn/decompose_silu.[cpp|h]` introduced by #85591 as well as one in `.github/PULL_REQUEST_TEMPLATE.md`

Pull Request resolved: #86973
Approved by: https://github.com/huydhn, https://github.com/izaitsevfb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed intel priority matters to intel architecture from performance wise intel This tag is for PR from Intel Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: jit release notes category topic: performance topic category
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

8 participants