Skip to content

Commit

Permalink
Update on "[inductor][cpp] epilogue support for gemm template"
Browse files Browse the repository at this point in the history
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
jgong5 committed May 21, 2024
2 parents df28970 + 0f831b6 commit 05d5463
Show file tree
Hide file tree
Showing 96 changed files with 2,394 additions and 1,710 deletions.
3 changes: 3 additions & 0 deletions .ci/docker/centos-rocm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ RUN rm install_rocm.sh
COPY ./common/install_rocm_magma.sh install_rocm_magma.sh
RUN bash ./install_rocm_magma.sh
RUN rm install_rocm_magma.sh
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh
ENV PATH /opt/rocm/bin:$PATH
ENV PATH /opt/rocm/hcc/bin:$PATH
ENV PATH /opt/rocm/hip/bin:$PATH
Expand Down
5 changes: 5 additions & 0 deletions .ci/docker/common/install_amdsmi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

set -ex

cd /opt/rocm/share/amd_smi && pip install .
6 changes: 4 additions & 2 deletions .ci/docker/common/install_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ install_ubuntu() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib

if [[ $(ver $ROCM_VERSION) -ge $(ver 6.1) ]]; then
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
Expand Down Expand Up @@ -106,7 +107,8 @@ install_centos() {
rocm-libs \
rccl \
rocprofiler-dev \
roctracer-dev
roctracer-dev \
amd-smi-lib

# precompiled miopen kernels; search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
Expand Down
5 changes: 5 additions & 0 deletions .ci/docker/ubuntu-rocm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ ENV MAGMA_HOME /opt/rocm/magma
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8

# Install amdsmi
COPY ./common/install_amdsmi.sh install_amdsmi.sh
RUN bash ./install_amdsmi.sh
RUN rm install_amdsmi.sh

# (optional) Install non-default CMake version
ARG CMAKE_VERSION
COPY ./common/install_cmake.sh install_cmake.sh
Expand Down
3 changes: 3 additions & 0 deletions .ci/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ else
fi
WERROR=1 python setup.py bdist_wheel
else
if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then
source .ci/pytorch/install_cache_xla.sh
fi
python setup.py bdist_wheel
fi
pip_install_whl "$(echo dist/*.whl)"
Expand Down
37 changes: 37 additions & 0 deletions .ci/pytorch/install_cache_xla.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash

# Script for installing sccache on the xla build job, which uses xla's docker
# image and doesn't have sccache installed on it. This is mostly copied from
# .ci/docker/install_cache.sh. Changes are: removing checks that will always
# return the same thing, ex checks for for rocm, CUDA, and changing the path
# where sccache is installed, and not changing /etc/environment.

set -ex

install_binary() {
echo "Downloading sccache binary from S3 repo"
curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /tmp/cache/bin/sccache
}

mkdir -p /tmp/cache/bin
mkdir -p /tmp/cache/lib
export PATH="/tmp/cache/bin:$PATH"

install_binary
chmod a+x /tmp/cache/bin/sccache

function write_sccache_stub() {
# Unset LD_PRELOAD for ps because of asan + ps issues
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90589
# shellcheck disable=SC2086
# shellcheck disable=SC2059
printf "#!/bin/sh\nif [ \$(env -u LD_PRELOAD ps -p \$PPID -o comm=) != sccache ]; then\n exec sccache $(which $1) \"\$@\"\nelse\n exec $(which $1) \"\$@\"\nfi" > "/tmp/cache/bin/$1"
chmod a+x "/tmp/cache/bin/$1"
}

write_sccache_stub cc
write_sccache_stub c++
write_sccache_stub gcc
write_sccache_stub g++
write_sccache_stub clang
write_sccache_stub clang++
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ cc_library(
[
"torch/*.h",
"torch/csrc/**/*.h",
"torch/csrc/distributed/c10d/**/*.hpp",
"torch/csrc/distributed/c10d/*.hpp",
"torch/lib/libshm/*.h",
],
exclude = [
Expand Down
14 changes: 9 additions & 5 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ namespace {
// directly against incoming TensorImpl*s.
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
using val_type = std::tuple<weakref_type, Tensor>;
ska::flat_hash_map<TensorImpl*, val_type> cached_casts;

static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
return cached_casts;
}
std::mutex cached_casts_mutex;


Expand Down Expand Up @@ -82,7 +86,7 @@ thread_local bool cache_enabled = true;

void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
cached_casts.clear();
get_cached_casts().clear();
}

int increment_nesting() {
Expand Down Expand Up @@ -124,12 +128,12 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_

if (can_try_cache) {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
if (it != cached_casts.end()) {
auto it = get_cached_casts().find(arg.unsafeGetTensorImpl());
if (it != get_cached_casts().end()) {
return std::get<1>(it->second);
} else {
auto casted_arg = arg.to(to_type);
cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
return casted_arg;
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ void CUDAGraph::debug_dump(const std::string& debug_path) {
TORCH_WARN("DEBUG: calling debug_dump()");
if (has_graph_) {
TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path);
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output
C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), cudaGraphDebugDotFlagsVerbose)); // most verbose output
AT_CUDA_CHECK(cudaGraphDestroy(graph_));
}
} else {
Expand Down
47 changes: 43 additions & 4 deletions aten/src/ATen/native/cpu/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,27 @@ void gemm_notrans_(
}


static float compute_dot(const float16_t *a, const float16_t *b, int64_t l) {
inline float32x4_t load_as_float32x4(const Half* ptr) {
return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
}

inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
}

template<typename T>
static float compute_dot(const T* a, const T* b, int64_t l) {
if ((l&3) != 0) {
return sum(l, [&](int64_t i) -> float {
return float(a[i]) * float(b[i]);
});
}
float32x4_t rcv = vdupq_n_f32(0);
for (int64_t idx = 0; idx < l; idx += 4) {
float32x4_t aVec = vcvt_f32_f16(vld1_f16(a + idx));
float32x4_t bVec = vcvt_f32_f16(vld1_f16(b + idx));
float32x4_t aVec = load_as_float32x4(a + idx);
float32x4_t bVec = load_as_float32x4(b + idx);
rcv = vaddq_f32(rcv, vmulq_f32(aVec, bVec));
}
auto sum = vpaddq_f32(rcv, rcv);
Expand All @@ -343,7 +354,35 @@ void gemm_transa_(
for (const auto i : c10::irange(begin, end)) {
const auto *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = compute_dot(reinterpret_cast<const float16_t*>(a_), reinterpret_cast<const float16_t*>(b_), k);
const auto dot = compute_dot(a_, b_, k);
b_ += ldb;
if (beta == 0) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
});
}

template <>
void gemm_transa_(
TransposeType transa,
int64_t m, int64_t n, int64_t k,
float alpha,
const at::BFloat16 *a, int64_t lda,
const at::BFloat16 *b, int64_t ldb,
float beta,
at::BFloat16 *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
const auto *a_ = a + begin * lda;
for (const auto i : c10::irange(begin, end)) {
const auto *b_ = b;
for (const auto j : c10::irange(n)) {
const auto dot = compute_dot(a_, b_, k);
b_ += ldb;
if (beta == 0) {
c[j*ldc+i] = alpha*dot;
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/mps/operations/TensorCompare.mm
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
bool invert,
const Tensor& out,
string op_name) {
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"isin_Tensor_Tensor_out supported on MPS from MacOs_14_0 onwards");
if (elements.numel() == 0) {
return;
}
Expand All @@ -295,6 +293,10 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,

TORCH_CHECK(elements.is_mps() && test_elements.is_mps());
TORCH_CHECK(elements.dtype() == test_elements.dtype());
TORCH_CHECK(
!(!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) && !supportedFloatingType(elements.scalar_type())),
"isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ",
elements.scalar_type());

@autoreleasepool {
string key =
Expand Down
29 changes: 29 additions & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,18 @@ def get_example_inputs(self):
action="store_true",
help="Measure speedup with TorchInductor",
)
group.add_argument(
"--quantization",
choices=[
"int8dynamic",
"int8weightonly",
"int4weightonly",
"autoquant",
"noquant",
],
default=None,
help="Measure speedup of torchao quantization with TorchInductor baseline",
)
group.add_argument(
"--export",
action="store_true",
Expand Down Expand Up @@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None):
if args.inductor:
assert args.backend is None
args.backend = "inductor"
if args.quantization:
assert args.backend is None
args.backend = "torchao"
if args.dynamic_batch_only:
args.dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = True
Expand Down Expand Up @@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None):

# AOTInductor doesn't support control flow yet
runner.skip_models.update(runner.skip_models_due_to_control_flow)
elif args.backend == "torchao":
assert "cuda" in args.devices, "Quantization requires CUDA device."
assert args.bfloat16, "Quantization requires dtype bfloat16."
from .torchao import setup_baseline, torchao_optimize_ctx

setup_baseline()
baseline_ctx = functools.partial(
torch.compile,
backend="inductor",
fullgraph=args.nopython,
mode=args.inductor_compile_mode,
)
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
optimize_ctx = torchao_optimize_ctx(args.quantization)
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = speedup_experiment
Expand Down

0 comments on commit 05d5463

Please sign in to comment.