Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions hgemm/hgemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,8 @@ void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
// from hgemm_cublas.cu
void hgemm_cublas_tensor_op_row_major(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
// from hgemm_wmma.cu
void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
Expand All @@ -1018,6 +1019,9 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
// from hgemm_mma_stage_tn.cu
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// CUDA Cores FP16
Expand All @@ -1037,7 +1041,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf)
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async)
// cuBLAS Tensor Cores
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_row_major)
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_nn)
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_tn)
// WMMA API Tensor Cores
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_naive)
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2)
Expand All @@ -1056,5 +1061,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages)
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem)
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem)
// TN: A row major MxK, B col major NxK, C row major MxN
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn)
}

20 changes: 17 additions & 3 deletions hgemm/hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ def get_args():
parser.add_argument("--iters", "--i", type=int, default=10, help="Benchmark iters")
parser.add_argument("--show-all", "--show", action="store_true", help="Show all matrix values ")
parser.add_argument("--enable-mma", "--mma", action="store_true", help="Enable MMA kernel tests")
parser.add_argument("--enable-mma-tn", "--mma-tn", action="store_true", help="Enable TN MMA kernel tests")
parser.add_argument("--enable-wmma", "--wmma", action="store_true", help="Enable WMMA kernel tests")
parser.add_argument("--enable-cuda", "--cuda", action="store_true", help="Enable CUDA kernel tests")
parser.add_argument("--enable-mma-all", "--mma-all", action="store_true", help="Enable all MMA kernel tests")
parser.add_argument("--enable-wmma-all", "--wmma-all", action="store_true", help="Enable all WMMA kernel tests")
parser.add_argument("--enable-cuda-all", "--cuda-all", action="store_true", help="Enable all CUDA kernel tests")
parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul")
parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm")
parser.add_argument("--disable-cublas-tn", "--no-cublas-tn", action="store_true", help="Disable cublas TN hgemm")
parser.add_argument("--sleep-duration", "--sleep", type=float, default=0.1, help="Sleep duration")
parser.add_argument("--swizzle-factor", "--swizzle", type=float, default=0.25, help="Swizzle factor")
return parser.parse_args()
Expand All @@ -35,7 +37,8 @@ def get_args():
lib = load(name='hgemm_lib',
sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu',
'hgemm_mma.cu', 'hgemm_mma_stage.cu'],
'hgemm_mma.cu', 'hgemm_mma_stage.cu',
'hgemm_mma_stage_tn.cu'],
extra_cuda_cflags=[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
Expand Down Expand Up @@ -65,6 +68,8 @@ def run_benchmark(perf_func: callable,
M = a.size(0)
K = a.size(1)
N = b.size(1)
if 'tn' in tag:
N = b.size(0)
if swizzle:
# make swizzle stride as N/4 or N/2 and multiples of 256
swizzle_stride = int((int(N * args.swizzle_factor) // 256) * 256)
Expand Down Expand Up @@ -217,8 +222,17 @@ def run_benchmark(perf_func: callable,
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
if not args.disable_cublas:
run_benchmark(lib.hgemm_cublas_tensor_op_row_major, a, b, "(cublas)", c)
if (not args.disable_cublas) and any((
args.enable_mma, args.enable_mma_all, args.enable_wmma, args.enable_wmma_all,
args.enable_cuda, args.enable_cuda_all, args.enable_torch)):
run_benchmark(lib.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c)
if args.enable_mma_tn:
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem)", c, stages=2)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
if not args.disable_cublas_tn:
run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b.transpose(1, 0), "tn(cublas)", c)
if args.enable_torch:
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
torch.cuda.synchronize()
Expand Down
57 changes: 49 additions & 8 deletions hgemm/hgemm_cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

#include "cublas_v2.h"

void cublas_tensor_op_row_major(half *A, half *B, half *C, size_t M,
size_t N, size_t K) {
// NN: A/B/C All row major
void cublas_tensor_op_nn(half *A, half *B, half *C, size_t M, size_t N, size_t K) {

static cublasHandle_t handle = nullptr;
cublasCreate(&handle);
Expand All @@ -36,11 +36,33 @@ void cublas_tensor_op_row_major(half *A, half *B, half *C, size_t M,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);

// why this line will make cublas slow down?
// cublasDestroy(handle);
}

// TODO: add cublas_tensor_op_col_major
// TN: A row major MxK, B col major NxK, C row major MxN
void cublas_tensor_op_tn(half *A, half *B, half *C, size_t M, size_t N, size_t K) {

static cublasHandle_t handle = nullptr;
cublasCreate(&handle);
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);

static half alpha = 1.0;
static half beta = 0.0;

cublasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
N, M, K,
&alpha,
B, CUDA_R_16F, K,
A, CUDA_R_16F, K,
&beta,
C, CUDA_R_16F, N,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);

// cublasDestroy(handle);
}

// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
Expand All @@ -58,8 +80,8 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
throw std::runtime_error("Tensor size mismatch!"); \
}

// cublas tensor op with row major B matrix
void hgemm_cublas_tensor_op_row_major(
// NN: A/B/C All row major
void hgemm_cublas_tensor_op_nn(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
Expand All @@ -71,12 +93,31 @@ void hgemm_cublas_tensor_op_row_major(
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)

cublas_tensor_op_row_major(
cublas_tensor_op_nn(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}

// TODO: add cublas_tensor_op_col_major
// TN: A row major MxK, B col major NxK, C row major MxN
void hgemm_cublas_tensor_op_tn(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(0);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, N, K)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)

cublas_tensor_op_tn(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}
57 changes: 0 additions & 57 deletions hgemm/hgemm_mma_stage_col_major.cu

This file was deleted.

Loading