-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] CPU float16 performance optimization on eager mode. #97068
Comments
More clarification on avx512-fp16: we are not going to leverage avx512-fp16 to optimize non-conv/gemm ATen ops (primarily pointwise and reductions) but plan to follow the similar "fused type cast" approaches as how we optimize those ops for the bf16 data type, i.e. fp16 data are converted to/from fp32, fused with computation which happens with fp32. The type cast will rely on the f16c instruction set, as noted by @mingfeima in the RFC description. This is due to the following considerations:
amx-fp16 will be leveraged to optimize conv and gemm ops via the oneDNN library. |
…v on CPU" The PR is part of #97068, which is to add fp16 support for mkldnn conv and mkldnn deconv to leverage avx512-fp16 or amx-fp16 via the oneDNN library. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PR is part of #97068, which is to add fp16 support for mkldnn conv and mkldnn deconv to leverage avx512-fp16 or amx-fp16 via the oneDNN library. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…v on CPU" The PR is part of #97068, which is to add fp16 support for mkldnn conv and mkldnn deconv to leverage avx_ne_convert, avx512-fp16, and amx-fp16 via the oneDNN library. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PR is part of #97068, which is to add fp16 support for mkldnn conv and mkldnn deconv to leverage avx_ne_convert, avx512-fp16, and amx-fp16 via the oneDNN library. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…v on CPU" The PR is part of #97068, which is to add fp16 support for mkldnn conv and mkldnn deconv to leverage avx_ne_convert, avx512-fp16, and amx-fp16 via the oneDNN library. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PR is part of #97068, which is to add fp16 support for mkldnn conv and mkldnn deconv to leverage avx_ne_convert, avx512-fp16, and amx-fp16 via the oneDNN library. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
The PR is part of #97068, which is to add fp16 support for mkldnn conv and mkldnn deconv to leverage avx_ne_convert, avx512-fp16, and amx-fp16 via the oneDNN library. Pull Request resolved: #99496 Approved by: https://github.com/jgong5, https://github.com/cpuhrsch
🚀 The feature, motivation and pitch
The RFC is to improve float16 performance as well as Op coverage on PyTorch CPU backend on eager mode.
Float16 and BFloat16 are both commonly used reduced floating point types for performance improvement in neural network inference/training. On the CPU side, previous optimization efforts have been placed more on BFloat16, which leaves float16 at a relatively primitive status.
On the 4th generation Intel® Xeon® Scalable processor (Sapphire Rapids), a new fp16 instruction set architecture for Intel® AVX-512 has been added, e.g. avx512-fp16. The instruction set supports a wide range of general-purpose numeric operations for fp16. One the next generation of Xeon, Intel® Advanced Matrix Extensions (AMX) will have fp16 support, e.g. amx-fp16.
This proposal would help the scenario when the model is pre trained on GPU with mixed precision of float16/float32 and users intend to do deployment on the CPU side without modifying the model weights, for instance, many HuggingFace models belong to this scenario.
This project will be targeting at:
Technically, the optimization will be carried out as below:
Compute intensive Ops (e.g. Convolution, Gemm, and RNN):
Generic ATen kernels:
Half
. Add native convert intrinsics:_mm256_cvtph_ps
/_mm256_cvtps_ph
(Rounding mode: RNE).Test Plan:
vec_test_all_types_AVX2
andvec_test_all_types_AVX512
for float16.torch/testing/_internal/common_methods_invocations.py
.Alternatives
No response
Additional context
Previous RFC on extending AMP fp16 on CPU:
Float16 support in torch inductor working in parallel (implemented in similar method as BFloat16 support), has dependency on explicit vectorization utils of
at::vec::Vectorized<Half>
.Pull Requests related to this feature requests:
cc @jgong5 @XiaobingSuper @sanchitintel @ashokei @jingxu10
The text was updated successfully, but these errors were encountered: