Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ROCm support #7

Merged
merged 2 commits into from Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion cuda_ext.py
Expand Up @@ -53,8 +53,10 @@ def find_msvc():
os.path.join(library_dir, "exllama_ext/cuda_func/q4_mlp.cu"),
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp")
],
extra_include_paths = [os.path.join(library_dir, "exllama_ext")],
verbose = verbose,
extra_ldflags = ["cublas.lib"] if windows else []
extra_ldflags = ["cublas.lib"] if windows else [],
extra_cuda_cflags = ["-U__HIP_NO_HALF_CONVERSIONS__"] if torch.version.hip else []
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
)

Expand Down
4 changes: 2 additions & 2 deletions exllama_ext/cuda_compat.cuh
Expand Up @@ -41,8 +41,8 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)

//

#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ < 700
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)

__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
Expand Down
7 changes: 7 additions & 0 deletions exllama_ext/cuda_func/column_remap.cu
@@ -1,7 +1,14 @@
#include "column_remap.cuh"
#include "../util.cuh"

// Using 1024 make me crash with "Memory access fault by GPU node-1 (Agent
// handle: 0x012345678912) on address 0x012345678912. Reason: Page not present
// or supervisor privilege."
#if defined(USE_ROCM)
const int SHUF_BLOCKSIZE_X = 256;
#else
const int SHUF_BLOCKSIZE_X = 1024;
#endif
const int SHUF_BLOCKSIZE_Y = 16;

__global__ void column_remap_kernel
Expand Down
3 changes: 3 additions & 0 deletions exllama_ext/cuda_func/half_matmul.cu
Expand Up @@ -2,6 +2,9 @@
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif

// Block size

Expand Down
6 changes: 6 additions & 0 deletions exllama_ext/cuda_func/half_matmul.cuh
Expand Up @@ -6,6 +6,12 @@
#include <cstdint>
#include <ATen/cuda/CUDAContext.h>

// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif

void half_matmul_cuda
(
const half* x,
Expand Down
3 changes: 3 additions & 0 deletions exllama_ext/cuda_func/q4_matmul.cu
Expand Up @@ -4,6 +4,9 @@
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif

const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
Expand Down
6 changes: 6 additions & 0 deletions exllama_ext/cuda_func/q4_matmul.cuh
Expand Up @@ -10,6 +10,12 @@
#include "q4_matrix.cuh"
#include "../tuning.h"

// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif

void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
Expand Down
3 changes: 3 additions & 0 deletions exllama_ext/cuda_func/q4_mlp.cu
Expand Up @@ -4,6 +4,9 @@
#include "../cuda_buffers.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif

const int THREADS_X = 32;
const int THREADS_Y = 4;
Expand Down
45 changes: 45 additions & 0 deletions exllama_ext/hip_compat.cuh
@@ -0,0 +1,45 @@
#ifndef _hip_compat_cuh
#define _hip_compat_cuh

// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}

__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}

#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp

// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}

#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm

#endif
4 changes: 4 additions & 0 deletions exllama_ext/util.cuh
Expand Up @@ -6,7 +6,11 @@
#include <cstdint>
#include <cstdio>

#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif

// React to failure on return code != cudaSuccess

Expand Down
4 changes: 3 additions & 1 deletion model_init.py
@@ -1,6 +1,7 @@
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
import argparse, sys, os, glob
from torch import version as torch_version

def add_args(parser):

Expand All @@ -23,11 +24,12 @@ def add_args(parser):
parser.add_argument("-mmnh2", "--matmul_no_half2", action = "store_true", help = "Don't use half2 in Q4 matmul kernel")
parser.add_argument("-snh2", "--silu_no_half2", action = "store_true", help = "Don't use half2 in SiLU kernel")
parser.add_argument("-nh2", "--no_half2", action = "store_true", help = "(All of the above) disable half2 in all kernela")
parser.add_argument("-fh2", "--force_half2", action = "store_true", help = "Force enable half2 even if unsupported")


def post_parse(args):

if args.no_half2:
if args.no_half2 or torch_version.hip and not args.force_half2:
args.rmsnorm_no_half2 = True
args.rope_no_half2 = True
args.matmul_no_half2 = True
Expand Down