Skip to content

Commit

Permalink
Update on "[inductor][cpp] GEMM template (infra and fp32)"
Browse files Browse the repository at this point in the history
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.

Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |

Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x

Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |

Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365)

[ghstack-poisoned]
  • Loading branch information
jgong5 committed May 21, 2024
2 parents 2993c2e + ad1684f commit ac36018
Show file tree
Hide file tree
Showing 96 changed files with 2,391 additions and 1,709 deletions.
3 changes: 3 additions & 0 deletions .ci/docker/centos-rocm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ RUN rm install_rocm.sh
COPY ./common/install_rocm_magma.sh install_rocm_magma.sh
RUN bash ./install_rocm_magma.sh
RUN rm install_rocm_magma.sh
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh
ENV PATH /opt/rocm/bin:$PATH
ENV PATH /opt/rocm/hcc/bin:$PATH
ENV PATH /opt/rocm/hip/bin:$PATH
Expand Down
5 changes: 5 additions & 0 deletions .ci/docker/common/install_amdsmi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

set -ex

cd /opt/rocm/share/amd_smi && pip install .
6 changes: 4 additions & 2 deletions .ci/docker/common/install_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ install_ubuntu() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib

if [[ $(ver $ROCM_VERSION) -ge $(ver 6.1) ]]; then
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
Expand Down Expand Up @@ -106,7 +107,8 @@ install_centos() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib

# precompiled miopen kernels; search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
Expand Down
5 changes: 5 additions & 0 deletions .ci/docker/ubuntu-rocm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ ENV MAGMA_HOME /opt/rocm/magma
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8

# Install amdsmi
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh

# (optional) Install non-default CMake version
ARG CMAKE_VERSION
COPY ./common/install_cmake.sh install_cmake.sh
Expand Down
3 changes: 3 additions & 0 deletions .ci/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ else
fi
WERROR=1 python setup.py bdist_wheel
else
if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then
source .ci/pytorch/install_cache_xla.sh
fi
python setup.py bdist_wheel
fi
pip_install_whl "$(echo dist/*.whl)"
Expand Down
37 changes: 37 additions & 0 deletions .ci/pytorch/install_cache_xla.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash

# Script for installing sccache on the xla build job, which uses xla's docker
# image and doesn't have sccache installed on it. This is mostly copied from
# .ci/docker/install_cache.sh. Changes are: removing checks that will always
# return the same thing, ex checks for for rocm, CUDA, and changing the path
# where sccache is installed, and not changing /etc/environment.

set -ex

install_binary() {
echo "Downloading sccache binary from S3 repo"
curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /tmp/cache/bin/sccache
}

mkdir -p /tmp/cache/bin
mkdir -p /tmp/cache/lib
export PATH="/tmp/cache/bin:$PATH"

install_binary
chmod a+x /tmp/cache/bin/sccache

function write_sccache_stub() {
# Unset LD_PRELOAD for ps because of asan + ps issues
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90589
# shellcheck disable=SC2086
# shellcheck disable=SC2059
printf "#!/bin/sh\nif [ \$(env -u LD_PRELOAD ps -p \$PPID -o comm=) != sccache ]; then\n exec sccache $(which $1) \"\$@\"\nelse\n exec $(which $1) \"\$@\"\nfi" > "/tmp/cache/bin/$1"
chmod a+x "/tmp/cache/bin/$1"
}

write_sccache_stub cc
write_sccache_stub c++
write_sccache_stub gcc
write_sccache_stub g++
write_sccache_stub clang
write_sccache_stub clang++
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ cc_library(
[
"torch/*.h",
"torch/csrc/**/*.h",
"torch/csrc/distributed/c10d/**/*.hpp",
"torch/csrc/distributed/c10d/*.hpp",
"torch/lib/libshm/*.h",
],
exclude = [
Expand Down
14 changes: 9 additions & 5 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ namespace {
// directly against incoming TensorImpl*s.
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using val_type = std::tuple<weakref_type, Tensor>;
ska::flat_hash_map<TensorImpl*, val_type> cached_casts;

static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
return cached_casts;
}
std::mutex cached_casts_mutex;


Expand Down Expand Up @@ -82,7 +86,7 @@ thread_local bool cache_enabled = true;

void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
cached_casts.clear();
get_cached_casts().clear();
}

int increment_nesting() {
Expand Down Expand Up @@ -124,12 +128,12 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_

if (can_try_cache) {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
if (it != cached_casts.end()) {
auto it = get_cached_casts().find(arg.unsafeGetTensorImpl());
if (it != get_cached_casts().end()) {
return std::get<1>(it->second);
} else {
auto casted_arg = arg.to(to_type);
cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
return casted_arg;
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ void CUDAGraph::debug_dump(const std::string& debug_path) {
TORCH_WARN("DEBUG: calling debug_dump()");
if (has_graph_) {
TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path);
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), cudaGraphDebugDotFlagsVerbose)); // most verbose output
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
}
} else {
Expand Down
47 changes: 43 additions & 4 deletions aten/src/ATen/native/cpu/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,27 @@ void gemm_notrans_(
}


static float compute_dot(const float16_t *a, const float16_t *b, int64_t l) {
inline float32x4_t load_as_float32x4(const Half* ptr) {
return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
}

inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
}

template<typename T>
static float compute_dot(const T* a, const T* b, int64_t l) {
if ((l&3) != 0) {
return sum(l, [&](int64_t i) -> float {
return float(a[i]) * float(b[i]);
});
}
float32x4_t rcv = vdupq_n_f32(0);
for (int64_t idx = 0; idx < l; idx += 4) {
float32x4_t aVec = vcvt_f32_f16(vld1_f16(a + idx));
float32x4_t bVec = vcvt_f32_f16(vld1_f16(b + idx));
float32x4_t aVec = load_as_float32x4(a + idx);
float32x4_t bVec = load_as_float32x4(b + idx);
rcv = vaddq_f32(rcv, vmulq_f32(aVec, bVec));
}
auto sum = vpaddq_f32(rcv, rcv);
Expand All @@ -343,7 +354,35 @@ void gemm_transa_(
for (const auto i : c10::irange(begin, end)) {
const auto *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = compute_dot(reinterpret_cast<const float16_t*>(a_), reinterpret_cast<const float16_t*>(b_), k);
const auto dot = compute_dot(a_, b_, k);
b_ += ldb;
if (beta == 0) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
});
}

template <>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
float alpha,
const at::BFloat16 *a, int64_t lda,
const at::BFloat16 *b, int64_t ldb,
float beta,
at::BFloat16 *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
const auto *a_ = a + begin * lda;
for (const auto i : c10::irange(begin, end)) {
const auto *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = compute_dot(a_, b_, k);
b_ += ldb;
if (beta == 0) {
c[j*ldc+i] = alpha*dot;
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/mps/operations/TensorCompare.mm
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
bool invert,
const Tensor& out,
string op_name) {
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"isin_Tensor_Tensor_out supported on MPS from MacOs_14_0 onwards");
if (elements.numel() == 0) {
return;
}
Expand All @@ -295,6 +293,10 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,

TORCH_CHECK(elements.is_mps() && test_elements.is_mps());
TORCH_CHECK(elements.dtype() == test_elements.dtype());
TORCH_CHECK(
!(!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) && !supportedFloatingType(elements.scalar_type())),
"isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ",
elements.scalar_type());

@autoreleasepool {
string key =
Expand Down
29 changes: 29 additions & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,18 @@ def get_example_inputs(self):
action="store_true",
help="Measure speedup with TorchInductor",
)
group.add_argument(
"--quantization",
choices=[
"int8dynamic",
"int8weightonly",
"int4weightonly",
"autoquant",
"noquant",
],
default=None,
help="Measure speedup of torchao quantization with TorchInductor baseline",
)
group.add_argument(
"--export",
action="store_true",
Expand Down Expand Up @@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None):
if args.inductor:
assert args.backend is None
args.backend = "inductor"
if args.quantization:
assert args.backend is None
args.backend = "torchao"
if args.dynamic_batch_only:
args.dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = True
Expand Down Expand Up @@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None):

# AOTInductor doesn't support control flow yet
runner.skip_models.update(runner.skip_models_due_to_control_flow)
elif args.backend == "torchao":
assert "cuda" in args.devices, "Quantization requires CUDA device."
assert args.bfloat16, "Quantization requires dtype bfloat16."
from .torchao import setup_baseline, torchao_optimize_ctx

setup_baseline()
baseline_ctx = functools.partial(
torch.compile,
backend="inductor",
fullgraph=args.nopython,
mode=args.inductor_compile_mode,
)
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
optimize_ctx = torchao_optimize_ctx(args.quantization)
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = speedup_experiment
Expand Down

0 comments on commit ac36018

Please sign in to comment.