From 892aa3302778af3207eb792ffc043caeb1c7c4da Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 10 Oct 2025 16:25:57 -0700 Subject: [PATCH] Support aoti_torch_cuda__weight_int4pack_mm Summary: When quantizing a model with 4w_hqq (https://github.com/huggingface/optimum-executorch/pull/164), AOTI-generated code will call aoti_torch_cuda__weight_int4pack_mm as a fallback op. This PR borrows the CUDA implementation of _weight_int4pack_mm_cuda from libtorch, by replacing at::Tensor and relevant utility functions with ET equivalents. Using the Voxtral runner as an example, With the bfloat16 format, here is the generated ptd file size and latency. ``` aoti_cuda_blob.ptd: 9.0 GB Program load latency (ms): 0.054 Method load latency (ms): audio_encoder: 1492.989 token_embedding: 803.561 text_decoder: 6556.770 Run latency (ms): audio_encoder: 76.848 token_embedding: 6.479 text_decoder: 149.128 ``` With `--qlinear 4w_hqq --qlinear_encoder 4w_hqq`, the ptd file size is cut more than half, with slowdowns in the encoder and decoder parts. ``` aoti_cuda_blob.ptd: 3.7 GB Program load latency (ms): 0.051 Method load latency (ms): audio_encoder: 716.667 token_embedding: 633.476 text_decoder: 1840.760 Run latency (ms): audio_encoder: 329.274 token_embedding: 4.285 text_decoder: 335.590 ``` [ghstack-poisoned] --- backends/aoti/common_shims.cpp | 4 + backends/aoti/common_shims.h | 1 + backends/aoti/utils.h | 2 + backends/cuda/CMakeLists.txt | 14 +- backends/cuda/cuda_backend.py | 4 +- backends/cuda/runtime/TARGETS | 3 + backends/cuda/runtime/shims/int4mm.cu | 57 + backends/cuda/runtime/shims/int4mm.cuh | 1311 +++++++++++++++++ backends/cuda/runtime/shims/int4mm.h | 34 + backends/cuda/runtime/shims/tests/targets.bzl | 1 + ...st_aoti_torch_cuda__weight_int4pack_mm.cpp | 342 +++++ backends/cuda/runtime/utils.h | 28 +- 12 files changed, 1796 insertions(+), 5 deletions(-) create mode 100644 backends/cuda/runtime/shims/int4mm.cu create mode 100644 backends/cuda/runtime/shims/int4mm.cuh create mode 100644 backends/cuda/runtime/shims/int4mm.h create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abc83779443..5e205f76325 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -135,6 +135,10 @@ int32_t aoti_torch_dtype_bfloat16() { return 15; // PyTorch's bfloat16 dtype code } +int32_t aoti_torch_dtype_int32() { + return 3; // PyTorch's int32 dtype code +} + int32_t aoti_torch_dtype_int64() { return 4; // PyTorch's int64 dtype code } diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 5f54cd1c878..a50b3691321 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -59,6 +59,7 @@ int32_t aoti_torch_device_type_cpu(); int32_t aoti_torch_layout_strided(); int32_t aoti_torch_dtype_float32(); int32_t aoti_torch_dtype_bfloat16(); +int32_t aoti_torch_dtype_int32(); int32_t aoti_torch_dtype_int64(); // Autograd mode functions diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 78c07bcea6e..fb171e52ffb 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -34,6 +34,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { // Convert based on known PyTorch dtype codes (without CUDA-specific // dependency) switch (dtype) { + case 3: // PyTorch's int32 dtype code + return executorch::aten::ScalarType::Int; case 4: // PyTorch's int64 dtype code return executorch::aten::ScalarType::Long; case 6: // PyTorch's float32 dtype code diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 575f676e4cc..8fea0e3131b 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -38,8 +38,20 @@ find_package_torch() set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp runtime/shims/tensor_attribute.cpp runtime/guard.cpp - runtime/shims/cuda_guard.cpp + runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu ) +# Set default CUDA architectures if not specified (int4mm requires sm_80+) +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES + "80;86;89;90" + CACHE STRING "CUDA architectures" + ) + message( + STATUS + "CMAKE_CUDA_ARCHITECTURES not set, using default: 80;86;89;90 (Ampere+)" + ) + message(STATUS " Override with: cmake -DCMAKE_CUDA_ARCHITECTURES= ...") +endif() add_library(aoti_cuda STATIC ${_aoti_cuda_sources}) target_include_directories( aoti_cuda diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 8ed8cdefbb1..8bc38c6e715 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -28,7 +28,9 @@ from torch.nn.attention import SDPBackend # exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = {} +supported_fallback_kernels: Dict[str, Any] = { + "at::_ops::_weight_int4pack_mm::call": None, +} # required fallback kernels but not supported missing_fallback_kernels: Set[str] = set() diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 54412269287..7cb8baf1041 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -7,12 +7,15 @@ runtime.cxx_library( srcs = [ "guard.cpp", "shims/cuda_guard.cpp", + "shims/int4mm.cu", "shims/memory.cpp", "shims/tensor_attribute.cpp", ], headers = [ "guard.h", "shims/cuda_guard.h", + "shims/int4mm.cuh", + "shims/int4mm.h", "shims/memory.h", "shims/tensor_attribute.h", "utils.h", diff --git a/backends/cuda/runtime/shims/int4mm.cu b/backends/cuda/runtime/shims/int4mm.cu new file mode 100644 index 00000000000..82965e189b4 --- /dev/null +++ b/backends/cuda/runtime/shims/int4mm.cu @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include + +namespace executorch::backends::cuda { +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda__weight_int4pack_mm( + Tensor* self, + Tensor* mat2, + int64_t qGroupSize, + Tensor* qScaleAndZeros, + Tensor** ret0) { + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + mat2 != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: mat2 tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + qScaleAndZeros != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: qScaleAndZeros tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: ret0 is null"); + + *ret0 = _weight_int4pack_mm_cuda(*self, *mat2, qGroupSize, *qScaleAndZeros); + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4mm.cuh b/backends/cuda/runtime/shims/int4mm.cuh new file mode 100644 index 00000000000..fcbb32893fd --- /dev/null +++ b/backends/cuda/runtime/shims/int4mm.cuh @@ -0,0 +1,1311 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This is a clone of aten/src/ATen/native/cuda/int4mm.cu from PyTorch, +// with at::Tensor replaced with ETensor and aten utility functions/macros +// replaced with their executorch equivalents. +// +// In future, we should consider making the PyTorch code generic enough +// to be reusable in executorch. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) +#include +#include +#include +#if !defined(USE_ROCM) +#include +#endif +#endif + +namespace executorch::backends::cuda { +using executorch::backends::aoti::Tensor; + +template +constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return (a / b); +} + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + // Overflow safe variant of (a + b - 1) / b + const uint64_t blocks = a / b + (a % b != 0); + return blocks; +} + +template +constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return divDown(a, b) * b; +} + +template +constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return divUp(a, b) * b; +} + +template +constexpr __host__ __device__ bool isEvenDivisor(U a, V b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return (a % V(b) == 0) && ((a / V(b)) >= 1); +} + +template +constexpr __host__ __device__ T pow(T n, int power) { + return (power > 0 ? n * pow(n, power - 1) : 1); +} + +template +constexpr __host__ __device__ T pow2(int power) { + return pow(2, power); +} + +static_assert(pow2(8) == 256, "pow2"); + +template +constexpr __host__ __device__ int log2(T n, int p = 0) { + return (n <= 1) ? p : log2(n / 2, p + 1); +} + +static_assert(log2(2) == 1, "log2"); +static_assert(log2(3) == 1, "log2"); +static_assert(log2(4) == 2, "log2"); + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + static_assert(std::is_integral_v, ""); + return (v && !(v & (v - 1))); +} + +static_assert(isPowerOf2(2048), "isPowerOf2"); +static_assert(!isPowerOf2(3333), "isPowerOf2"); + +template +constexpr __host__ __device__ T nextHighestPowerOf2(T v) { + static_assert(std::is_integral_v, ""); + return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); +} + +static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + +static_assert( + nextHighestPowerOf2(1536000000u) == 2147483648u, + "nextHighestPowerOf2"); +static_assert( + nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, + "nextHighestPowerOf2"); + +template +constexpr __host__ __device__ T nextLowestPowerOf2(T v) { + static_assert(std::is_integral_v, ""); + return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v)))); +} + +static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2"); + +static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2"); + +inline __host__ __device__ bool isPointerAligned(const void* p, int align) { + return reinterpret_cast(p) % align == 0; +} + +// Returns the increment needed to aligned the pointer to the next highest +// aligned address +template +inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { + static_assert(isPowerOf2(Align), ""); + const uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1)); + return diff == 0 ? 0 : uint32_t(Align) - diff; +} + +#if defined (__gfx90a__) || defined(__gfx942__) +#define CDNA2_OR_LATER 1 +#else +#define CDNA2_OR_LATER 0 +#endif + +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) + +#if defined(USE_ROCM) +// TODO: Support RDNA +constexpr int32_t kWarpSize = 64; + +template +using VecT = T __attribute__((ext_vector_type(Rank))); + +/* + * Not used by ET +static bool isCDNA2orLater(int index) { + return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index); +} +*/ + +#else +constexpr int32_t kWarpSize = 32; +#endif + +// f16 vector types +struct __align__(2) f16x1 { + __half vals[1]; +}; + +struct __align__(4) f16x2 { + __half vals[2]; +}; + +struct __align__(8) f16x4 { + __half vals[4]; +}; + +struct __align__(16) f16x8 { + __half vals[8]; +}; + +// bf16 vector types +struct __align__(2) bf16x1 { + __nv_bfloat16 vals[1]; +}; + +struct __align__(4) bf16x2 { + __nv_bfloat16 vals[2]; +}; + +struct __align__(8) bf16x4 { + __nv_bfloat16 vals[4]; +}; + +struct __align__(16) bf16x8 { + __nv_bfloat16 vals[8]; +}; + +// bf162 vector types +struct __align__(4) bf16x2x1 { + __nv_bfloat162 vals[1]; +}; + +struct __align__(8) bf16x2x2 { + __nv_bfloat162 vals[2]; +}; + +struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; +}; + +struct __align__(16) bf16x2x4_u32 { +#if defined(USE_ROCM) + VecT val[2]; +#else + uint32_t vals[4]; +#endif +}; + +struct __align__(8) bf16x2x2_u32 { +#if defined(USE_ROCM) + VecT val; +#else + uint32_t vals[2]; +#endif +}; + +struct __align__(4) bf16x2x1_u32 { + uint32_t vals[1]; +}; + +template +struct __align__(sizeof(T) * N) VectorType { + T vals[N]; +}; + +// from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. +#if !defined(USE_ROCM) + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; +#endif + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[0]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + +#pragma unroll + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[ii]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + } + + // This is the BF16 {-136, -136} represented as an integer. +#if defined(USE_ROCM) +#if ROCM_VERSION >= 60200 + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); +#else + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); +#endif +#else + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; +#endif + +// Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction +#if defined(USE_ROCM) + result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); +#else + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); +#endif + } + + return result; +} + + + +enum class KReductionType { + // No k-reduction is needed between blocks as the number of k-tiles processed + // per block are exact and we can directly write the output + None, +}; + +// Loads the A matrix in 16-bit standard m x k row major layout, and writes +// the C matrix in 16-bit standard m x n row major layout: +// +// size [m][k] +template +struct ALayout_RM { + static constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + static constexpr int32_t kNTileSize = 16; +#else + static constexpr int32_t kNTileSize = 8; +#endif + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + const void* A, + int32_t m, + int32_t k, + int32_t mTiles, + int32_t mTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, +#if defined(USE_ROCM) + bf16x2x2_u32 out[KTilesToLoad] +#else + bf16x2x4_u32 out[KTilesToLoad] +#endif + ) { +#if defined(USE_ROCM) + const auto mLane = mTile * kMTileSize + (laneId % kMTileSize); + const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4; +#else + const auto mLane = mTile * kMTileSize + (laneId / 4); + const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; +#endif + + // access + // [mTile * kMTileSize + (laneId / 4)] + // [kTileStart * kKTileSize + (laneId % 4) * 2] + auto aPtr = reinterpret_cast(A) + mLane * k + kLane; + bool m0InBounds = mLane < m; + +#if !defined(USE_ROCM) + auto aPtrPlus8Rows = aPtr + 8 * k; + + bool m1InBounds = (mLane + 8) < m; +#endif + +#pragma unroll + for (int i = 0; i < KTilesToLoad; ++i) { +#if defined(USE_ROCM) + out[i].val = m0InBounds ? *((VecT *)(aPtr + i * kKTileSize)) : VecT{0, 0, 0, 0}; +#else + out[i].vals[0] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize) + : uint32_t(0); + out[i].vals[1] = m1InBounds + ? *reinterpret_cast(aPtrPlus8Rows + i * kKTileSize) + : uint32_t(0); + + out[i].vals[2] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize + 8) + : uint32_t(0); + out[i].vals[3] = m1InBounds ? *reinterpret_cast( + aPtrPlus8Rows + i * kKTileSize + 8) + : uint32_t(0); +#endif + } + } + + static __device__ void store( + void* C, + int32_t m, + int32_t n, + int32_t mOutTiles, + int32_t mTile, + int32_t nOutTiles, + int32_t nTile, + int32_t laneId, + const float4& out) { + static_assert(ReduceType == KReductionType::None, ""); + + if constexpr (ReduceType == KReductionType::None) { +#if defined(USE_ROCM) + const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4; + const int outCol = nTile * kNTileSize + (laneId % kNTileSize); +#else + // sum.x / sum.y are written at + // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // sum.z / sum.w are written at + // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // i.e., same columns, different row. + const int outRow = mTile * kMTileSize + (laneId / 4); + const int outCol = nTile * kNTileSize + (laneId % 4) * 2; +#endif + + // Pointer where sum.x / sum.y is written + auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; + +#if defined(USE_ROCM) + if (outRow < m) + cPtr[0] = __float2bfloat16(out.x); + if ((outRow + 1) < m) + cPtr[n] = __float2bfloat16(out.y); + if ((outRow + 2) < m) + cPtr[2*n] = __float2bfloat16(out.z); + if ((outRow + 3) < m) + cPtr[3*n] = __float2bfloat16(out.w); +#else + auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); + auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); + + if (outRow < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01; + } + + // sum.z, sum.w at +8 rows from cPtr + if (outRow + 8 < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; + } +#endif + } + } +}; + +template +struct BLayout_TC_int4 { + static constexpr int32_t kInnerKTiles = InnerKTiles; + static constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + static constexpr int32_t kNTileSize = 16; +#else + static constexpr int32_t kNTileSize = 8; +#endif + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] + // n-tiles: n / 8 for NV, n /16 for AMD + // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD) + // value per warp lane: 32 for NV, 64 for AMD + // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. + // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest + // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a + // uint32x4 (128 bits) + const void* __restrict__ B, + // size [k / qGroupSize][n][2] + // Contains the scale and zero point of each of the quantized int4 values + // within B + // v_reconstructed = (bf16(B_int4_val) * scale) - zero + const void* __restrict__ quantizationInfo, + int32_t n, + int32_t k, + int32_t nTiles, + int32_t nTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) { + // offset [nTile][kTileStart / InnerKTiles][laneId][0] + auto bPtr = reinterpret_cast(B) + + (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) * + kWarpSize) + + laneId) * + (InnerKTiles / 2); + + int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2]; + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { + auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2); + + if constexpr (InnerKTiles == 2) { + b_int4[i][0] = bPtrCur[0]; + } + + if constexpr (InnerKTiles == 4) { + // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]) + // : "l"(bPtrCur)); + + int2 load8 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load8.x; + b_int4[i][1] = load8.y; + } + + if constexpr (InnerKTiles == 8) { + // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), + // "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur)); + + int4 load16 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load16.x; + b_int4[i][1] = load16.y; + b_int4[i][2] = load16.z; + b_int4[i][3] = load16.w; + } + } + + // Load needed info for dequantization + + static_assert(isPowerOf2(QGroupSize), ""); + static_assert(isEvenDivisor(QGroupSize, kKTileSize), ""); + // smallest quantization group size is 32 (2 k-tiles are packed in an int32) + static_assert(QGroupSize >= kKTileSize * 2, ""); + constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize); + // a q-group could be larger than what we are handling in a single warp + constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1 + ? 1 + : (KTilesToLoad / kKTilesPerQGroup); + + __nv_bfloat162 qScaleAndZero[kNumQGroups]; + { +#if defined(USE_ROCM) + int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize); +#else + int32_t laneN = nTile * kNTileSize + (laneId / 4); +#endif + int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; + + int32_t n = nTiles * kNTileSize; + + // offset [qScale_kGroup][qScale_n][0] + auto qInfoPtr = reinterpret_cast(quantizationInfo) + + (groupStart * n + laneN) * 2; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScaleAndZero[i] = + *reinterpret_cast(qInfoPtr + i * n * 2); + } + } + + // + // De-quantize int4 values to bf16. Values are dequantized as truly int4 + // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + { + // FIXME: does this negatively affect register counts, or will nvcc + // move this expansion (and data loads above) closer to the point of use? + __nv_bfloat162 qScale[kNumQGroups]; + __nv_bfloat162 qZero[kNumQGroups]; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x); + qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y); + } + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { +#pragma unroll + for (int j = 0; j < InnerKTiles / 2; ++j) { + bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]); + + int curKTile = i * InnerKTiles + j * 2; + int curQGroup = (curKTile * kKTileSize) / QGroupSize; + + // The dequantized values in `v` for a given lane have the same n + // dimension (the B tensor core layout has all values in the same + // thread along the same n) but different k dimension, but all are + // guaranteed to occur within the same quantization group, so we need + // only load a single scale + zero to cover what this lane has +#pragma unroll + for (int k = 0; k < 4; ++k) { + v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]); + } + + // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and + // can't be used as a 32-bit asm register argument for `mma` + static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), ""); + std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32)); + } + } + } + } +}; + +template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerIteration> +__global__ +__launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( + // Data for the A matrix, loaded as per ALayout + const void* const __restrict__ A, + + // Data for the B matrix, loaded as per BLayout + const void* const __restrict__ B, + + // Optional quantization data for dequantizing B, loaded as per BLayout + const void* const __restrict__ B_quantizationInfo, + + // Output data for the C matrix, stored as per CLayout + void* __restrict__ C, + + // The size of the matrix multiplication + int32_t m, + int32_t n, + int32_t k, + + // The size of the matrix multiplication, in multiples of our TC tile size + int32_t mTiles, + int32_t nTiles, + int32_t kTiles) { + constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else + constexpr int32_t kNTileSize = 8; +#endif + constexpr int32_t kKTileSize = 16; + +#if !defined(USE_ROCM) || CDNA2_OR_LATER + + static_assert( + ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && + ALayout::kKTileSize == kKTileSize, + ""); + + static_assert( + BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize && + BLayout::kKTileSize == kKTileSize, + ""); + + static_assert( + CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize && + CLayout::kKTileSize == kKTileSize, + ""); + + constexpr int kInnerKTiles = BLayout::kInnerKTiles; + + // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads + static_assert( + kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, ""); + + // We always process at least kInnerKTiles k-tiles back to back in a warp + static_assert( + KTilesPerIteration >= kInnerKTiles && + isEvenDivisor(KTilesPerIteration, kInnerKTiles), + ""); + + auto warpId = threadIdx.y; + auto laneId = threadIdx.x; + + int32_t mTile = blockIdx.z; + int32_t nTile = blockIdx.y; + +#if defined(USE_ROCM) + VecT c{0.0f, 0.0f, 0.0f, 0.0f}; +#else + float4 c{0.0f, 0.0f, 0.0f, 0.0f}; +#endif + + // First, handle whole multiples of KTilesPerIteration + auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); + + // Each warp handles a set of KTilesPerIteration under the above limit + for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration; + kTileBase < kTilesLimit; + kTileBase += Warps * KTilesPerIteration) { + // + // Load data from A + // +#if defined(USE_ROCM) + bf16x2x2_u32 a[KTilesPerIteration]; +#else + bf16x2x4_u32 a[KTilesPerIteration]; +#endif + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); + + // + // Load data from B and de-quantize as needed + // Each k-tile is bf16x2x2 + // + bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBase, + laneId, + b); + + // + // Now, perform the matrix multiplication + // + + // We accumulate across k-tiles here +#pragma unroll + for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) { + static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, ""); +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` +#if defined(USE_ROCM) + VecT cTmp[2]; +#else + float4 cTmp[2]; +#endif + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; +#else + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + a[i * kInnerKTiles + j * 2 + k].val, + b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], + cTmp[k], 0, 0, 0); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), + "=f"(cTmp[k].y), + "=f"(cTmp[k].z), + "=f"(cTmp[k].w) + : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + c[0] += cTmp[k][0]; + c[1] += cTmp[k][1]; + c[2] += cTmp[k][2]; + c[3] += cTmp[k][3]; +#else + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; +#endif + } + } + } + } // for all tiles under kTilesLimit + + // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles + // remaining. We guarantee that the number of warps is >= KTilesPerIteration / + // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its + // thing without needing more warps + static_assert(Warps >= KTilesPerIteration / kInnerKTiles, ""); + + auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles; + + // If we have any remainder k-tiles, some warps will handle them, processing + // kInnerKTiles k-tiles at a time + if (kTileBaseRemaining < kTiles) { +#if defined(USE_ROCM) + bf16x2x2_u32 a[kInnerKTiles]; +#else + bf16x2x4_u32 a[kInnerKTiles]; +#endif + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); + + bf16x2x4_u32 b[1][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBaseRemaining, + laneId, + b); + +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` +#if defined(USE_ROCM) + VecT cTmp[2]; +#else + float4 cTmp[2]; +#endif + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; +#else + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + a[j * 2 + k].val, + b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], + cTmp[k], 0, 0, 0); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w) + : "r"(a[j * 2 + k].vals[0]), + "r"(a[j * 2 + k].vals[1]), + "r"(a[j * 2 + k].vals[2]), + "r"(a[j * 2 + k].vals[3]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + c[0] += cTmp[k][0]; + c[1] += cTmp[k][1]; + c[2] += cTmp[k][2]; + c[3] += cTmp[k][3]; +#else + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; +#endif + } + } + } + + // + // Reduce independent k-tiles (same m/n) across warps + // + __shared__ float4 smem_sum[Warps][kWarpSize]; + + // FIXME: this likely doesn't need to be a true reduction tree, can just be a + // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) + // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); +#if defined(USE_ROCM) + smem_sum[warpId][laneId].x = c[0]; + smem_sum[warpId][laneId].y = c[1]; + smem_sum[warpId][laneId].z = c[2]; + smem_sum[warpId][laneId].w = c[3]; +#else + smem_sum[warpId][laneId] = c; +#endif + + __syncthreads(); + + if (warpId == 0) { + float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f}; + + // Reduce across the block in the first warp + for (int i = 0; i < Warps; ++i) { + float4 v = smem_sum[i][laneId]; + sum_f32.x += v.x; + sum_f32.y += v.y; + sum_f32.z += v.z; + sum_f32.w += v.w; + } + + // Write the reduced result (in the first warp) into the output + CLayout::store( + C, + m, + n, + mTiles, + mTile, + // n for C output becomes k for A input, so for m16n8k16, + // we need to halve the tiles + nTiles / 2, + nTile, + laneId, + sum_f32); + } +#else + printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n"); +#endif +} + +template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerWarp> +void launch_tinygemm_kernel( + const Tensor& A, + const Tensor& B, + const Tensor* qScaleAndZeros, /* optional */ + Tensor& C_final, + int32_t mTiles, + int32_t nTiles, + int32_t kTiles, + int32_t m, + int32_t n, + int32_t k, + cudaStream_t stream) { + // The chunking kernel requires that kTiles is a multiple of kInnerKTiles + ET_CHECK( + kTiles >= BLayout::kInnerKTiles && + isEvenDivisor(kTiles, BLayout::kInnerKTiles)); + + ET_CHECK( + KTilesPerWarp >= BLayout::kInnerKTiles && + isEvenDivisor(KTilesPerWarp, BLayout::kInnerKTiles)); + + // After intra-block reduction across the k dimension, we are left with this + // many tiles + // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp); + int32_t postKernelKTiles = 1; // we loop + + auto grid = dim3(postKernelKTiles, nTiles, mTiles); + auto block = dim3(kWarpSize, Warps); + + auto func = + tinygemm_m16n8k16_chunk_kernel; + + func<<>>( + A.data_ptr(), + B.data_ptr(), + qScaleAndZeros ? qScaleAndZeros->data_ptr() : nullptr, + C_final.data_ptr(), + m, + n, + k, + mTiles, + nTiles, + kTiles); + + ET_CUDA_KERNEL_LAUNCH_CHECK(); + + cudaFuncAttributes funcAttr; +#if defined(USE_ROCM) + ET_CUDA_CHECK(cudaFuncGetAttributes(&funcAttr, (void *)func)); +#else + ET_CUDA_CHECK(cudaFuncGetAttributes(&funcAttr, func)); +#endif +} + +/* + * Not used by ET +// FIXME: parallelize better, smem staging etc? +template +__global__ void matrix_to_m16n8k16_Bint4_layout( + // size [n][k / 2] + const at::PackedTensorAccessor32 in, + // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] + at::PackedTensorAccessor32 out) { + // int4 values are packed into int32 values, which require at least 8. Given + // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of + // innermost k-tiles that we can use is 2. + static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); + +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else + constexpr int32_t kNTileSize = 8; +#endif + constexpr int32_t kKTileSize = 16; + + // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // Two k-tiles are packed into an int32 at a time +#pragma unroll + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + // n dimension that this lane loads from +#if defined(USE_ROCM) + auto n0 = nTile * kNTileSize + (t % kNTileSize); +#else + auto n0 = nTile * kNTileSize + (t / 4); +#endif + + bool n0Valid = n0 < in.size(0); + + // Four uint8 are packed into an int32 + int32_t ks[4]; + + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize / 2; + +#if defined(USE_ROCM) + ks[0] = kBase0 + (t / kNTileSize) * 2; + ks[1] = ks[0] + 1; + + auto kBase1 = kBase0 + kKTileSize / 2; + ks[2] = kBase1 + (t / kNTileSize) * 2; + ks[3] = ks[2] + 1; +#else + ks[0] = kBase0 + t % 4; + ks[1] = ks[0] + 4; + + auto kBase1 = kBase0 + kKTileSize / 2; + ks[2] = kBase1 + t % 4; + ks[3] = ks[2] + 4; +#endif + + auto pIn = &in[n0][0]; + + uint8_t v[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + v[i] = (n0Valid && ks[i] < in.size(1)) ? pIn[ks[i]] : uint8_t(0); + } + + // To clearly explain the packed result with 8 int4 values (4 uint8) + // into one int32, we use the follow figure: + // [n][k] int32: v[0] v[1] v[2] v[3] v[4] v[5] v[6] v[7] + // [n][k / 2] uint8: v[0] v[1] v[2] v[3] + // When using int32 weight as input, the packed result is consisted of + // v[7] | v[5] | v[3] | v[1] | v[6] | v[4] | v[2] | v[0], + // which epuals to + // v[3]L | v[2]L | v[1]L | v[0]L | v[3]H | v[2]H | v[1]H | v[0]H + // when using uint8 weight as input. + int32_t pack = ((uint32_t)(v[3] & 0xF) << 28) | + ((uint32_t)(v[2] & 0xF) << 24) | ((uint32_t)(v[1] & 0xF) << 20) | + ((uint32_t)(v[0] & 0xF) << 16) | ((uint32_t)(v[3] & 0xF0) << 8) | + ((uint32_t)(v[2] & 0xF0) << 4) | ((uint32_t)(v[1] & 0xF0)) | + ((uint32_t)(v[0] & 0xF0) >> 4); + + // inner k-tiles pack two at a time +#if defined(USE_ROCM) + // The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia + // But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2] + // So construct the pointer accordingly + auto bPtr = out.data() + + ((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) + + (kOuterTile * kWarpSize * (InnerKTiles / 2)) + + (t * (InnerKTiles / 2)) + + (innerKTile / 2)); + *bPtr = pack; +#else + out[nTile][kOuterTile][t][innerKTile / 2] = pack; +#endif + } +} +*/ + +#endif // defined(USE_ROCM) || CUDA_VERSION >= 12000 + + +Tensor* _weight_int4pack_mm_cuda( + const Tensor& A, + const Tensor& B, + int64_t qGroupSize, + const Tensor& qScaleAndZeros) { + // Skip CUDAGuard because ETensor doesn't carry device information + // auto result = CUDAGuard::create(0); + + // Skip device check because ETensor doesn't carry device information + // ET_CHECK( + // A.device() == B.device() && A.device() == qScaleAndZeros.device()); + +#if defined(USE_ROCM) + if (!isCDNA2orLater(A.device().index())) { + ET_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); + } +#endif + + constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else + constexpr int32_t kNTileSize = 8; +#endif + constexpr int32_t kKTileSize = 16; + + // row major layout + auto m = A.size(0); + auto mTiles = divUp(m, kMTileSize); + + // To convert the nTiles from tensor storage layout to the actual matrix core layout + constexpr int32_t kNTileSizeTensor = 8; + auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor); + + // tensor core layout + auto nTiles = (B.size(0) / nTileScaleFactor); + auto n = nTiles * kNTileSize; + + // row major layout + auto k = A.size(1); + auto kTiles = divUp(k, kKTileSize); + + // The number of inner k tiles is the innermost dimension of times 2 + // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4 + // packed into 1 int32 for int4 B + auto B_innerKTiles = B.size(3) * 2; + ET_CHECK(B_innerKTiles == 2 || B_innerKTiles == 4 || B_innerKTiles == 8); + + // A is standard row major + ET_CHECK(A.dtype() == executorch::aten::ScalarType::BFloat16); + // ET only supports contiguous tensors for now + // ET_CHECK(A.is_contiguous()); + ET_CHECK(A.dim() == 2); + + // B has B_innerKTiles k-tiles in the innermost dimension + ET_CHECK(B.dtype() == executorch::aten::ScalarType::Int); + // ET only supports contiguous tensors for now + // ET_CHECK(B.is_contiguous()); + ET_CHECK(B.dim() == 4); + ET_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize)); + ET_CHECK(B.size(2) == 32); + + // Validate the scale and zero point tensor for dequantization + // These are the only versions handled at the moment + ET_CHECK( + qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || + qGroupSize == 256); + + ET_CHECK(qScaleAndZeros.dim() == 3); + auto numQGroups = qScaleAndZeros.size(0); + ET_CHECK( + kTiles * kKTileSize >= qGroupSize && + isEvenDivisor(kTiles * kKTileSize, qGroupSize)); + ET_CHECK(qScaleAndZeros.size(1) == n); + ET_CHECK(qScaleAndZeros.size(2) == 2); + + // Output is a standard row-major matrix + Tensor* C_final = nullptr; + std::array shape = {m, n}; + std::array stride = {n, 1}; + aoti_torch_empty_strided( + 2, + shape.data(), + stride.data(), + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, + &C_final + ); + +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_MSG(stream_result.ok(), "Failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); +#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ + do { \ + using ACLayout = ALayout_RM; \ + \ + ET_CHECK( \ + K_TILES_PER_WARP >= B_innerKTiles && \ + isEvenDivisor(K_TILES_PER_WARP, B_innerKTiles)); \ + \ + switch (B_innerKTiles) { \ + case 2: \ + if constexpr (K_TILES_PER_WARP >= 2) { \ + using BLayout = BLayout_TC_int4<2, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + A, \ + B, \ + &qScaleAndZeros, \ + *C_final, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 4: \ + if constexpr (K_TILES_PER_WARP >= 4) { \ + using BLayout = BLayout_TC_int4<4, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + A, \ + B, \ + &qScaleAndZeros, \ + *C_final, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 8: \ + if constexpr (K_TILES_PER_WARP >= 8) { \ + using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + A, \ + B, \ + &qScaleAndZeros, \ + *C_final, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + default: \ + break; \ + } \ + } while (false) + +#define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \ + do { \ + switch (qGroupSize) { \ + case 32: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE); \ + break; \ + case 64: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE); \ + break; \ + case 128: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \ + break; \ + case 256: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \ + break; \ + } \ + } while (false) + + HANDLE_Q_GROUP(8, 8, KReductionType::None); + +#undef HANDLE_Q_GROUP +#undef RUN_GEMM + + return C_final; +#endif + ET_CHECK_MSG(false, "_weight_int4pack_mm_cuda is not available for build."); + return C_final; +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4mm.h b/backends/cuda/runtime/shims/int4mm.h new file mode 100644 index 00000000000..49e28046f58 --- /dev/null +++ b/backends/cuda/runtime/shims/int4mm.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda__weight_int4pack_mm( + Tensor* self, + Tensor* mat2, + int64_t qGroupSize, + Tensor* qScaleAndZeros, + Tensor** ret0); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 70f27b86bec..34a9d60582f 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -33,3 +33,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor") cuda_shim_cpp_unittest("aoti_torch_copy_") cuda_shim_cpp_unittest("aoti_torch_cuda_guard") + cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp new file mode 100644 index 00000000000..1b59fc1abdb --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp @@ -0,0 +1,342 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; + +// Test fixture for aoti_torch_cuda__weight_int4pack_mm tests +class AOTITorchInt4MMTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Check if GPU supports sm_80+ (required for int4mm) + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int compute_capability = prop.major * 10 + prop.minor; + if (compute_capability < 80) { + GTEST_SKIP() << "GPU compute capability " << compute_capability + << " < 80 (Ampere+), int4mm requires sm_80+"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a BFloat16 tensor + Tensor* create_bfloat16_tensor(const std::vector& sizes) { + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // default strides + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } + + // Helper to create an Int32 tensor + Tensor* create_int32_tensor(const std::vector& sizes) { + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // default strides + static_cast(SupportedDTypes::INT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic int4mm functionality with minimal valid inputs +TEST_F(AOTITorchInt4MMTest, BasicFunctionality) { + // Create input tensor A: [m, k] = [2, 128] in BFloat16 + int64_t m = 2; + int64_t k = 128; + int64_t n = 64; + int64_t qGroupSize = 128; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr) << "Failed to create input tensor A"; + + // Create weight tensor B (int4 packed): [n/8, k/(innerKTiles*16), 32, 4] in + // Int32 For int4mm, innerKTiles is typically 8, so k/(8*16) = 128/128 = 1 + int64_t B_innerKTiles = 8; + int64_t B_kTiles = k / (B_innerKTiles * 16); + Tensor* B = create_int32_tensor({n / 8, B_kTiles, 32, 4}); + ASSERT_NE(B, nullptr) << "Failed to create weight tensor B"; + + // Create scale and zeros tensor: [k/qGroupSize, n, 2] in BFloat16 + // For k=128, qGroupSize=128, k/qGroupSize=1 + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr) + << "Failed to create qScaleAndZeros tensor"; + + // Create output tensor: [m, n] in BFloat16 + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr) << "Failed to create output tensor"; + + printf("Testing int4mm with shapes:\n"); + printf(" A: [%ldx%ld] BFloat16\n", m, k); + printf(" B: [%ldx%ldx32x4] Int32\n", n / 8, B_kTiles); + printf(" qScaleAndZeros: [%ldx%ldx2] BFloat16\n", k / qGroupSize, n); + printf(" qGroupSize: %ld\n", qGroupSize); + printf(" Output: [%ldx%ld] BFloat16\n", m, n); + + // Call int4mm + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + + // Check if the function succeeded + EXPECT_EQ(error, Error::Ok) << "int4mm operation should succeed"; + + // Verify output tensor properties + EXPECT_EQ(output->dim(), 2); + EXPECT_EQ(output->size(0), m); + EXPECT_EQ(output->size(1), n); + + printf("int4mm test passed successfully!\n"); +} + +// Test with different qGroupSize values +TEST_F(AOTITorchInt4MMTest, DifferentQGroupSizes) { + int64_t m = 4; + int64_t k = 256; + int64_t n = 128; + int64_t B_innerKTiles = 8; + + // Test qGroupSize = 64 + { + int64_t qGroupSize = 64; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with qGroupSize=64\n"); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::Ok) << "int4mm with qGroupSize=64 should succeed"; + } + + // Test qGroupSize = 128 + { + int64_t qGroupSize = 128; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with qGroupSize=128\n"); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::Ok) << "int4mm with qGroupSize=128 should succeed"; + } + + // Test qGroupSize = 256 + { + int64_t qGroupSize = 256; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with qGroupSize=256\n"); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::Ok) << "int4mm with qGroupSize=256 should succeed"; + } +} + +// Test error handling with null inputs +TEST_F(AOTITorchInt4MMTest, NullInputHandling) { + int64_t m = 2; + int64_t k = 128; + int64_t n = 64; + int64_t qGroupSize = 128; + int64_t B_innerKTiles = 8; + + Tensor* A = create_bfloat16_tensor({m, k}); + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + Tensor* output = create_bfloat16_tensor({m, n}); + + // Test null A + { + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + nullptr, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null A tensor"; + } + + // Test null B + { + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, nullptr, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null B tensor"; + } + + // Test null qScaleAndZeros + { + AOTITorchError error = + aoti_torch_cuda__weight_int4pack_mm(A, B, qGroupSize, nullptr, &output); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null qScaleAndZeros tensor"; + } + + // Test null output pointer + { + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, nullptr); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null output pointer"; + } + + // Test null output tensor (ret0 points to null) + { + Tensor* null_output = nullptr; + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &null_output); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null output tensor"; + } +} + +// Test with larger batch size +TEST_F(AOTITorchInt4MMTest, LargerBatchSize) { + int64_t m = 16; // Batch size + int64_t k = 256; + int64_t n = 128; + int64_t qGroupSize = 128; + int64_t B_innerKTiles = 8; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with larger batch: m=%ld\n", m); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + + EXPECT_EQ(error, Error::Ok) << "int4mm with larger batch should succeed"; + EXPECT_EQ(output->size(0), m); + EXPECT_EQ(output->size(1), n); +} + +// Test with larger tensors +TEST_F(AOTITorchInt4MMTest, LargerTensors) { + int64_t m = 8; + int64_t k = 512; + int64_t n = 256; + int64_t qGroupSize = 128; + int64_t B_innerKTiles = 8; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf( + "Testing int4mm with larger tensors: [%ldx%ld] x [weight] -> [%ldx%ld]\n", + m, + k, + m, + n); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + + EXPECT_EQ(error, Error::Ok) << "int4mm with larger tensors should succeed"; + EXPECT_EQ(output->dim(), 2); + EXPECT_EQ(output->size(0), m); + EXPECT_EQ(output->size(1), n); +} diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index 2d805724090..90aa07a6333 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -14,7 +14,7 @@ #include #include -// CUDA error checking macro +// CUDA error checking macro (with return) #define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ do { \ const cudaError_t err = EXPR; \ @@ -30,14 +30,34 @@ return Error::Internal; \ } while (0) -// Kernel launch check macro +// CUDA error checking macro (without return, for use in void functions) +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t err = EXPR; \ + if (err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err)); \ + ET_CHECK_MSG(false, "CUDA error: %s", cudaGetErrorString(err)); \ + } while (0) + +// Kernel launch check macro (with return) #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) +// Kernel launch check macro (without return, for use in void functions) +#define ET_CUDA_KERNEL_LAUNCH_CHECK() ET_CUDA_CHECK(cudaGetLastError()) + namespace executorch::backends::cuda { // Enum for supported data types in et-cuda backend enum class SupportedDTypes : int32_t { + INT32 = 3, // PyTorch's int64 dtype code INT64 = 4, // PyTorch's int64 dtype code FLOAT32 = 6, // PyTorch's float32 dtype code BFLOAT16 = 15, // PyTorch's bfloat16 dtype code @@ -99,6 +119,7 @@ using AOTITorchError = Error; // Helper function to check if a dtype is supported in ET CUDA backend inline bool is_dtype_supported_in_et_cuda(int32_t dtype) { switch (dtype) { + case static_cast(SupportedDTypes::INT32): case static_cast(SupportedDTypes::INT64): case static_cast(SupportedDTypes::FLOAT32): case static_cast(SupportedDTypes::BFLOAT16): @@ -113,8 +134,9 @@ inline AOTITorchError validate_dtype(int32_t dtype) { ET_CHECK_OR_RETURN_ERROR( is_dtype_supported_in_et_cuda(dtype), InvalidArgument, - "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", + "Unsupported dtype: %d. Supported dtypes: %d (int32), %d (int64), %d (float32), %d (bfloat16)", dtype, + static_cast(SupportedDTypes::INT32), static_cast(SupportedDTypes::INT64), static_cast(SupportedDTypes::FLOAT32), static_cast(SupportedDTypes::BFLOAT16));