Skip to content

Commit

Permalink
Update base for Update on "[AOTI] support freezing for MKLDNN"
Browse files Browse the repository at this point in the history
## Description
Fixes #114450. This PR builds upon the work from imzhuhl done in #114451.

This PR requires #122472 to land firstly.

We leverage the serialization and deserialization API from oneDNN v3.4.1 to save the opaque MKLDNN tensor during the compilation and restore the opaque tensor when loading the compiled .so.
ideep version is updated so that we won't break any pipeline even if third_party/ideep is not updated at the same time.

### Test plan:
```sh
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_conv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_deconv_freezing_non_abi_compatible_cpu
python -u test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCpu.test_linear_freezing_non_abi_compatible_cpu
```

### TODOs in follow-up PRs
1. We found that using `AOTI_TORCH_CHECK` will cause performance drop on several models (`DistillGPT2`, `MBartForConditionalGeneration`, `T5ForConditionalGeneration`, `T5Small`) compared with JIT Inductor which uses `TORCH_CHECK`. This may need further discussion how to address (`AOTI_TORCH_CHECK` is introduced in 
 #119220).
2. Freezing in non-ABI compatible mode will work with the support in this PR. While for ABI compatible mode, we need to firstly address this issue: `AssertionError: None, i.e. optional output is not supported`.
https://github.com/pytorch/pytorch/blob/6c4f43f82675b5fcfe8cf3e5983d0c0f326408aa/torch/_inductor/codegen/cpp_wrapper_cpu.py#L2023-L2024

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
chunyuan-w committed May 25, 2024
2 parents 56c31ce + 5f15110 commit 7e9a39a
Show file tree
Hide file tree
Showing 213 changed files with 5,035 additions and 26,324 deletions.
1 change: 1 addition & 0 deletions .ci/pytorch/multigpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ time python test/run_test.py --verbose -i distributed/test_c10d_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_nccl
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_cuda_p2p
time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/assigntome-docathon.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ on:
jobs:
assign:
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Check for "/assigntome" in comment
uses: actions/github-script@v6
Expand All @@ -26,14 +28,14 @@ jobs:
repo: context.repo.repo,
issue_number: issueNumber
});
const hasLabel = issue.labels.some(label => label.name === 'docathon-h2-2023');
const hasLabel = issue.labels.some(label => label.name === 'docathon-h1-2024');
if (hasLabel) {
if (issue.assignee !== null) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issueNumber,
body: "The issue is already assigned. Please pick an opened and unnasigned issue with the [docathon-h2-2023 label](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3Adocathon-h2-2023)."
body: "The issue is already assigned. Please pick an opened and unnasigned issue with the [docathon-h1-2024 label](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3Adocathon-h1-2024)."
});
} else {
await github.rest.issues.addAssignees({
Expand All @@ -44,7 +46,7 @@ jobs:
});
}
} else {
const commmentMessage = "This issue does not have the correct label. Please pick an opened and unnasigned issue with the [docathon-h2-2023 label](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3Adocathon-h2-2023)."
const commmentMessage = "This issue does not have the correct label. Please pick an opened and unnasigned issue with the [docathon-h1-2024 label](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3Adocathon-h1-2024)."
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
Expand Down
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2128,7 +2128,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'ruff==0.4.4',
'ruff==0.4.5',
]
is_formatter = true

Expand Down
62 changes: 0 additions & 62 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -456,50 +456,8 @@ CAFFE2_COPTS = COMMON_COPTS + [
filegroup(
name = "caffe2_core_srcs",
srcs = [
"caffe2/core/allocator.cc",
"caffe2/core/blob_serialization.cc",
"caffe2/core/blob_stats.cc",
"caffe2/core/common.cc",
"caffe2/core/context.cc",
"caffe2/core/context_base.cc",
"caffe2/core/db.cc",
"caffe2/core/event.cc",
"caffe2/core/export_c10_op_to_caffe2.cc",
"caffe2/core/graph.cc",
"caffe2/core/init.cc",
"caffe2/core/init_denormals.cc",
"caffe2/core/init_intrinsics_check.cc",
"caffe2/core/init_omp.cc",
"caffe2/core/int8_serialization.cc",
"caffe2/core/memonger.cc",
"caffe2/core/module.cc",
"caffe2/core/net.cc",
"caffe2/core/net_async_base.cc",
"caffe2/core/net_async_scheduling.cc",
"caffe2/core/net_async_task.cc",
"caffe2/core/net_async_task_future.cc",
"caffe2/core/net_async_task_graph.cc",
"caffe2/core/net_async_tracing.cc",
"caffe2/core/net_dag_utils.cc",
"caffe2/core/net_parallel.cc",
"caffe2/core/net_simple.cc",
"caffe2/core/net_simple_refcount.cc",
"caffe2/core/nomnigraph/Representations/NeuralNet.cc",
"caffe2/core/nomnigraph/tests/test_util.cc",
"caffe2/core/numa.cc",
"caffe2/core/operator.cc",
"caffe2/core/operator_schema.cc",
"caffe2/core/plan_executor.cc",
"caffe2/core/prof_dag_counters.cc",
"caffe2/core/qtensor.cc",
"caffe2/core/qtensor_serialization.cc",
"caffe2/core/stats.cc",
"caffe2/core/tensor.cc",
"caffe2/core/tensor_int8.cc",
"caffe2/core/test_utils.cc",
"caffe2/core/transform.cc",
"caffe2/core/types.cc",
"caffe2/core/workspace.cc",
],
)

Expand Down Expand Up @@ -534,17 +492,9 @@ filegroup(
srcs = [
"caffe2/utils/bench_utils.cc",
"caffe2/utils/cpuid.cc",
"caffe2/utils/math/broadcast.cc",
"caffe2/utils/math/elementwise.cc",
"caffe2/utils/math/reduce.cc",
"caffe2/utils/math/transpose.cc",
"caffe2/utils/math/utils.cc",
"caffe2/utils/math_cpu.cc",
"caffe2/utils/murmur_hash3.cc",
"caffe2/utils/proto_utils.cc",
"caffe2/utils/proto_wrap.cc",
"caffe2/utils/signal_handler.cc",
"caffe2/utils/smart_tensor_printer.cc",
"caffe2/utils/string_utils.cc",
"caffe2/utils/threadpool/ThreadPool.cc",
"caffe2/utils/threadpool/pthreadpool.cc",
Expand Down Expand Up @@ -585,18 +535,9 @@ cc_library(
name = "caffe2_headers",
hdrs = glob(
[
"caffe2/core/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Converters/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Generated/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Graph/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Representations/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Support/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Transformations/*.h",
"caffe2/core/nomnigraph/tests/*.h",
"caffe2/perfkernels/*.h",
"caffe2/serialize/*.h",
"caffe2/utils/*.h",
"caffe2/utils/math/*.h",
"caffe2/utils/threadpool/*.h",
"modules/**/*.h",
],
Expand All @@ -608,9 +549,6 @@ cc_library(
"caffe2/image/*.h",
])),
copts = CAFFE2_COPTS,
includes = [
"caffe2/core/nomnigraph/include",
],
visibility = ["//visibility:public"],
deps = [
":caffe2_core_macros",
Expand Down
140 changes: 88 additions & 52 deletions aten/src/ATen/native/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/Parallel.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/Unroll.h>
#include <c10/util/complex.h>
#include <c10/util/irange.h>
#include <algorithm>
Expand Down Expand Up @@ -244,26 +245,30 @@ static inline float16_t reduce(float16x8_t x) {
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#define F16_ELEMENTS_PER_ITERATION 32
#define F16_ELEMENTS_PER_REGISTER 8
#define F16_REGISTERS_PER_ITERATION (F16_ELEMENTS_PER_ITERATION / F16_ELEMENTS_PER_REGISTER)
static inline double reduce(float16x8_t x[F16_REGISTERS_PER_ITERATION]) {
int offset = F16_REGISTERS_PER_ITERATION / 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f16(x[i], x[offset + i]);
}
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f16(x[i], x[offset + i]);
}
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f16(x[i], x[offset + i]);
}
// We need the shift for reduce(), hence the extra constants.
static constexpr auto kF16ElementsPerIterationShift = 7;
static constexpr auto kF16ElementsPerIteration = 1 << kF16ElementsPerIterationShift;
static_assert(kF16ElementsPerIteration == 128);

static constexpr auto kF16ElementsPerRegisterShift = 3;
static constexpr auto kF16ElementsPerRegister = 1 << kF16ElementsPerRegisterShift;
static_assert(kF16ElementsPerRegister == 8);

static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationShift - kF16ElementsPerRegisterShift;
static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift;
static_assert(kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister);

static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) {
int offset = kF16RegistersPerIteration;
c10::ForcedUnroll<kF16RegistersPerIterationShift>{}([&offset, &x](auto idx) {
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f16(x[i], x[offset + i]);
}
});
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0]));
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0]));
return (double)vaddvq_f32(vaddq_f32(t0, t1));

}

static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) {
Expand All @@ -280,13 +285,13 @@ static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) {
static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
float16x8_t sum[F16_REGISTERS_PER_ITERATION] = {vdupq_n_f16(0)};
float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)};

const auto m_aligned = m & ~(F16_ELEMENTS_PER_ITERATION - 1);
for (int j = 0; j < m_aligned ; j += F16_ELEMENTS_PER_ITERATION) {
for (int k = 0; k < F16_REGISTERS_PER_ITERATION; ++k) {
const auto temp_x = vld1q_f16(x + j + k * F16_ELEMENTS_PER_REGISTER);
const auto temp_a = vld1q_f16(a + lda * i + j + k * F16_ELEMENTS_PER_REGISTER);
const auto m_aligned = m & ~(kF16ElementsPerIteration - 1);
for (int j = 0; j < m_aligned ; j += kF16ElementsPerIteration) {
for (int k = 0; k < kF16RegistersPerIteration; ++k) {
const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister);
const auto temp_a = vld1q_f16(a + lda * i + j + k * kF16ElementsPerRegister);
sum[k] = f16_fma(sum[k], temp_x, temp_a);
}
}
Expand Down Expand Up @@ -315,45 +320,76 @@ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
#endif
}

static inline float32x4_t f32_fma_low_f16(float32x4_t a, float16x8_t b, float16x8_t c) {
#ifdef __ARM_FEATURE_FP16_FML
// NOTE: this instruction is an optional instruction in ARM v8.2 and
// v8.3, but mandatory in v8.4 per
// https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en
// I'm not certain that I have the right feature test macro.
return vfmlalq_low_f16(a, b, c);
#else
return f32_fma(a, vcvt_f32_f16(vget_low_f16(b)), vcvt_f32_f16(vget_low_f16(c)));
#endif
}

static inline float32x4_t f32_fma_high_f16(float32x4_t a, float16x8_t b, float16x8_t c) {
#ifdef __ARM_FEATURE_FP16_FML
// See above note about this instruction.
return vfmlalq_high_f16(a, b, c);
#else
return f32_fma(a, vcvt_f32_f16(vget_high_f16(b)), vcvt_f32_f16(vget_high_f16(c)));
#endif
}

// The below reduce overload and
// fp16_gemv_trans_fp32_arith_by_dot_products are adapted from
// llama.cpp's ggml_vec_dot_f32 and surrounding utility functions. See
// NOTE [ GGML Copyright Notice ] above for the required notice.
#define F32_ELEMENTS_PER_ITERATION 16
#define F32_ELEMENTS_PER_REGISTER 4
#define F32_REGISTERS_PER_ITERATION (F32_ELEMENTS_PER_ITERATION / F32_ELEMENTS_PER_REGISTER)
static inline double reduce(float32x4_t x[F32_REGISTERS_PER_ITERATION]) {
int offset = F32_REGISTERS_PER_ITERATION / 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f32(x[i], x[offset + i]);
}
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f32(x[i], x[offset + i]);
}
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f32(x[i], x[offset + i]);
}
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f32(x[i], x[offset + i]);
}

// We need the shift for reduce(), hence the extra constants.
static constexpr auto kF32ElementsPerIterationShift = 5;
static constexpr auto kF32ElementsPerIteration = 1 << kF32ElementsPerIterationShift;
static_assert(kF32ElementsPerIteration == 32);

static constexpr auto kF32ElementsPerRegisterShift = 2;
static constexpr auto kF32ElementsPerRegister = 1 << kF32ElementsPerRegisterShift;
static_assert(kF32ElementsPerRegister == 4);

static constexpr auto kF32RegisterPairsPerIteration = 4;
static constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
static constexpr auto kF32RegistersPerIterationShift = 3;
static_assert(kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister);
static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);

static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
int offset = kF32RegistersPerIteration;
c10::ForcedUnroll<kF32RegistersPerIterationShift>{}([&offset, &x](auto idx) {
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f32(x[i], x[offset + i]);
}
});
return vaddvq_f32(x[0]);
}

// On my Apple M1 Macbook (which is ARM v8.5 and thus has the
// instructions f32_fma_{low,high}_f16 is targeting), this kernel has
// equivalent performance to the fp16-native kernel.
static void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
float32x4_t sum[F32_REGISTERS_PER_ITERATION] = {vdupq_n_f32(0)};

const auto m_aligned = m & ~(F32_ELEMENTS_PER_ITERATION - 1);
for (int j = 0; j < m_aligned ; j += F32_ELEMENTS_PER_ITERATION) {
for (int k = 0; k < F32_REGISTERS_PER_ITERATION; ++k) {
const auto temp_x = vcvt_f32_f16(vld1_f16(x + j + k * F32_ELEMENTS_PER_REGISTER));
const auto temp_a = vcvt_f32_f16(vld1_f16(a + lda * i + j + k * F32_ELEMENTS_PER_REGISTER));
sum[k] = f32_fma(sum[k], temp_x, temp_a);
}
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};

const auto m_aligned = m & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < m_aligned ; j += kF32ElementsPerIteration) {
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([x, a, lda, i, j, &sum](auto k) {
// Load a pair of f32 registers at a time.
const auto temp_x = vld1q_f16(x + j + k * 2 * kF32ElementsPerRegister);
const auto temp_a = vld1q_f16(a + lda * i + j + k * 2 * kF32ElementsPerRegister);

sum[2 * k] = f32_fma_low_f16(sum[2 * k], temp_x, temp_a);
sum[2 * k + 1] = f32_fma_high_f16(sum[2 * k + 1], temp_x, temp_a);
});
}
auto reducedSum = reduce(sum);

Expand Down
Loading

0 comments on commit 7e9a39a

Please sign in to comment.