From 25bd2b6ab0cc68b80c12e6fc1678c15bd4b90b8e Mon Sep 17 00:00:00 2001 From: flozi00 Date: Mon, 20 May 2024 22:00:04 +0200 Subject: [PATCH 01/32] start porting latest tgi --- Dockerfile.dev | 19 + server/Makefile-awq | 15 + server/Makefile-eetq | 2 +- server/Makefile-flash-att | 4 +- server/Makefile-flash-att-v2 | 36 +- server/Makefile-vllm | 28 +- .../exllama_kernels/cu_compat.cuh | 58 ++ .../exllama_kernels/cuda_buffers.cu | 71 ++ .../exllama_kernels/cuda_buffers.cuh | 52 ++ .../exllama_kernels/cuda_func/column_remap.cu | 61 ++ .../cuda_func/column_remap.cuh | 19 + .../exllama_kernels/cuda_func/q4_matmul.cu | 256 ++++++ .../exllama_kernels/cuda_func/q4_matmul.cuh | 37 + .../exllama_kernels/cuda_func/q4_matrix.cu | 220 ++++++ .../exllama_kernels/cuda_func/q4_matrix.cuh | 53 ++ .../exllama_kernels/exllama_ext.cpp | 253 ++++++ .../exllama_kernels/hip_compat.cuh} | 28 +- .../exllama_kernels/matrix.cuh | 294 +++++++ .../exllama_kernels/exllama_kernels/tuning.h | 13 + .../exllama_kernels/exllama_kernels/util.cuh | 33 + server/exllama_kernels/setup.py | 19 + .../exllamav2_kernels/cpp/quantize_func.cpp | 52 -- .../exllamav2_kernels/cpp/quantize_func.h | 25 - .../exllamav2_kernels/cpp/safetensors.cpp | 372 --------- .../exllamav2_kernels/cpp/safetensors.h | 46 -- .../exllamav2_kernels/cpp/sampling.cpp | 730 ------------------ .../exllamav2_kernels/cpp/sampling.h | 138 ---- .../exllamav2_kernels/cuda/cache.cu | 161 ---- .../exllamav2_kernels/cuda/cache.cuh | 14 - .../cuda/comp_units/kernel_select.cu | 21 - .../cuda/comp_units/kernel_select.cuh | 81 -- .../cuda/comp_units/unit_exl2_1a.cu | 10 - .../cuda/comp_units/unit_exl2_1b.cu | 10 - .../cuda/comp_units/unit_exl2_2a.cu | 10 - .../cuda/comp_units/unit_exl2_2b.cu | 10 - .../cuda/comp_units/unit_exl2_3a.cu | 10 - .../cuda/comp_units/unit_exl2_3b.cu | 10 - .../cuda/comp_units/unit_gptq_1.cu | 10 - .../cuda/comp_units/unit_gptq_2.cu | 10 - .../cuda/comp_units/unit_gptq_3.cu | 10 - .../exllamav2_kernels/cuda/h_gemm.cu | 274 ------- .../exllamav2_kernels/cuda/h_gemm.cuh | 39 - .../exllamav2_kernels/cuda/layer_norm.cu | 171 ---- .../exllamav2_kernels/cuda/layer_norm.cuh | 20 - .../exllamav2_kernels/cuda/lora.cu | 33 - .../exllamav2_kernels/cuda/lora.cuh | 24 - .../exllamav2_kernels/cuda/matrix_view.cuh | 2 +- .../exllamav2_kernels/cuda/pack_tensor.cu | 268 ------- .../exllamav2_kernels/cuda/pack_tensor.cuh | 35 - .../exllamav2_kernels/cuda/q_attn.cu | 167 ---- .../exllamav2_kernels/cuda/q_attn.cuh | 102 --- .../exllamav2_kernels/cuda/q_gemm.cu | 5 +- .../exllamav2_kernels/cuda/q_gemm.cuh | 2 +- .../cuda/q_gemm_kernel_gptq.cuh | 16 +- .../exllamav2_kernels/cuda/q_matrix.cu | 27 +- .../exllamav2_kernels/cuda/q_mlp.cu | 332 -------- .../exllamav2_kernels/cuda/q_mlp.cuh | 139 ---- .../exllamav2_kernels/cuda/q_mlp_softmax.cuh | 156 ---- .../exllamav2_kernels/cuda/quant/qdq_2.cuh | 2 +- .../exllamav2_kernels/cuda/quant/qdq_4.cuh | 2 +- .../exllamav2_kernels/cuda/quant/qdq_5.cuh | 2 +- .../exllamav2_kernels/cuda/quant/qdq_6.cuh | 2 - .../exllamav2_kernels/cuda/quant/qdq_8.cuh | 2 +- .../exllamav2_kernels/cuda/quantize.cu | 343 -------- .../exllamav2_kernels/cuda/quantize.cuh | 71 -- .../exllamav2_kernels/cuda/rms_norm.cu | 133 ---- .../exllamav2_kernels/cuda/rms_norm.cuh | 19 - .../exllamav2_kernels/cuda/rope.cu | 142 ---- .../exllamav2_kernels/cuda/rope.cuh | 22 - .../exllamav2_kernels/cuda/util.cu | 24 - .../exllamav2_kernels/cuda/util.cuh | 2 +- server/exllamav2_kernels/setup.py | 11 + 72 files changed, 1584 insertions(+), 4306 deletions(-) create mode 100644 Dockerfile.dev create mode 100644 server/Makefile-awq create mode 100644 server/exllama_kernels/exllama_kernels/cu_compat.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_buffers.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_buffers.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh create mode 100644 server/exllama_kernels/exllama_kernels/exllama_ext.cpp rename server/{exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh => exllama_kernels/exllama_kernels/hip_compat.cuh} (69%) create mode 100644 server/exllama_kernels/exllama_kernels/matrix.cuh create mode 100644 server/exllama_kernels/exllama_kernels/tuning.h create mode 100644 server/exllama_kernels/exllama_kernels/util.cuh create mode 100644 server/exllama_kernels/setup.py delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.cpp delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.h delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.cpp delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.h delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.cpp delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.h delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1a.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1b.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2a.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2b.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3a.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3b.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_1.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_2.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_3.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp_softmax.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cu delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cuh delete mode 100644 server/exllamav2_kernels/exllamav2_kernels/cuda/util.cu diff --git a/Dockerfile.dev b/Dockerfile.dev new file mode 100644 index 000000000..5de18e760 --- /dev/null +++ b/Dockerfile.dev @@ -0,0 +1,19 @@ +# LoRAX base image +FROM ghcr.io/predibase/lorax:latest as base + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile + +# Final image +FROM base + +COPY container-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh +COPY sync.sh sync.sh +RUN chmod +x sync.sh + +# ENTRYPOINT ["./entrypoint.sh"] +ENTRYPOINT ["lorax-launcher"] +CMD ["--json-output"] diff --git a/server/Makefile-awq b/server/Makefile-awq new file mode 100644 index 000000000..4e074a133 --- /dev/null +++ b/server/Makefile-awq @@ -0,0 +1,15 @@ +# Fork that adds only the correct stream to this kernel in order +# to make cuda graphs work. +awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4 + +awq: + rm -rf llm-awq + git clone https://github.com/huggingface/llm-awq + +build-awq: awq + cd llm-awq/ && git fetch && git checkout $(awq_commit) + cd llm-awq/awq/kernels && python setup.py build + +install-awq: build-awq + pip uninstall awq_inference_engine -y || true + cd llm-awq/awq/kernels && python setup.py install diff --git a/server/Makefile-eetq b/server/Makefile-eetq index 2cd3e59af..726e47b57 100644 --- a/server/Makefile-eetq +++ b/server/Makefile-eetq @@ -1,4 +1,4 @@ -eetq_commit := cc2fdb4637e03652ac264eaef44dd8492472de01 # 323827dd471458a84e9c840f614e4592b157a4b1 +eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0 eetq: # Clone eetq diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index bc1d37ef5..ffa304aa5 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -2,7 +2,7 @@ flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec flash-attention: # Clone flash attention - pip install packaging + pip install -U packaging ninja --no-cache-dir git clone https://github.com/HazyResearch/flash-attention.git build-flash-attention: flash-attention @@ -13,4 +13,4 @@ build-flash-attention: flash-attention install-flash-attention: build-flash-attention pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true - cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install \ No newline at end of file + cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 3b6baf281..36ef576ae 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,19 +1,29 @@ -flash_att_v2_commit := 2c3baba4a63c4007c8a132c5380edc9430f88a22 +flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 +flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 -flash-attention-v2: - # Clone flash attention - pip install packaging + +flash-attention-v2-cuda: + # Clone flash attention + pip install -U packaging ninja --no-cache-dir git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 -build-flash-attention-v2: flash-attention-v2 - cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) +build-flash-attention-v2-cuda: flash-attention-v2-cuda + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) + cd flash-attention-v2 && git submodule update --init --recursive cd flash-attention-v2 && python setup.py build -# install-flash-attention-v2: build-flash-attention-v2 -# cd flash-attention-v2 && python setup.py install +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install + +flash-attention-v2-rocm: + # Clone flash attention + pip install -U packaging ninja --no-cache-dir + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 + +build-flash-attention-v2-rocm: flash-attention-v2-rocm + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) + cd flash-attention-v2 && git submodule update --init --recursive + cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build -# Install from pip because the target commit is actually a release commit -# and the pip wheels do not require nvcc to be installed. -# Reference: https://github.com/Dao-AILab/flash-attention/issues/509 -install-flash-attention-v2: - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn==2.3.0 --no-build-isolation +install-flash-attention-v2-rocm: build-flash-attention-v2-rocm + cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install diff --git a/server/Makefile-vllm b/server/Makefile-vllm index cf59b03fa..62fa413f4 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,13 +1,25 @@ -vllm_commit := 6d592eb430a37a7f8f5f9beb2dbc014bf3aa76bc - -vllm: +vllm-cuda: # Clone vllm - git clone https://github.com/vllm-project/vllm.git + pip install -U ninja packaging --no-cache-dir + git clone https://github.com/Narsil/vllm.git vllm -build-vllm: vllm - cd vllm && git fetch && git checkout $(vllm_commit) +build-vllm-cuda: vllm-cuda + cd vllm && git fetch && git checkout b5dfc61db88a81069e45b44f7cc99bd9e62a60fa cd vllm && python setup.py build -install-vllm: build-vllm +install-vllm-cuda: build-vllm-cuda + pip uninstall vllm -y || true + cd vllm && python setup.py install + +vllm-rocm: + # Clone vllm + pip install -U ninja packaging --no-cache-dir + git clone https://github.com/fxmarty/rocm-vllm.git vllm + +build-vllm-rocm: vllm-rocm + cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 + cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install + +install-vllm-rocm: build-vllm-rocm pip uninstall vllm -y || true - cd vllm && python setup.py install \ No newline at end of file + cd vllm && python setup.py install diff --git a/server/exllama_kernels/exllama_kernels/cu_compat.cuh b/server/exllama_kernels/exllama_kernels/cu_compat.cuh new file mode 100644 index 000000000..c5258813e --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_buffers.cu b/server/exllama_kernels/exllama_kernels/cuda_buffers.cu new file mode 100644 index 000000000..ee2cbee23 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_buffers.cu @@ -0,0 +1,71 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh b/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh new file mode 100644 index 000000000..afb60a012 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh @@ -0,0 +1,52 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu new file mode 100644 index 000000000..c25b0206b --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu @@ -0,0 +1,61 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "../util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh new file mode 100644 index 000000000..0364e38c4 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu new file mode 100644 index 000000000..1b0f79564 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -0,0 +1,256 @@ +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" +#include "../cu_compat.cuh" +#include "../cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "../hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + bool no_zero, + const cublasHandle_t handle +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); + +// const float alpha = 1.0f; +// const float beta = no_zero ? 1.0f : 0.0f; +// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, +// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh new file mode 100644 index 000000000..4c7a66692 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh @@ -0,0 +1,37 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "../tuning.h" + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + bool no_zero, + const cublasHandle_t handle +); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu new file mode 100644 index 000000000..1f32e6b83 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -0,0 +1,220 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include "q4_matrix.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh new file mode 100644 index 000000000..49431dc95 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/exllama_ext.cpp b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp new file mode 100644 index 000000000..f2df80e8f --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp @@ -0,0 +1,253 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "cuda_func/q4_matrix.cuh" +#include "cuda_func/q4_matmul.cuh" +#include "cuda_func/column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + false, + stream + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + false, + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh b/server/exllama_kernels/exllama_kernels/hip_compat.cuh similarity index 69% rename from server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh rename to server/exllama_kernels/exllama_kernels/hip_compat.cuh index 19b1e4a60..f2a3dcada 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh +++ b/server/exllama_kernels/exllama_kernels/hip_compat.cuh @@ -1,12 +1,24 @@ -#ifndef _compat_gemm_cuh -#define _compat_gemm_cuh +// Adapted from turboderp exllama: https://github.com/turboderp/exllama -#if defined(USE_ROCM) +#ifndef _hip_compat_cuh +#define _hip_compat_cuh -// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required -// for symbols as hipblasHalf. -#include +// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{ + _Float16_2{static_cast<_Float16>(1.0f), + static_cast<_Float16>(1.0f)} / x.data}; +} +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf. __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, @@ -31,8 +43,10 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t #define hipblasHgemm __compat_hipblasHgemm // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_handle hipblasHandle_t #define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream #define rocblas_hgemm __compat_hipblasHgemm -#endif #endif diff --git a/server/exllama_kernels/exllama_kernels/matrix.cuh b/server/exllama_kernels/exllama_kernels/matrix.cuh new file mode 100644 index 000000000..2fd5ab0b3 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/server/exllama_kernels/exllama_kernels/tuning.h b/server/exllama_kernels/exllama_kernels/tuning.h new file mode 100644 index 000000000..770ca46aa --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/server/exllama_kernels/exllama_kernels/util.cuh b/server/exllama_kernels/exllama_kernels/util.cuh new file mode 100644 index 000000000..7b3975732 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/server/exllama_kernels/setup.py b/server/exllama_kernels/setup.py new file mode 100644 index 000000000..987d181eb --- /dev/null +++ b/server/exllama_kernels/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="exllama_kernels", + ext_modules=[ + CUDAExtension( + name="exllama_kernels", + sources=[ + "exllama_kernels/exllama_ext.cpp", + "exllama_kernels/cuda_buffers.cu", + "exllama_kernels/cuda_func/column_remap.cu", + "exllama_kernels/cuda_func/q4_matmul.cu", + "exllama_kernels/cuda_func/q4_matrix.cu", + ], + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.cpp b/server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.cpp deleted file mode 100644 index c1772dd97..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "quantize_func.h" -#include "../cuda/quantize.cuh" - -void quantize_range -( - torch::Tensor quant, - torch::Tensor scale, - torch::Tensor out_q, - float qzero, - float maxq, - torch::Tensor hessian_inv, - torch::Tensor weights, - torch::Tensor error, - int a, - int b -) -{ - int columns = weights.size(1); - int hcolumns = hessian_inv.size(1); - TORCH_CHECK(hcolumns == weights.size(0), "H shape mismatch") - - for (int c = a; c < b; c++) - { - fused_quantize_adjust_cuda - ( - (const float*) weights.data_ptr(), - (float*) quant.data_ptr(), - (const float*) scale.data_ptr(), - out_q.device().is_meta() ? NULL : (uint16_t*) out_q.data_ptr(), - (const float*) hessian_inv.data_ptr(), - (float*) error.data_ptr(), - c, // row - hcolumns, // rows - columns, - qzero, - maxq - ); - - vv_mul_sub_cuda - ( - ((const float*) hessian_inv.data_ptr()) + c * hcolumns + c, - ((const float*) error.data_ptr()) + c * columns, - ((float*) weights.data_ptr()) + c * columns, - b - c, - columns - ); - } - - torch::Tensor x = hessian_inv.slice(0, a, b).slice(1, b).transpose(0, 1); - torch::Tensor y = error.slice(0, a, b); - weights.slice(0, b).addmm_(x, y, 1.0f, -1.0f); -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.h b/server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.h deleted file mode 100644 index cd111bad3..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _quantize_func_h -#define _quantize_func_h - -#include -#include -#include -#include -#include -#include - -void quantize_range -( - torch::Tensor quant, - torch::Tensor scale, - torch::Tensor out_q, - float qzero, - float maxq, - torch::Tensor hessian_inv, - torch::Tensor weights, - torch::Tensor error, - int a, - int b -); - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.cpp b/server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.cpp deleted file mode 100644 index 09c7cff4c..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.cpp +++ /dev/null @@ -1,372 +0,0 @@ -#include "safetensors.h" - -#include -#include "util.h" - -#ifdef __linux__ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#endif - -#define MAX_BLOCK_SIZE (128*1024) -#define MAX_PAGES 4 -#define PAGESIZE (16*1024*1024) -#define Q_DEPTH 1 -#define PINNED_MEMORY (MAX_PAGES * PAGESIZE) - -//#define ST_DEBUG - -void* pinned_buffer = nullptr; -void* aligned_buffer = nullptr; - -struct STPage -{ -public: - int file_descriptor; - #ifdef __linux__ - size_t file_a; - size_t file_b; - long access; - std::atomic locks; - char* ptr; - #endif -}; - -STPage pages[MAX_PAGES]; -long serial = 1; - -void safetensors_pinned_buffer() -{ - #ifdef __linux__ - - if (pinned_buffer) return; - cudaMallocHost((void**) &pinned_buffer, PINNED_MEMORY + MAX_BLOCK_SIZE); - TORCH_CHECK(pinned_buffer, "Unable to allocate pinned memory"); - aligned_buffer = (void*) ((char *)(((uintptr_t)pinned_buffer + MAX_BLOCK_SIZE - 1) & ~(MAX_BLOCK_SIZE - 1))); - - for (int i = 0; i < MAX_PAGES; i++) - { - pages[i].file_descriptor = -1; - pages[i].file_a = 0; - pages[i].file_b = 0; - pages[i].access = -1; - pages[i].locks.store(0); - pages[i].ptr = ((char*) aligned_buffer) + i * PAGESIZE; - } - - #endif -} - -void safetensors_free_pinned_buffer() -{ - #ifdef __linux__ - - if (!pinned_buffer) return; - cudaFreeHost((void*) pinned_buffer); - pinned_buffer = nullptr; - aligned_buffer = nullptr; - - #endif -} - -STPage* get_cache_page(size_t file_descriptor, size_t block_size, size_t filesize, size_t file_a, size_t file_b) -{ - #ifdef __linux__ - - #ifdef ST_DEBUG - printf("-- get cache page\n"); - DBGX3(file_descriptor, file_a, file_b); - DBGX2(block_size, filesize); - #endif - - // Find existing page in cache - - for (int i = 0; i < MAX_PAGES; i++) - { - if (pages[i].file_descriptor == file_descriptor && - pages[i].file_a == file_a && - pages[i].file_b == file_b) - { - pages[i].access = serial++; - return &pages[i]; - } - } - - // Find page to evict - - int oldest_i = -1; - long oldest = std::numeric_limits::max(); - - while (oldest_i == -1) - { - for (int i = 0; i < MAX_PAGES; i++) - { - if (pages[i].locks.load() > 0) continue; - if (pages[i].access < oldest) - { - oldest_i = i; - oldest = pages[i].access; - } - } - } - - #ifdef ST_DEBUG - printf("-- evict page\n"); - DBGX(oldest_i); - #endif - - // Load page - - #ifdef ST_DEBUG - printf("-- load page\n"); - DBGX3(file_descriptor, file_a, file_b); - #endif - - int p = oldest_i; - pages[p].access = serial++; - pages[p].file_a = file_a; - pages[p].file_b = file_b; - pages[p].file_descriptor = file_descriptor; - - aiocb aiocb_list[Q_DEPTH]; - size_t read_lens[Q_DEPTH]; - int num_chunks = 0; - - size_t q_chunk = PAGESIZE / Q_DEPTH; - size_t q_a = file_a; - char* page_ptr = pages[p].ptr; - - for (int i = 0; i < Q_DEPTH; ++i) - { - size_t q_b = q_a + q_chunk; - if (q_b > filesize) q_b = filesize; - - size_t read_len = q_b - q_a; - read_lens[i] = read_len; - //read_len = DIVIDE(read_len, 2 * block_size) * 2 * block_size; - read_len = q_chunk; - - memset(&aiocb_list[i], 0, sizeof(aiocb)); - aiocb_list[i].aio_fildes = file_descriptor; - aiocb_list[i].aio_buf = page_ptr; - aiocb_list[i].aio_nbytes = read_len; - aiocb_list[i].aio_offset = q_a; - - #ifdef ST_DEBUG - DBGX3(q_a, q_b, read_len); - DBGX2(filesize, read_lens[i]); - #endif - - aio_read(&aiocb_list[i]); - num_chunks++; - - if (q_b >= filesize) break; - - page_ptr += q_chunk; - q_a += q_chunk; - } - - q_a = file_a; - - for (int i = 0; i < num_chunks; ++i) - { - struct aiocb *aiocb_active[1]; - aiocb_active[0] = &aiocb_list[i]; - aio_suspend(aiocb_active, 1, NULL); - - int err = aio_error(&aiocb_list[i]); - - #ifdef ST_DEBUG - DBGX(err); - #endif - - TORCH_CHECK(err == 0, "Async read error (1)"); - - ssize_t bytes_read = aio_return(&aiocb_list[i]); - - #ifdef ST_DEBUG - DBGX2(bytes_read, read_lens[i]); - #endif - - TORCH_CHECK(bytes_read == read_lens[i], "Async read error (2)"); - } - - return &pages[p]; - - #else - TORCH_CHECK(false, "fasttensors only supported on Linux"); - return NULL; - #endif -} - -uintptr_t safetensors_open(const char* filename) -{ - #ifdef __linux__ - - STFile* f = new STFile(filename); - return reinterpret_cast (f); - - - #else - TORCH_CHECK(false, "fasttensors only supported on Linux"); - return 0; - #endif -} - -void safetensors_close(uintptr_t handle) -{ - #ifdef __linux__ - - STFile* f = reinterpret_cast (handle); - delete f; - - #else - TORCH_CHECK(false, "fasttensors only supported on Linux"); - #endif -} - -STFile::STFile(const char* filename) -{ - #ifdef __linux__ - - file_descriptor = open(filename, O_RDONLY | O_DIRECT); - TORCH_CHECK(file_descriptor != -1, "Safetensors file I/O error"); - - struct stat sb; - auto res = fstat(file_descriptor, &sb); - TORCH_CHECK(res != -1, "Safetensors fstat failed"); - filesize = sb.st_size; - block_size = sb.st_blksize; - padded_size = DIVIDE(filesize, block_size) * block_size; - TORCH_CHECK(block_size <= MAX_BLOCK_SIZE, "Block size too large") - - #else - TORCH_CHECK(false, "fasttensors only supported on Linux"); - #endif -} - -STFile::~STFile() -{ - #ifdef __linux__ - - close(file_descriptor); - - #else - TORCH_CHECK(false, "fasttensors only supported on Linux"); - #endif -} - -void dec_lock(cudaStream_t stream, cudaError_t status, void *user_data) -{ - #ifdef __linux__ - STPage* p = (STPage*) user_data; - p->locks--; - #endif -} - -void STFile::load -( - torch::Tensor target, - size_t offset, - size_t length, - bool gpu -) -{ - #ifdef __linux__ - - safetensors_pinned_buffer(); - - #ifdef ST_DEBUG - printf("-- load tensor\n"); - DBGX2(offset, length); - DBGI(length); - #endif - - // Get cache pages - - size_t file_b = offset / PAGESIZE * PAGESIZE; - size_t file_c = DIVIDE(offset + length, PAGESIZE) * PAGESIZE; - - // Loop over pages - - size_t file_a = file_b; - size_t tensor_offset = 0; - - while (tensor_offset < length) - { - file_a = file_b; - file_b += PAGESIZE; - - STPage* page = get_cache_page(file_descriptor, block_size, filesize, file_a, file_b); - ssize_t left = offset - file_a; - if (left < 0) left = 0; - ssize_t right = offset + length - file_a; - if (right > PAGESIZE) right = PAGESIZE; - ssize_t copy_len = right - left; - - #ifdef ST_DEBUG - printf("-- copy chunk\n"); - DBGX3(left, right, copy_len); - DBGX(tensor_offset); - DBGI(copy_len); - #endif - - char* src = page->ptr + left; - char* dst = ((char*) target.data_ptr()) + tensor_offset; - - if (gpu) - { - page->locks++; - cudaMemcpyAsync(dst, src, copy_len, cudaMemcpyHostToDevice); - cudaStreamAddCallback(NULL, dec_lock, (void*) page, 0); - } - else - { - memcpy(dst, src, copy_len); - } - - //cudaDeviceSynchronize(); - - tensor_offset += copy_len; - } - - #else - TORCH_CHECK(false, "fasttensors only supported on Linux"); - #endif -} - -void safetensors_load -( - uintptr_t handle, - torch::Tensor target, - size_t offset, - size_t length -) -{ - #ifdef __linux__ - - STFile* f = reinterpret_cast (handle); - c10::optional device = torch::device_of(target); - - if (device.has_value() && device->type() == torch::kCPU) - { - f->load(target, offset, length, false); - } - else - { - const at::cuda::OptionalCUDAGuard device_guard(device); - f->load(target, offset, length, true); - } - - #else - TORCH_CHECK(false, "fasttensors only supported on Linux"); - #endif -} - diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.h b/server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.h deleted file mode 100644 index e37e9e6ca..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cpp/safetensors.h +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef _safetensors_h -#define _safetensors_h - -#include -#include -#include -#include -#include -#include - -class STFile -{ -public: - STFile(const char* filename); - ~STFile(); - - void load - ( - torch::Tensor target, - size_t offset, - size_t length, - bool gpu - ); - - int file_descriptor; - size_t filesize; - size_t padded_size; - size_t block_size; - void* aligned_buffer; -}; - -void safetensors_pinned_buffer(); -void safetensors_free_pinned_buffer(); - -uintptr_t safetensors_open(const char* filename); -void safetensors_close(uintptr_t handle); - -void safetensors_load -( - uintptr_t handle, - torch::Tensor target, - size_t offset, - size_t length -); - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.cpp b/server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.cpp deleted file mode 100644 index 057a6e699..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.cpp +++ /dev/null @@ -1,730 +0,0 @@ -#include "sampling.h" -#include "util.h" -#include "algorithm" -#include -#include -#include -#include - -const int top_k_heap_threshold = 500; - -bool* g_rep_mask = NULL; -int g_vocab_size = 0; - -void apply_rep_penalty_cpu -( - const int vocab_size, - const uint64_t* sequence, - const float penalty_max, - const int sustain, - const int decay, - const float alpha_frequency, - const float alpha_presence, - const int seq_len, - float* logits -) -{ - // Map of which logits have already had penalties applied - - if (vocab_size > g_vocab_size) - { - if (g_rep_mask) free(g_rep_mask); - g_vocab_size = vocab_size; - g_rep_mask = (bool*) malloc(g_vocab_size * sizeof(bool)); - } - - memset(g_rep_mask, 0, g_vocab_size * sizeof(bool)); - - // Penalties to apply - - float rep_p = penalty_max; // Multiplicative penalty, as in HF repetition penalty - float freq_p = alpha_frequency; // Additive frequency penalty, as in OAI spec - float pres_p = alpha_presence; // Additive presence penalty, as in OAI spec - - // Change in penalties over the "decay" range of the context - - float d_rep_p = 0.0f; - float d_freq_p = 0.0f; - float d_pres_p = 0.0f; - if (decay) - { - d_rep_p = (1.0f - rep_p) / (float) decay; - d_freq_p = (0.0f - freq_p) / (float) decay; - d_pres_p = (0.0f - pres_p) / (float) decay; - } - - // "sustain" length, range of the context over which penalties are fixed - - int sust = sustain == -1 ? seq_len : sustain; - - // First token of the penalty range, including decay - - int beg = seq_len - sust - decay; - if (beg < 0) beg = 0; - - // Iter over context, backwards - - for (int i = seq_len; i > beg;) - { - uint64_t t = sequence[--i]; - - // If t has not been encountered before, apply rep_p and pres_p - - if (!g_rep_mask[t]) - { - if (logits[t] > 0.0) logits[t] /= rep_p; // Multiplicative penalty - else logits[t] *= rep_p; - - logits[t] -= pres_p; // Additive penalty - - g_rep_mask[t] = true; // Only once per logit - } - - // Apply freq_p penalty for every time a token is encountered, so the total additive penalty is count * freq_p - - logits[t] -= freq_p; - - // If we're in the "decay" range, reduce penalties for every token - - if (--sust < 0) - { - rep_p += d_rep_p; - freq_p += d_freq_p; - pres_p += d_pres_p; - } - } -} - -void softmax_cpu -( - const int vocab_size, - const float temperature, - const float* logits, - const bool* logits_filter, - float* output -) -{ - float esum = 0.0f; - float itemp = 1.0f / temperature; - float maxl = -1e38; - - #pragma unroll(32) - for (int i = 0; i < vocab_size; i++) - { - if (!logits_filter[i]) continue; - maxl = fmaxf(logits[i], maxl); - } - maxl *= itemp; - - #pragma unroll(32) - for (int i = 0; i < vocab_size; i++) - { - if (!logits_filter[i]) continue; - float e = expf(logits[i] * itemp - maxl); - output[i] = e; - esum += e; - } - float isum = 1.0f / esum; - - #pragma unroll(32) - for (int i = 0; i < vocab_size; i++) - { - if (logits_filter[i]) - output[i] *= isum; - else - output[i] = 0.0f; - } - -// printf("Softmax:"); -// float summ = 0.0f; -// for (int i = 0; i < vocab_size; i++) -// { -// if (logits_filter[i]) -// { -// printf("%d, %f\n", i, output[i]); -// summ += output[i]; -// } -// } -// printf("sum: %f\n", summ); -} - -int post_softmax_temperature -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float temperature -) -{ -// printf("---- pre\n"); -// for (int i = 0; i < num_candidates; ++i) -// DBGIF(i, temp_probs[i]); - - float psum = 0.0f; - float itemp = 1.0f / temperature; - for (int i = 0; i < num_candidates; ++i) - { - float p = powf(temp_probs[i], itemp); - psum += p; - temp_probs[i] = p; - } - - float ipsum = 1.0f / psum; - for (int i = 0; i < num_candidates; ++i) - temp_probs[i] *= ipsum; - -// printf("---- post\n"); -// DBGF(temperature); -// printf("----\n"); -// for (int i = 0; i < num_candidates; ++i) -// DBGIF(i, temp_probs[i]); -// printf("\n"); - - return num_candidates; -} - - -void normalize_cpu -( - const int num_candidates, - float* probs -) -{ - float sum = 0.0f; - #pragma unroll(32) - for (int i = 0; i < num_candidates; i++) sum += probs[i]; - float isum = 1.0f / sum; - #pragma unroll(32) - for (int i = 0; i < num_candidates; i++) probs[i] *= isum; -} - -int greedy_sample -( - const int num_candidates, - const float* probs, - const bool* logits_filter -) -{ - int maxidx = -1; - float max = -1e38; - - for(int i = 1; i < num_candidates; i++) - { - if (logits_filter[i] && (maxidx == -1 || probs[i] > max)) - { - max = probs[i]; - maxidx = i; - } - } - return maxidx; -} - -template -inline void swap(T &a, T &b) -{ - T temp = a; - a = b; - b = temp; -} - -inline bool cmp_asc(const float& a, const float& b) -{ - return a > b; -} - -inline bool cmp_desc(const float& a, const float& b) -{ - return a < b; -} - -template -void quicksort_with_idx -( - float* arr, - int* idx, - int low, - int high, - int max_index -) -{ - if (low >= high) return; - - // Bubblesort very short segments - - if (high - low == 1) - { - int i0 = low; - int i1 = low + 1; - - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - return; - } - - if (high - low == 2) - { - int i0 = low; - int i1 = low + 1; - int i2 = low + 2; - - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - if (cmp_func(arr[i1], arr[i2])) { swap(arr[i1], arr[i2]); swap(idx[i1], idx[i2]); } - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - return; - } - - if (high - low == 3) - { - int i0 = low; - int i1 = low + 1; - int i2 = low + 2; - int i3 = low + 3; - - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - if (cmp_func(arr[i1], arr[i2])) { swap(arr[i1], arr[i2]); swap(idx[i1], idx[i2]); } - if (cmp_func(arr[i2], arr[i3])) { swap(arr[i2], arr[i3]); swap(idx[i2], idx[i3]); } - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - if (cmp_func(arr[i1], arr[i2])) { swap(arr[i1], arr[i2]); swap(idx[i1], idx[i2]); } - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - return; - } - - if (high - low == 4) - { - int i0 = low; - int i1 = low + 1; - int i2 = low + 2; - int i3 = low + 3; - int i4 = low + 4; - - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - if (cmp_func(arr[i1], arr[i2])) { swap(arr[i1], arr[i2]); swap(idx[i1], idx[i2]); } - if (cmp_func(arr[i2], arr[i3])) { swap(arr[i2], arr[i3]); swap(idx[i2], idx[i3]); } - if (cmp_func(arr[i3], arr[i4])) { swap(arr[i3], arr[i4]); swap(idx[i3], idx[i4]); } - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - if (cmp_func(arr[i1], arr[i2])) { swap(arr[i1], arr[i2]); swap(idx[i1], idx[i2]); } - if (cmp_func(arr[i2], arr[i3])) { swap(arr[i2], arr[i3]); swap(idx[i2], idx[i3]); } - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - if (cmp_func(arr[i1], arr[i2])) { swap(arr[i1], arr[i2]); swap(idx[i1], idx[i2]); } - if (cmp_func(arr[i0], arr[i1])) { swap(arr[i0], arr[i1]); swap(idx[i0], idx[i1]); } - return; - } - - float pivot = arr[high]; - int i = low - 1; - for (int j = low; j < high; j++) - { - if (!cmp_func(arr[j], pivot)) - { - i++; - swap(arr[i], arr[j]); - swap(idx[i], idx[j]); - } - } - - swap(arr[i + 1], arr[high]); - swap(idx[i + 1], idx[high]); - int pos = i + 1; - - if (max_index == 0 || low <= max_index) - quicksort_with_idx(arr, idx, low, pos - 1, max_index); - if (max_index == 0 || pos <= max_index) - quicksort_with_idx(arr, idx, pos + 1, high, max_index); -} - -// Discard tiny probabilities, improves performance when temperature is very low - -int pre_sort_descending -( - const int num_candidates, - float* arr, - int* idx -) -{ - const float eps = 1e-8; - int i = 0; - int j = num_candidates - 1; - - while (i <= j) - { - if (arr[j] < eps) { j--; continue; } - if (arr[i] >= eps) { i++; continue; } - swap(arr[i], arr[j]); - swap(idx[i], idx[j]); - i++; - j--; - } - - return i; -} - -int sort_descending -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - int max_index -) -{ - int pre = pre_sort_descending(num_candidates, temp_probs, temp_indices); - quicksort_with_idx(temp_probs, temp_indices, 0, pre - 1, max_index); - -// int m = (max_index == 0 ? num_candidates : max_index); -// for (int i = 0; i < m; i++) printf("%i - %f \n", temp_indices[i], temp_probs[i] * 10000.0); -// for (int i = 0; i < m - 1; i++) if (temp_probs[i] < temp_probs[i + 1] - 2e-8) DBGI(i); - - return pre; -} - -int top_k_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - int top_k -) -{ - //TIME_START; - - // Use min-heap for lower values of K - - if (top_k <= top_k_heap_threshold) - { - std::priority_queue, std::vector>, std::greater>> min_heap; - - for (int i = 0; i < top_k; ++i) min_heap.push({temp_probs[i], temp_indices[i]}); - - for (int i = top_k; i < num_candidates; i++) - { - if (temp_probs[i] > min_heap.top().first) - { - min_heap.pop(); - min_heap.push({temp_probs[i], temp_indices[i]}); - } - } - - int j = top_k; - for (int i = 0; i < top_k; i++) - { - j--; - temp_probs[j] = min_heap.top().first; - temp_indices[j] = min_heap.top().second; - min_heap.pop(); - } - } - - // For larger values, quicksort is still faster - - else - { - sort_descending(num_candidates, temp_probs, temp_indices, top_k); - } - - //TIME_STOP; - - return top_k; -} - -int top_p_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float top_p -) -{ - std::priority_queue, std::vector>, std::greater>> min_heap; - - //TIME_START; - - float min_p = 1e-6; - - float sum = 0.0f; - for (int i = 0; i < num_candidates; i++) - { - if (temp_probs[i] < min_p) continue; - if (sum > top_p && temp_probs[i] < min_heap.top().first) continue; - - min_heap.push({temp_probs[i], temp_indices[i]}); - sum += temp_probs[i]; - - while (sum > top_p && min_heap.size() > 1) - { - sum -= min_heap.top().first; - min_heap.pop(); - } - } - - int j = min_heap.size(); - int k = j; - while (j > 0) - { - j--; - temp_probs[j] = min_heap.top().first; - temp_indices[j] = min_heap.top().second; - min_heap.pop(); - } - - //TIME_STOP; - - return k; -} - -int keep_threshold -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float threshold -) -{ - int i = 0; - int j = num_candidates - 1; - - while (j >= i) - { - while (temp_probs[i] >= threshold && j >= i) i++; - if (temp_probs[j] >= threshold) - { - swap(temp_probs[i], temp_probs[j]); - swap(temp_indices[i], temp_indices[j]); - i++; - } - j--; - } - return i; -} - -int top_a_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float top_a -) -{ - // Find top probability - float top_prob = temp_probs[0]; - for (int i = 1; i < num_candidates; i++) - if (temp_probs[i] > top_prob) top_prob = temp_probs[i]; - - // Calculate the threshold - float threshold = top_a * top_prob * top_prob; - - // Use the keep_threshold function to keep only probabilities above the threshold - int n = keep_threshold(num_candidates, temp_probs, temp_indices, threshold); - - return n; -} - -int min_p_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float min_p -) -{ - //TIME_START; - - float top_prob = temp_probs[0]; - for (int i = 1; i < num_candidates; i++) - if (temp_probs[i] > top_prob) top_prob = temp_probs[i]; - - float threshold = top_prob * min_p; - int n = keep_threshold(num_candidates, temp_probs, temp_indices, threshold); - - //TIME_STOP; - - return n; -} - -int tfs_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float tfs -) -{ - //TIME_START; - - if (num_candidates < 3) return num_candidates; // Discrete 2nd derivative undefined - - // 2nd derivative of sorted probs - - int nc = sort_descending(num_candidates, temp_probs, temp_indices, num_candidates); - - float* derivative = (float*) malloc(nc * sizeof(float)); - float dsum = 0.0f; - for (int i = 0; i < nc - 2; i++) - { - float d = fabs(- temp_probs[i] + 2 * temp_probs[i + 1] - temp_probs[i + 2]); - dsum += d; - derivative[i] = d; - } - - // Keep probs for cumulative sum of normalized 2nd derivative <= threshold - - float dsum_i = 1.0f / dsum; - int k = 0; - float cumsum = 0.0f; - while (k < nc - 2) - { - cumsum += derivative[k] * dsum_i; - if (cumsum > tfs) break; - k++; - } - - // Center distribution on the cutoff point - - k++; - - //TIME_STOP; - - free(derivative); - return k; -} - -int mirostat_pre_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float mirostat_mu, - float mirostat_tau, - float mirostat_eta -) -{ - //TIME_START; - - // If mu not yet initialized, initialize here - - float mu = mirostat_mu; - if (mu == 0.0f) mu = mirostat_tau * 2.0f; - - // Discard tokens with surprise greater than mu - - int nc = sort_descending(num_candidates, temp_probs, temp_indices, num_candidates); - - float target_prob = powf(2, -mu); - int k = 1; - for (; k < nc; k++) - if (temp_probs[k] < target_prob) break; - - //TIME_STOP; - - return k; -} - -float mirostat_post_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float mirostat_mu, - float mirostat_tau, - float mirostat_eta -) -{ - // If mu not yet initializer, initialize here - - float mu = mirostat_mu; - if (mu == 0.0f) mu = mirostat_tau * 2.0f; - - // Adjust mu based on probability of final choice - - float observed_surprise = -log2(temp_probs[0]); - mu += mirostat_eta * (mirostat_tau - observed_surprise); - - return mu; -} - -int typical_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float typical -) -{ - //TIME_START; - - const float epsilon = 1e-10; - - float* temp = (float*) malloc(num_candidates * sizeof(float)); - int* entropy_dev_order = (int*) malloc(num_candidates * sizeof(int)); - int* temp_indices_2 = (int*) malloc(num_candidates * sizeof(int)); - - float neg_entropy = 0.0f; - for (int i = 0; i < num_candidates; i++) - { - float x = temp_probs[i]; - float y = x + logf(x + epsilon); - neg_entropy += x * y; - temp[i] = y; // temp = log_probs - } - - for (int i = 0; i < num_candidates; i++) - { - temp[i] = fabs(temp[i] - neg_entropy); // temp = entropy_dev - entropy_dev_order[i] = i; - } - - quicksort_with_idx(temp, entropy_dev_order, 0, num_candidates - 1, num_candidates); - - memcpy(temp, temp_probs, num_candidates * sizeof(float)); // temp = temp_probs - memcpy(temp_indices_2, temp_indices, num_candidates * sizeof(int)); - - float cumprob = 0.0f; - int num = 0; - - while (true) - { - int j = entropy_dev_order[num]; - float p = temp[j]; - temp_probs[num] = p; - temp_indices[num] = temp_indices_2[j]; - - cumprob += p; - if (cumprob >= typical) break; - num++; - if (num >= num_candidates) break; - } - - free(temp); - free(entropy_dev_order); - free(temp_indices_2); - - //TIME_STOP; - - if (num == 0) num = 1; - return num; -} - -int multinomial_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float random -) -{ - int idx = 0; - float accum = temp_probs[idx]; - - while (true) - { - if (accum >= random) break; - if (idx == num_candidates - 1) break; - idx++; - accum += temp_probs[idx]; - } - - temp_probs[0] = temp_probs[idx]; - temp_indices[0] = temp_indices[idx]; - - return 1; -} - - - diff --git a/server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.h b/server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.h deleted file mode 100644 index 5ce1dc153..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cpp/sampling.h +++ /dev/null @@ -1,138 +0,0 @@ -#ifndef _sampling_h -#define _sampling_h - -#include -#include -#include -#include - -void apply_rep_penalty_cpu -( - const int vocab_size, - const uint64_t* sequence, - const float penalty_max, - const int sustain, - const int decay, - const float alpha_frequency, - const float alpha_presence, - const int seq_len, - float* logits -); - -void softmax_cpu -( - const int vocab_size, - const float temperature, - const float* logits, - const bool* logits_filter, - float* output -); - -void normalize_cpu -( - const int num_candidates, - float* probs -); - -int greedy_sample -( - const int num_candidates, - const float* probs, - const bool* logits_filter -); - -int sort_descending -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - int max_index -); - -int top_k_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - int top_k -); - -int top_p_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float top_p -); - -int top_a_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float top_a -); - -int min_p_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float min_p -); - -int tfs_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float tfs -); - -int typical_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float typical -); - -int mirostat_pre_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float mirostat_mu, - float mirostat_tau, - float mirostat_eta -); - -float mirostat_post_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float mirostat_mu, - float mirostat_tau, - float mirostat_eta -); - -int post_softmax_temperature -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float temperature -); - -int multinomial_cpu -( - const int num_candidates, - float* temp_probs, - int* temp_indices, - float random -); - -#endif - - diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cu deleted file mode 100644 index 900a92b53..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cu +++ /dev/null @@ -1,161 +0,0 @@ -#include "cache.cuh" - -// #if defined(CUDART_VERSION) && CUDART_VERSION >= 11080 -// -// #include - -#include "quant/qdq_util.cuh" -#include "util.cuh" - -#define THREADS 32 - -// The upper 8 bits of FP16 are equivalent to FP8 E5M2. -// -// The range of values typically cached seem to be in the range of +/- 16, with an exponent component (with bias) up to -// about 20. Empirically, the MSE over the whole range of observed values in the K/V cache works out the same for E4M3 -// and E5M2. However, over 80% of values in the cache tensors fall within the range of -1..1, where E5M2 produces about -// a 25% lower MSE. - -__device__ inline uint32_t compress(uint32_t v) -{ - uint32_t vh = (v & 0xff000000) >> 16; - uint32_t vl = (v & 0x0000ff00) >> 8; - return vh | vl; -} - -__device__ inline uint32_t decompress(uint32_t v) -{ - uint32_t vh = (v & 0xff00) << 16; - uint32_t vl = (v & 0x00ff) << 8; - return vh | vl; -} - -__global__ void nv_fp16_to_fp8 -( - const half* __restrict__ pIn, - unsigned char* __restrict__ pOut, - int stride, - int height, - int min, - int max -) -{ - int x = min + (blockIdx.x * blockDim.x + threadIdx.x) * 8; - int y = blockIdx.y; - if (x >= max) return; - - int4* in_ptr = (int4*) (pIn + y * stride + x); - int2* out_ptr = (int2*) (pOut + y * stride + x); - - int4 in = *in_ptr; - uint32_t c0 = compress(in.x); - uint32_t c1 = compress(in.y); - uint32_t c2 = compress(in.z); - uint32_t c3 = compress(in.w); - int2 out = make_int2(c0 | (c1 << 16), c2 | (c3 << 16)); - *out_ptr = out; -} - -__global__ void nv_fp8_to_fp16 -( - const unsigned char* __restrict__ pIn, - half* __restrict__ pOut, - int stride, - int height, - int min, - int max -) -{ - int x = min + (blockIdx.x * blockDim.x + threadIdx.x) * 8; - int y = blockIdx.y; - if (x >= max) return; - - int2* in_ptr = (int2*) (pIn + y * stride + x); - int4* out_ptr = (int4*) (pOut + y * stride + x); - - int2 in = *in_ptr; - uint32_t c0 = decompress(in.x); - uint32_t c1 = decompress(in.x >> 16); - uint32_t c2 = decompress(in.y); - uint32_t c3 = decompress(in.y >> 16); - int4 out = make_int4(c0, c1, c2, c3); - *out_ptr = out; -} - -// __global__ void nv_fp32_to_fp16(const float* pIn, half* pOut, int size) -// { -// int i = blockIdx.x * blockDim.x + threadIdx.x; -// if (i < size) { -// pOut[i] = __float2half(pIn[i]); -// } -// } - -// __global__ void nv_fp16_to_fp8_ref(const half* pIn, unsigned char *pOut, int size) -// { -// int i = blockIdx.x * blockDim.x + threadIdx.x; -// if (i < size) { -// pOut[i] = __nv_cvt_halfraw_to_fp8(pIn[i], __NV_SATFINITE, __NV_E4M3); -// } -// } -// -// __global__ void nv_fp8_to_fp16_ref(const unsigned char* pIn, half* pOut, int size) -// { -// int i = blockIdx.x * blockDim.x + threadIdx.x; -// if (i < size) { -// pOut[i] = __nv_cvt_fp8_to_halfraw(pIn[i], __NV_E4M3); -// } -// } - -void array_fp16_to_fp8_cuda(const half* pIn, unsigned char *pOut, int stride, int height, int offset, int width) -{ - int min = offset; - int max = offset + width; - min = min / 8 * 8; - max = min + (max - min + 7) / 8 * 8; - - dim3 blockDim, gridDim; - blockDim.x = THREADS; - gridDim.x = DIVIDE((max - min) / 8, THREADS); - gridDim.y = height; - - nv_fp16_to_fp8<<>>(pIn, pOut, stride, height, min, max); - // cuda_check( cudaPeekAtLastError() ); -} - -void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int stride, int height, int offset, int width) -{ - int min = offset; - int max = offset + width; - min = min / 8 * 8; - max = min + (max - min + 7) / 8 * 8; - - dim3 blockDim, gridDim; - blockDim.x = THREADS; - gridDim.x = DIVIDE((max - min) / 8, THREADS); - gridDim.y = height; - - nv_fp8_to_fp16<<>>(pIn, pOut, stride, height, min, max); - // cuda_check( cudaPeekAtLastError() ); -} - -// void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size) -// { -// const int threads = 512; -// int blocks = DIVIDE(size / 1, threads); -// nv_fp16_to_fp8_ref<<>>(pIn, pOut, size); -// } -// -// void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size) -// { -// const int threads = 512; -// int blocks = DIVIDE(size / 1, threads); -// nv_fp8_to_fp16_ref<<>>(pIn, pOut, size); -// } - -// #else -// -// void array_fp16_to_fp8_cuda(const half* pIn, unsigned char *pOut, int size) { } -// -// void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int size) { } -// -// #endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cuh deleted file mode 100644 index 4cb291b65..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/cache.cuh +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _cache_cuh -#define _cache_cuh - -#include -#include -#include -#include - -void array_fp16_to_fp8_cuda(const half* pIn, unsigned char *pOut, int stride, int height, int offset, int width); -void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int stride, int height, int offset, int width); -// void array_fp16_to_fp8_ref_cuda(const half* pIn, unsigned char *pOut, int size); -// void array_fp8_to_fp16_ref_cuda(const unsigned char* pIn, half* pOut, int size); - -#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cu deleted file mode 100644 index 844274c2b..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cu +++ /dev/null @@ -1,21 +0,0 @@ -#include "kernel_select.cuh" -#include "../util.cuh" - -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights) -{ - fp_gemm_half_q_half_gptq_kernel k; - k = pick_gemm_half_q_half_gptq_kernel_1(m_count, r_weights, mul_r_weights); if (k) return k; - k = pick_gemm_half_q_half_gptq_kernel_2(m_count, r_weights, mul_r_weights); if (k) return k; - k = pick_gemm_half_q_half_gptq_kernel_3(m_count, r_weights, mul_r_weights); return k; -} - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int perm, bool r_weights, bool mul_r_weights) -{ - fp_gemm_half_q_half_kernel k; - k = pick_gemm_half_q_half_kernel_1a(perm, r_weights, mul_r_weights); if (k) return k; - k = pick_gemm_half_q_half_kernel_1b(perm, r_weights, mul_r_weights); if (k) return k; - k = pick_gemm_half_q_half_kernel_2a(perm, r_weights, mul_r_weights); if (k) return k; - k = pick_gemm_half_q_half_kernel_2b(perm, r_weights, mul_r_weights); if (k) return k; - k = pick_gemm_half_q_half_kernel_3a(perm, r_weights, mul_r_weights); if (k) return k; - k = pick_gemm_half_q_half_kernel_3b(perm, r_weights, mul_r_weights); return k; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cuh deleted file mode 100644 index cdae46ace..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/kernel_select.cuh +++ /dev/null @@ -1,81 +0,0 @@ -#ifndef _kernel_select_cuh -#define _kernel_select_cuh - -#include -#include -#include -#include -#include - -#include "../q_gemm_kernel.cuh" -#include "../q_gemm_kernel_gptq.cuh" - -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel_1(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel_2(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel_3(const int m_count, bool r_weights, bool mul_r_weights); - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_1a(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_1b(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_2a(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_2b(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_3a(const int m_count, bool r_weights, bool mul_r_weights); -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_3b(const int m_count, bool r_weights, bool mul_r_weights); - -template -struct map_m_count_gptq { - static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count) - { - if (m_count == GPTQ_BLOCK_M_SIZE_MAX) return gemm_half_q_half_gptq_kernel; - printf(" ## No GPTQ kernel found for block size %i\n", m_count); - return NULL; - } -}; - -template -struct map_m_count_exl2_a { - static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int perm) - { - switch (perm) - { - case 0b00000010: return gemm_half_q_half_kernel<0b00000010, use_r_weights, mul_r_weights>; - case 0b00000110: return gemm_half_q_half_kernel<0b00000110, use_r_weights, mul_r_weights>; - case 0b00000100: return gemm_half_q_half_kernel<0b00000100, use_r_weights, mul_r_weights>; - case 0b00001110: return gemm_half_q_half_kernel<0b00001110, use_r_weights, mul_r_weights>; - case 0b00001100: return gemm_half_q_half_kernel<0b00001100, use_r_weights, mul_r_weights>; - case 0b00001000: return gemm_half_q_half_kernel<0b00001000, use_r_weights, mul_r_weights>; - case 0b00011000: return gemm_half_q_half_kernel<0b00011000, use_r_weights, mul_r_weights>; - case 0b00010000: return gemm_half_q_half_kernel<0b00010000, use_r_weights, mul_r_weights>; - case 0b00110000: return gemm_half_q_half_kernel<0b00110000, use_r_weights, mul_r_weights>; - case 0b00100000: return gemm_half_q_half_kernel<0b00100000, use_r_weights, mul_r_weights>; - default: - return NULL; - } - } -}; - -template -struct map_m_count_exl2_b { - static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int perm) - { - switch (perm) - { - case 0b10000000: return gemm_half_q_half_kernel<0b10000000, use_r_weights, mul_r_weights>; - case 0b00100110: return gemm_half_q_half_kernel<0b00100110, use_r_weights, mul_r_weights>; - case 0b00010100: return gemm_half_q_half_kernel<0b00010100, use_r_weights, mul_r_weights>; - case 0b10001100: return gemm_half_q_half_kernel<0b10001100, use_r_weights, mul_r_weights>; - case 0b10011000: return gemm_half_q_half_kernel<0b10011000, use_r_weights, mul_r_weights>; - case 0b10110000: return gemm_half_q_half_kernel<0b10110000, use_r_weights, mul_r_weights>; - case 0b10101100: return gemm_half_q_half_kernel<0b10101100, use_r_weights, mul_r_weights>; - case 0b00001010: return gemm_half_q_half_kernel<0b00001010, use_r_weights, mul_r_weights>; - case 0b00101000: return gemm_half_q_half_kernel<0b00101000, use_r_weights, mul_r_weights>; - case 0b10100000: return gemm_half_q_half_kernel<0b10100000, use_r_weights, mul_r_weights>; - default: return gemm_half_q_half_kernel<0b10111110, use_r_weights, mul_r_weights>; - // printf(" ## No kernel found for permutation %x\n", perm); - // return NULL; - } - } -}; - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1a.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1a.cu deleted file mode 100644 index 9aaf957dd..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1a.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_1a(const int perm, bool r_weights, bool mul_r_weights) -{ - if (!r_weights && !mul_r_weights) return map_m_count_exl2_a::pick_gemm_half_q_half_kernel(perm); - // if (!r_weights && mul_r_weights) return map_m_count_exl2_a::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && !mul_r_weights) return map_m_count_exl2_a< true, false>::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && mul_r_weights) return map_m_count_exl2_a< true, true>::pick_gemm_half_q_half_kernel(perm); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1b.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1b.cu deleted file mode 100644 index f480b99c0..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_1b.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_1b(const int perm, bool r_weights, bool mul_r_weights) -{ - if (!r_weights && !mul_r_weights) return map_m_count_exl2_b::pick_gemm_half_q_half_kernel(perm); - // if (!r_weights && mul_r_weights) return map_m_count_exl2_b::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && !mul_r_weights) return map_m_count_exl2_b< true, false>::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && mul_r_weights) return map_m_count_exl2_b< true, true>::pick_gemm_half_q_half_kernel(perm); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2a.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2a.cu deleted file mode 100644 index 1eace7ed0..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2a.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_2a(const int perm, bool r_weights, bool mul_r_weights) -{ - // if (!r_weights && !mul_r_weights) return map_m_count_exl2_a::pick_gemm_half_q_half_kernel(perm); - // if (!r_weights && mul_r_weights) return map_m_count_exl2_a::pick_gemm_half_q_half_kernel(perm); - if ( r_weights && !mul_r_weights) return map_m_count_exl2_a< true, false>::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && mul_r_weights) return map_m_count_exl2_a< true, true>::pick_gemm_half_q_half_kernel(perm); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2b.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2b.cu deleted file mode 100644 index c63d5a5aa..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_2b.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_2b(const int perm, bool r_weights, bool mul_r_weights) -{ - // if (!r_weights && !mul_r_weights) return map_m_count_exl2_b::pick_gemm_half_q_half_kernel(perm); - // if (!r_weights && mul_r_weights) return map_m_count_exl2_b::pick_gemm_half_q_half_kernel(perm); - if ( r_weights && !mul_r_weights) return map_m_count_exl2_b< true, false>::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && mul_r_weights) return map_m_count_exl2_b< true, true>::pick_gemm_half_q_half_kernel(perm); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3a.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3a.cu deleted file mode 100644 index 07c0fda5a..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3a.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_3a(const int perm, bool r_weights, bool mul_r_weights) -{ - // if (!r_weights && !mul_r_weights) return map_m_count_exl2_a::pick_gemm_half_q_half_kernel(perm); - // if (!r_weights && mul_r_weights) return map_m_count_exl2_a::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && !mul_r_weights) return map_m_count_exl2_a< true, false>::pick_gemm_half_q_half_kernel(perm); - if ( r_weights && mul_r_weights) return map_m_count_exl2_a< true, true>::pick_gemm_half_q_half_kernel(perm); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3b.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3b.cu deleted file mode 100644 index fd0a88d87..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_exl2_3b.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel_3b(const int perm, bool r_weights, bool mul_r_weights) -{ - // if (!r_weights && !mul_r_weights) return map_m_count_exl2_b::pick_gemm_half_q_half_kernel(perm); - // if (!r_weights && mul_r_weights) return map_m_count_exl2_b::pick_gemm_half_q_half_kernel(perm); - // if ( r_weights && !mul_r_weights) return map_m_count_exl2_b< true, false>::pick_gemm_half_q_half_kernel(perm); - if ( r_weights && mul_r_weights) return map_m_count_exl2_b< true, true>::pick_gemm_half_q_half_kernel(perm); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_1.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_1.cu deleted file mode 100644 index 002a64112..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_1.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel_1(const int m_count, bool r_weights, bool mul_r_weights) -{ - if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - // if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - // if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); - // if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_2.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_2.cu deleted file mode 100644 index 9bee74647..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_2.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel_2(const int m_count, bool r_weights, bool mul_r_weights) -{ - // if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - // if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); - // if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_3.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_3.cu deleted file mode 100644 index d3bbd3bb3..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/comp_units/unit_gptq_3.cu +++ /dev/null @@ -1,10 +0,0 @@ -#include "kernel_select.cuh" - -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel_3(const int m_count, bool r_weights, bool mul_r_weights) -{ - // if (!r_weights && !mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - // if (!r_weights && mul_r_weights) return map_m_count_gptq::pick_gemm_half_q_half_gptq_kernel(m_count); - // if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count); - if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count); - return NULL; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cu deleted file mode 100644 index 6f536c53d..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cu +++ /dev/null @@ -1,274 +0,0 @@ -#include "h_gemm.cuh" -#include "util.cuh" -#include "../config.h" -#include "matrix_view.cuh" - -// union half2_uint32 -// { -// uint32_t as_uint32; -// half2 as_half2; -// __device__ half2_uint32(uint32_t val) : as_uint32(val) {} -// __device__ half2_uint32(half2 val) : as_half2(val) {} -// }; - -// TODO: Improve tall kernel, maybe special cases for size_n = 1, 2, 4, 8, 16 - -const int T_THREADS_M = 1; -const int T_THREADS_N = 8; -const int T_BLOCKSIZE_K = 32; -const int T_MAX_M = 16; -const int T_MAX_N = 64; -const int T_MAX_K = 1024 / T_THREADS_N * T_BLOCKSIZE_K; -const int T_MAX_BLOCKS_K = T_MAX_K / T_BLOCKSIZE_K; - -__global__ void h_gemm_tall_kernel -( - const int size_m, - const int size_n, - const int size_k, - const half* __restrict__ a, - const half* __restrict__ b, - half* __restrict__ c, - bool clear -) -{ - __shared__ half accum[T_MAX_BLOCKS_K][T_THREADS_N]; - - int m = blockIdx.y * T_THREADS_M + threadIdx.z; - int n = blockIdx.x * T_THREADS_N + threadIdx.x; - int k = threadIdx.y * T_BLOCKSIZE_K; - - if (n >= size_n) return; - if (m >= size_m) return; - - MatrixView_half a_(a, size_m, size_k); - MatrixView_half b_(b, size_k, size_n); - MatrixView_half_rw c_(c, size_m, size_n); - - int k_end = min(k + T_BLOCKSIZE_K, size_k); - - const half* a_ptr = a_.item_ptr(m, k); - const half* a_ptr_end = a_.item_ptr(m, k_end); - const half* b_ptr = b_.item_ptr(k, n); - half* c_ptr = c_.item_ptr(m, n); - - half2 r2 = {}; - - while(a_ptr <= a_ptr_end - 8) - { - int4 a_int4 = *((int4*) a_ptr); - half2 a_01 = ((half2_uint32) a_int4.x).as_half2; - half2 a_23 = ((half2_uint32) a_int4.y).as_half2; - half2 a_45 = ((half2_uint32) a_int4.z).as_half2; - half2 a_67 = ((half2_uint32) a_int4.w).as_half2; - a_ptr += 8; - - half b_0 = *b_ptr; b_ptr += size_n; - half b_1 = *b_ptr; b_ptr += size_n; - half b_2 = *b_ptr; b_ptr += size_n; - half b_3 = *b_ptr; b_ptr += size_n; - half b_4 = *b_ptr; b_ptr += size_n; - half b_5 = *b_ptr; b_ptr += size_n; - half b_6 = *b_ptr; b_ptr += size_n; - half b_7 = *b_ptr; b_ptr += size_n; - half2 b_01 = __halves2half2(b_0, b_1); - half2 b_23 = __halves2half2(b_2, b_3); - half2 b_45 = __halves2half2(b_4, b_5); - half2 b_67 = __halves2half2(b_6, b_7); - - r2 = __hfma2(a_01, b_01, r2); - r2 = __hfma2(a_23, b_23, r2); - r2 = __hfma2(a_45, b_45, r2); - r2 = __hfma2(a_67, b_67, r2); - } - - while(a_ptr <= a_ptr_end - 4) - { - int2 a_int2 = *((int2*) a_ptr); - half2 a_01 = ((half2_uint32) a_int2.x).as_half2; - half2 a_23 = ((half2_uint32) a_int2.y).as_half2; - a_ptr += 4; - - half b_0 = *b_ptr; b_ptr += size_n; - half b_1 = *b_ptr; b_ptr += size_n; - half b_2 = *b_ptr; b_ptr += size_n; - half b_3 = *b_ptr; b_ptr += size_n; - half2 b_01 = __halves2half2(b_0, b_1); - half2 b_23 = __halves2half2(b_2, b_3); - - r2 = __hfma2(a_01, b_01, r2); - r2 = __hfma2(a_23, b_23, r2); - } - - half r = __hadd(__low2half(r2), __high2half(r2)); - - while(a_ptr < a_ptr_end) - { - half a_item = *a_ptr++; - half b_item = *b_ptr; b_ptr += size_n; - r = __hfma(a_item, b_item, r); - } - - accum[threadIdx.y][threadIdx.x] = r; - __syncthreads(); - - if (threadIdx.y == 0) - { - half acc = accum[0][threadIdx.x]; - for (int i = 1; i < blockDim.y; ++i) acc = __hadd(accum[i][threadIdx.x], acc); - if (!clear) acc = __hadd(acc, *c_ptr); - *c_ptr = acc; - } -} - - -const int W_MAX_M = 16; -const int W_MAX_N = 65536; -const int W_MAX_K = 32; -const int W_THREADS_M = 1; -const int W_THREADS_N = 32; - -__global__ void h_gemm_wide_kernel -( - const int size_m, - const int size_n, - const int size_k, - const half* __restrict__ a, - const half* __restrict__ b, - half* __restrict__ c, - bool clear -) -{ - int m = blockIdx.y * W_THREADS_M + threadIdx.y; - int n = blockIdx.x * W_THREADS_N + threadIdx.x; - - if (m >= size_m) return; - - MatrixView_half a_(a, size_m, size_k); - MatrixView_half b_(b, size_k, size_n); - MatrixView_half_rw c_(c, size_m, size_n); - - half* c_ptr = c_.item_ptr(m, n); - - __shared__ half read_a[W_MAX_K]; - int t = threadIdx.x; - - if (t < size_k) - { - read_a[t] = a_.item(m, t); - } - __syncthreads(); - - if (n >= size_n) return; - - half r = {}; - - for (int k = 0; k < size_k; ++k) - { - half item_a = read_a[k]; - half item_b = b_.item(k, n); - r = __hfma(item_a, item_b, r); - } - - if (threadIdx.y == 0) - { - if (!clear) r = __hadd(r, *c_ptr); - *c_ptr = r; - } -} - - -// cuBLAS - -void h_gemm_cublas -( - cublasHandle_t cublas_handle, - const int size_m, - const int size_n, - const int size_k, - const half* a, - const half* b, - half* c, - const float alpha, - const float beta -) -{ - half alpha_ = __float2half(alpha); - half beta_ = __float2half(beta); - cublasHgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - size_n, size_m, size_k, - &alpha_, b, size_n, - a, size_k, - &beta_, c, size_n); -} - - -// alpha * ( a[m,k] @ b[k,n] ) + beta * c[m,n] -> c[m,n] - -void h_gemm_cuda -( - cublasHandle_t cublas_handle, - const int size_m, - const int size_n, - const int size_k, - const half* a, - const half* b, - half* c, - const float alpha, - const float beta -) -{ - if ((beta == 1.0f || beta == 0.0f) && (alpha == 1.0f)) - { - bool clear = (beta == 0.0f); - - //DBGI3(size_m, size_n, size_k); - - if (size_m <= T_MAX_M && size_n <= T_MAX_N && size_k <= T_MAX_K) - { - // Tall - - dim3 blockDim, gridDim; - blockDim.x = T_THREADS_N; - blockDim.y = DIVIDE(size_k, T_BLOCKSIZE_K); - blockDim.z = T_THREADS_M; - gridDim.x = DIVIDE(size_n, T_THREADS_N); - gridDim.y = DIVIDE(size_m, T_THREADS_M); - gridDim.z = 1; - -// DBGI3(blockDim.x, blockDim.y, blockDim.z); -// DBGI3(gridDim.x, gridDim.y, gridDim.z); - - h_gemm_tall_kernel<<>>(size_m, size_n, size_k, a, b, c, clear); - cuda_check( cudaPeekAtLastError() ); - return; - } - - if (size_m <= W_MAX_M && size_n <= W_MAX_N && size_k <= W_MAX_K) - { - // Wide - - dim3 blockDim, gridDim; - blockDim.x = W_THREADS_N; - blockDim.y = W_THREADS_M; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, W_THREADS_N); - gridDim.y = DIVIDE(size_m, W_THREADS_M); - gridDim.z = 1; - -// DBGI3(blockDim.x, blockDim.y, blockDim.z); -// DBGI3(gridDim.x, gridDim.y, gridDim.z); - - h_gemm_wide_kernel<<>>(size_m, size_n, size_k, a, b, c, clear); - cuda_check( cudaPeekAtLastError() ); - return; - } - } - - h_gemm_cublas(cublas_handle, size_m, size_n, size_k, a, b, c, alpha, beta); -// DBGI3(size_m, size_n, size_k); - cuda_check( cudaPeekAtLastError() ); - -} \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cuh deleted file mode 100644 index 4f9bab454..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/h_gemm.cuh +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef _h_gemm_cuh -#define _h_gemm_cuh - -#include -#include -#include -#include -#include - -// alpha * ( a[m,k] @ b[k,n] ) + beta * c[m,n] -> c[m,n] - -void h_gemm_cuda -( - cublasHandle_t cublas_handle, - const int size_m, - const int size_n, - const int size_k, - const half* a, - const half* b, - half* c, - const float alpha, - const float beta -); - -void h_gemm_cublas -( - cublasHandle_t cublas_handle, - const int size_m, - const int size_n, - const int size_k, - const half* a, - const half* b, - half* c, - const float alpha, - const float beta -); - -#endif - diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cu deleted file mode 100644 index d81e98b2e..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cu +++ /dev/null @@ -1,171 +0,0 @@ -#include "layer_norm.cuh" -#include "util.cuh" - -#if defined(USE_ROCM) -#define __shfl_xor_sync(mask, var, laneMask) __shfl_xor(var, laneMask) -#define NUM_WARPS (1024 / warpSize) -#define WARP_SIZE (warpSize) -#else -#define NUM_WARPS 32 -#define WARP_SIZE 32 -#endif - -// y = x * w / sqrt(row_mean(x * x) + epsilon) - -#define BLOCK_SIZE WARP_SIZE -#define NUM_THREADS (NUM_WARPS * WARP_SIZE) - -typedef void (*fp_layer_norm_kernel) -( - const half*, - const half*, - const half*, - half*, - const float, - const float, - const int, - const int -); - -template -__global__ void layer_norm_kernel -( - const half* __restrict__ x, - const half* __restrict__ w, - const half* __restrict__ b, - half* __restrict__ y, - const float epsilon, - const float r_dim, - const int rows, - const int dim -) -{ - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - int row = blockIdx.x; - const half* x_row = x + row * dim; - half* y_row = y + row * dim; - - float itemf[blocks_per_warp]; - - // Compute sum for each block - - float sum = 0.0f; - - #pragma unroll - for (int i = 0; i < blocks_per_warp; i++) - { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim) break; - - float f = __half2float(x_row[column]); - f = fmaxf(-65504.0f, fminf(f, 65504.0f)); - itemf[i] = f; - sum += f; - } - - // Shuffle to sum across lanes - - __shared__ float sums[NUM_WARPS]; - - for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); - if (lane_id == 0) sums[warp_id] = sum; - __syncthreads(); - - // Load partial sums from across warps, shuffle again across lanes - - sum = sums[lane_id]; - for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); - - // Compute mean - - float mean = sum * r_dim; - - // Compute square of distance to mean - - sum = 0.0f; - - #pragma unroll - for (int i = 0; i < blocks_per_warp; i++) - { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim) break; - - float f = itemf[i]; - f = fmaxf(-65504.0f, fminf(f, 65504.0f)); - f -= mean; - itemf[i] = f; - sum = fma(f, f, sum); - } - - // Shuffle to sum across lanes - - for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); - if (lane_id == 0) sums[warp_id] = sum; - __syncthreads(); - - // Load partial sums from across warps, shuffle again across lanes - - sum = sums[lane_id]; - for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); - - // Get 1/sqrt(variance) - - float rsvar = rsqrtf(sum * r_dim + epsilon); - - // Normalize x, scaling by w - - #pragma unroll 4 - for (int i = 0; i < blocks_per_warp; i++) - { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim) return; - - float x_itemf = itemf[i]; - float w_itemf = __half2float(w[column]); - float n = x_itemf * w_itemf * rsvar; - - half nh = __float2half_rn(n); - if (b) nh = __hadd(nh, b[column]); // Optional bias - - y_row[column] = nh; - } -} - -fp_layer_norm_kernel pick_layer_norm_kernel(const int blocks_per_warp) -{ - if (blocks_per_warp == 1) return layer_norm_kernel<1>; - if (blocks_per_warp == 2) return layer_norm_kernel<2>; - if (blocks_per_warp == 3) return layer_norm_kernel<3>; - if (blocks_per_warp == 4) return layer_norm_kernel<4>; - if (blocks_per_warp == 5) return layer_norm_kernel<5>; - if (blocks_per_warp == 6) return layer_norm_kernel<6>; - if (blocks_per_warp == 7) return layer_norm_kernel<7>; - if (blocks_per_warp == 8) return layer_norm_kernel<8>; - return NULL; -} - - -void layer_norm_cuda -( - const half* x, - const half* w, - const half* b, - half* y, - const float epsilon, - const int rows, - const int dim -) -{ - dim3 blockDim, gridDim; - blockDim.x = NUM_THREADS; - blockDim.y = 1; - gridDim.x = rows; - gridDim.y = 1; - - float r_dim = 1.0f / (float) dim; - - int blocks_per_warp = DIVIDE(dim, NUM_THREADS); - fp_layer_norm_kernel kernel = pick_layer_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, b, y, epsilon, r_dim, rows, dim); -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cuh deleted file mode 100644 index 2b10d4d65..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/layer_norm.cuh +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _layer_norm_cuh -#define _layer_norm_cuh - -#include -#include -#include -#include - -void layer_norm_cuda -( - const half* x, - const half* w, - const half* b, - half* y, - const float epsilon, - const int rows, - const int dim -); - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cu deleted file mode 100644 index 49a19ec0f..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cu +++ /dev/null @@ -1,33 +0,0 @@ -#include "lora.cuh" -#include "util.cuh" -#include "h_gemm.cuh" - -void apply_loras_cuda -( - cublasHandle_t cublas_handle, - const std::unordered_map>& adapters, - const std::vector& ids, - QMatrix* base, - const half* input, - half* output, - half* temp, - int rows -) -{ - for (uintptr_t lora_id : ids) - { - auto it = adapters.find(lora_id); - if (it == adapters.end()) continue; - - const std::tuple& lora = it->second; - half* lora_a = std::get<0>(lora); - half* lora_b = std::get<1>(lora); - int rank = std::get<2>(lora); - -// DBGI3(rows, rank, base->height); -// DBGI3(rows, base->width, rank); - - h_gemm_cuda(cublas_handle, rows, rank, base->height, input, lora_a, temp, 1.0f, 0.0f); - h_gemm_cuda(cublas_handle, rows, base->width, rank, temp, lora_b, output, 1.0f, 1.0f); - } -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cuh deleted file mode 100644 index b505a201c..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/lora.cuh +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _lora_cuh -#define _lora_cuh - -#include -#include -#include -#include -#include - -#include "q_matrix.cuh" - -void apply_loras_cuda -( - cublasHandle_t cublas_handle, - const std::unordered_map>& adapters, - const std::vector& ids, - QMatrix* base, - const half* input, - half* output, - half* temp, - int rows -); - -#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh index 55af84f23..a72bc7bc1 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh @@ -118,4 +118,4 @@ public: } }; -#endif \ No newline at end of file +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cu deleted file mode 100644 index 3a29f2fcd..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cu +++ /dev/null @@ -1,268 +0,0 @@ -#include "pack_tensor.cuh" -#include "util.cuh" - -#define BLOCKSIZE_X 32 -#define BLOCKSIZE_Y 16 - -// Pack rows: -// 0000 0000 0000 aaaa 0000 0000 0000 bbbb 0000 0000 0000 cccc ... -> hhhh gggg ffff eeee dddd cccc bbbb aaaa - -__global__ void pack_rows_4_kernel -( - const uint16_t* __restrict__ input, - uint32_t* __restrict__ output, - int rows, - int out_columns -) -{ - int out_column = blockIdx.x * blockDim.x + threadIdx.x; - int row = blockIdx.y * blockDim.y + threadIdx.y; - - if (row >= rows) return; - if (out_column >= out_columns) return; - - uint32_t packed = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - uint16_t x = input[row * out_columns * 8 + out_column * 8 + i]; - x -= 1; - packed |= (((uint32_t)x) << (i * 4)); - } - - output[row * out_columns + out_column] = packed; -} - -void pack_rows_4_cuda -( - const uint16_t* input, - uint32_t* output, - const int rows, - const int columns -) -{ - int out_columns = columns * 4 / 32; - - dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y); - dim3 blocks(DIVIDE(out_columns, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y)); - - pack_rows_4_kernel<<>>(input, output, rows, out_columns); -} - -// Pack rows: -// 0000 0000 0000 aaaa 0000 0000 0000 bbbb 0000 0000 0000 cccc ... -> hhhh gggg ffff eeee dddd cccc bbbb aaaa - -__global__ void pack_rows_6_kernel -( - const uint16_t* __restrict__ input, - uint32_t* __restrict__ output, - int rows, - int out_columns -) -{ - int out_column = blockIdx.x * blockDim.x + threadIdx.x; - int row = blockIdx.y * blockDim.y + threadIdx.y; - - if (row >= rows) return; - if (out_column >= out_columns) return; - - uint32_t packed = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - uint16_t x = input[row * out_columns * 8 + out_column * 8 + i]; - x -= 1; - packed |= (((uint32_t)x) << (i * 4)); - } - - output[row * out_columns + out_column] = packed; -} - -void pack_rows_6_cuda -( - const uint16_t* input, - uint32_t* output, - const int rows, - const int columns -) -{ - int out_columns = columns * 6 / 32; - - dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y); - dim3 blocks(DIVIDE(out_columns, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y)); - - pack_rows_6_kernel<<>>(input, output, rows, out_columns); -} - -// Pack columns - -__forceinline__ __device__ uint32_t wshift(uint32_t x, int j) -{ - if (j < 0) - { - if (j <= -32) return 0; // Else undefined in CUDA - return x >> (-j); - } - else - { - if (j >= 32) return 0; // Else undefined in CUDA - return x << j; - } -} - -template -__global__ void pack_columns_kernel -( - const uint16_t* __restrict__ input, - uint32_t* __restrict__ output, - int out_rows, - int columns -) -{ - int column = blockIdx.x * blockDim.x + threadIdx.x; - int out_row = blockIdx.y * blockDim.y + threadIdx.y; - - if (column >= columns) return; - if (out_row >= out_rows) return; - - uint32_t x; - - if constexpr (bits == 2) - { - int row = out_row * 32 / 2; - uint32_t packed = 0; - - # pragma unroll - for (int i = 0, j = 0; i < 16; i++, j += 2) - { - x = (uint32_t) input[(row + i) * columns + column]; - packed |= (x << (i * 2)); - } - output[out_row * columns + column] = packed; - } - - if constexpr (bits == 3) - { - if (out_row % 3) return; // Only run for every third row - int row = out_row * 32 / 3; - uint32_t packed0 = 0; - uint32_t packed1 = 0; - uint32_t packed2 = 0; - - #pragma unroll - for (int i = 0, j = 0; i < 32; i++, j += 3) - { - x = (uint32_t) input[(row + i) * columns + column]; - packed0 |= wshift(x, j); - packed1 |= wshift(x, j - 32); - packed2 |= wshift(x, j - 64); - } - - output[(out_row + 0) * columns + column] = packed0; - output[(out_row + 1) * columns + column] = packed1; - output[(out_row + 2) * columns + column] = packed2; - } - - if constexpr (bits == 4) - { - int row = out_row * 32 / 4; - uint32_t packed = 0; - - #pragma unroll - for (int i = 0, j = 0; i < 8; i++, j += 4) - { - x = (uint32_t) input[(row + i) * columns + column]; - packed |= (x << j); - } - output[out_row * columns + column] = packed; - } - - if constexpr (bits == 5) - { - if (out_row % 5) return; // Only run for every fifth row - int row = out_row * 32 / 5; - uint32_t packed0 = 0; - uint32_t packed1 = 0; - uint32_t packed2 = 0; - uint32_t packed3 = 0; - uint32_t packed4 = 0; - - #pragma unroll - for (int i = 0, j = 0; i < 32; i++, j += 5) - { - x = (uint32_t) input[(row + i) * columns + column]; - packed0 |= wshift(x, j); - packed1 |= wshift(x, j - 32); - packed2 |= wshift(x, j - 64); - packed3 |= wshift(x, j - 96); - packed4 |= wshift(x, j - 128); - } - - output[(out_row + 0) * columns + column] = packed0; - output[(out_row + 1) * columns + column] = packed1; - output[(out_row + 2) * columns + column] = packed2; - output[(out_row + 3) * columns + column] = packed3; - output[(out_row + 4) * columns + column] = packed4; - } - - if constexpr (bits == 6) - { - if (out_row % 3) return; // Only run for every third row - int row = out_row * 32 / 6; - uint32_t packed0 = 0; - uint32_t packed1 = 0; - uint32_t packed2 = 0; - - #pragma unroll - for (int i = 0, j = 0; i < 16; i++, j += 6) - { - x = (uint32_t) input[(row + i) * columns + column]; - packed0 |= wshift(x, j); - packed1 |= wshift(x, j - 32); - packed2 |= wshift(x, j - 64); - } - - output[(out_row + 0) * columns + column] = packed0; - output[(out_row + 1) * columns + column] = packed1; - output[(out_row + 2) * columns + column] = packed2; - } - - if constexpr (bits == 8) - { - int row = out_row * 32 / 8; - uint32_t packed = 0; - - #pragma unroll - for (int i = 0, j = 0; i < 4; i++, j += 8) - { - x = (uint32_t) input[(row + i) * columns + column]; - packed |= (x << j); - } - - output[out_row * columns + column] = packed; - } -} - -void pack_columns_cuda -( - const uint16_t* input, - uint32_t* output, - const int in_rows, - const int out_rows, - const int columns, - const int bits -) -{ - dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y); - dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), DIVIDE(out_rows, BLOCKSIZE_Y)); - - if (bits == 2) pack_columns_kernel<2><<>>(input, output, out_rows, columns); - if (bits == 3) pack_columns_kernel<3><<>>(input, output, out_rows, columns); - if (bits == 4) pack_columns_kernel<4><<>>(input, output, out_rows, columns); - if (bits == 5) pack_columns_kernel<5><<>>(input, output, out_rows, columns); - if (bits == 6) pack_columns_kernel<6><<>>(input, output, out_rows, columns); - if (bits == 8) pack_columns_kernel<8><<>>(input, output, out_rows, columns); -} - diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cuh deleted file mode 100644 index 4231cbb9c..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/pack_tensor.cuh +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _pack_tensor_cuh -#define _pack_tensor_cuh - -#include -#include -#include -#include - -void pack_rows_4_cuda -( - const uint16_t* input, - uint32_t* output, - const int rows, - const int columns -); - -void pack_rows_6_cuda -( - const uint16_t* input, - uint32_t* output, - const int rows, - const int columns -); - -void pack_columns_cuda -( - const uint16_t* input, - uint32_t* output, - const int in_rows, - const int out_rows, - const int columns, - const int bits -); - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cu deleted file mode 100644 index 158fffd41..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cu +++ /dev/null @@ -1,167 +0,0 @@ -#include "q_attn.cuh" -#include "q_gemm.cuh" -#include "rms_norm.cuh" -#include "layer_norm.cuh" -#include "rope.cuh" -#include "util.cuh" -#include "lora.cuh" - -const int THREADS_X = 32; -const int THREADS_Y = 1; -const int THREADS_Z = 4; -const int BLOCKSIZE_X = 2; // 2*half == 1*uint32_t -const int BLOCKSIZE_Z = 4; // num_heads must be divisible by BLOCKSIZE_Z - -__global__ void update_cache_kernel -( - const half* __restrict__ key_states, - const half* __restrict__ value_states, - half* __restrict__ key_cache, - half* __restrict__ value_cache, - const int head_dim, - const int num_kv_heads, - const int q_len, - const int cache_seq_len, - const int past_len -) -{ - //int state_shape[] = { num_kv_heads, q_len, head_dim }; - int state_stride[] = { head_dim, head_dim * num_kv_heads, 1 }; - int state_pos[] = { 0, 0, 0 }; - - //int cache_shape[] = { num_kv_heads, cache_seq_len, head_dim }; - int cache_stride[] = { cache_seq_len * head_dim, head_dim, 1 }; - int cache_pos[] = { 0, past_len, 0 }; - - int size[] = { num_kv_heads, q_len, head_dim }; - - int x = (blockIdx.x * THREADS_X + threadIdx.x) * BLOCKSIZE_X; - int y = blockIdx.y * THREADS_Y + threadIdx.y; - int z = (blockIdx.z * THREADS_Z + threadIdx.z) * BLOCKSIZE_Z; - - if (x >= size[2]) return; - if (y >= size[1]) return; - if (z >= size[0]) return; - - int state_offset = (z + state_pos[0]) * state_stride[0] + (y + state_pos[1]) * state_stride[1] + (x + state_pos[2]) * state_stride[2]; - int cache_offset = (z + cache_pos[0]) * cache_stride[0] + (y + cache_pos[1]) * cache_stride[1] + (x + cache_pos[2]) * cache_stride[2]; - - const uint32_t* key_ptr = (uint32_t*) (key_states + state_offset); - const uint32_t* value_ptr = (uint32_t*) (value_states + state_offset); - uint32_t* key_cache_ptr = (uint32_t*) (key_cache + cache_offset); - uint32_t* value_cache_ptr = (uint32_t*) (value_cache + cache_offset); - - #pragma unroll - for (int k = 0; k < BLOCKSIZE_Z; k++) - { - *key_cache_ptr = *key_ptr; - key_ptr += state_stride[0] / BLOCKSIZE_X; - key_cache_ptr += cache_stride[0] / BLOCKSIZE_X; - } - - #pragma unroll - for (int k = 0; k < BLOCKSIZE_Z; k++) - { - *value_cache_ptr = *value_ptr; - value_ptr += state_stride[0] / BLOCKSIZE_X; - value_cache_ptr += cache_stride[0] / BLOCKSIZE_X; - } -} - -QAttn::QAttn -( - half* _layernorm, - half* _layernorm_bias, - bool _layernorm_is_rms, - float _norm_epsilon, - QMatrix* _q_proj, - QMatrix* _k_proj, - QMatrix* _v_proj, - QMatrix* _o_proj, - half* _temp_state, -// half* _temp_q, -// half* _temp_k, -// half* _temp_v, - half* _temp_dq, - int _max_rows, - int _hidden_size, - int _num_heads, - int _num_kv_heads, - int _head_dim, - int _max_seq_len -): - layernorm(_layernorm), - layernorm_bias(_layernorm_bias), - layernorm_is_rms(_layernorm_is_rms), - norm_epsilon(_norm_epsilon), - q_proj(_q_proj), - k_proj(_k_proj), - v_proj(_v_proj), - o_proj(_o_proj), - temp_state(_temp_state), -// temp_q(_temp_q), -// temp_k(_temp_k), -// temp_v(_temp_v), - temp_dq(_temp_dq), - max_rows(_max_rows), - hidden_size(_hidden_size), - num_heads(_num_heads), - num_kv_heads(_num_kv_heads), - head_dim(_head_dim), - max_seq_len(_max_seq_len) -{ -} - -QAttn::~QAttn() -{ -} - -void QAttn::forward_cuda_1 -( - cublasHandle_t cublas_handle, - half* x, - int batch_size, - int q_len, - int past_len, - const int32_t* past_lens, - half* temp_q, - half* temp_k, - half* temp_v, - const half* sin, - const half* cos, - const std::vector& loras, - half* lora_temp -) -{ - if (layernorm_is_rms) - rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, q_len * batch_size, hidden_size); - else - layer_norm_cuda(x, layernorm, layernorm_bias, temp_state, norm_epsilon, q_len * batch_size, hidden_size); - - gemm_half_q_half_cuda(cublas_handle, temp_state, q_proj, temp_q, q_len * batch_size, q_proj->width, hidden_size, true, temp_dq); - gemm_half_q_half_cuda(cublas_handle, temp_state, k_proj, temp_k, q_len * batch_size, k_proj->width, hidden_size, true, temp_dq); - gemm_half_q_half_cuda(cublas_handle, temp_state, v_proj, temp_v, q_len * batch_size, v_proj->width, hidden_size, true, temp_dq); - - apply_loras_cuda(cublas_handle, q_proj_lora, loras, q_proj, temp_state, temp_q, lora_temp, q_len * batch_size); - apply_loras_cuda(cublas_handle, k_proj_lora, loras, k_proj, temp_state, temp_k, lora_temp, q_len * batch_size); - apply_loras_cuda(cublas_handle, v_proj_lora, loras, v_proj, temp_state, temp_v, lora_temp, q_len * batch_size); - - rope_cuda(temp_q, sin, cos, batch_size, q_len * num_heads, head_dim, num_heads, past_len, past_lens); - rope_cuda(temp_k, sin, cos, batch_size, q_len * num_kv_heads, head_dim, num_kv_heads, past_len, past_lens); -} - -void QAttn::forward_cuda_2 -( - cublasHandle_t cublas_handle, - const half* attn_output, - half* hidden_state, - int q_len, - int batch_size, - const std::vector& loras, - half* lora_temp -) -{ - gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, hidden_size, false, temp_dq); - - apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, hidden_state, lora_temp, q_len * batch_size); -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cuh deleted file mode 100644 index 0fa3b321b..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_attn.cuh +++ /dev/null @@ -1,102 +0,0 @@ -#ifndef _q_attn_cuh -#define _q_attn_cuh - -#include -#include -#include -#include -#include - -#include "q_matrix.cuh" - -class QAttn -{ -public: - - half* layernorm; - half* layernorm_bias; - bool layernorm_is_rms; - float norm_epsilon; - - QMatrix* q_proj; - QMatrix* k_proj; - QMatrix* v_proj; - QMatrix* o_proj; - - half* temp_state; -// half* temp_q; -// half* temp_k; -// half* temp_v; - half* temp_dq; - - int device; - int max_rows; - int hidden_size; - int num_heads; - int num_kv_heads; - int head_dim; - int max_seq_len; - - std::unordered_map> q_proj_lora; - std::unordered_map> k_proj_lora; - std::unordered_map> v_proj_lora; - std::unordered_map> o_proj_lora; - - QAttn - ( - half* _layernorm, - half* _layermorm_bias, - bool _layernorm_is_rms, - float _norm_epsilon, - QMatrix* _q_proj, - QMatrix* _k_proj, - QMatrix* _v_proj, - QMatrix* _o_proj, - half* _temp_state, -// half* _temp_q, -// half* _temp_k, -// half* _temp_v, - half* _temp_dq, - int _max_rows, - int _hidden_size, - int _num_heads, - int _num_kv_heads, - int _head_dim, - int _max_seq_len - ); - - ~QAttn(); - - void forward_cuda_1 - ( - cublasHandle_t cublas_handle, - half* x, - int batch_size, - int q_len, - int past_len, - const int32_t* past_lens, - half* temp_q, - half* temp_k, - half* temp_v, - const half* sin, - const half* cos, - const std::vector& loras, - half* lora_temp - ); - - void forward_cuda_2 - ( - cublasHandle_t cublas_handle, - const half* attn_output, - half* hidden_state, - int q_len, - int batch_size, - const std::vector& loras, - half* lora_temp - ); - -private: - -}; - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu index b4e4cf22b..5b99f1ba8 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu @@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part bool mul_r_weights ) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (!b->is_gptq) { dim3 blockDim, gridDim; @@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights); - kernel<<>> + kernel<<>> ( a, b->cuda_q_weight, @@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part // print_global_mem(r_weights, 1, 1, 1); // DBGI(r_weights_stride); - kernel<<>> + kernel<<>> ( a, b->cuda_q_weight, diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh index b643f915a..e49457f3a 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh @@ -33,4 +33,4 @@ void clear_tensor_cuda int size_n ); -#endif \ No newline at end of file +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh index 74b0db2b0..f816fd9dd 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); // __syncthreads(); @@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); } #pragma unroll diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index ae08cc1fe..f7a91e29e 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -168,8 +168,9 @@ QMatrix::QMatrix blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = 1; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); + shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); } QMatrix::~QMatrix() @@ -237,10 +238,10 @@ __global__ void reconstruct_gptq_kernel half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); __syncthreads(); @@ -255,10 +256,10 @@ __global__ void reconstruct_gptq_kernel nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); } for (int p = 0; p < 4; p++) @@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out) blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (!is_gptq) { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_kernel<<>> + reconstruct_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out) else { gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4); - reconstruct_gptq_kernel<<>> + reconstruct_gptq_kernel<<>> ( cuda_q_weight, cuda_q_perm, @@ -563,6 +565,7 @@ __global__ void make_sequential_kernel bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); uint32_t* cuda_new_qweight = NULL; cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); if (err != cudaSuccess) { @@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx) gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = height / 8; - make_sequential_kernel<<>> + make_sequential_kernel<<>> ( cuda_q_weight, cuda_new_qweight, diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cu deleted file mode 100644 index 1de12a9a5..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cu +++ /dev/null @@ -1,332 +0,0 @@ -#include "q_mlp.cuh" -#include "q_gemm.cuh" -#include "rms_norm.cuh" -#include "layer_norm.cuh" -#include "util.cuh" -#include "matrix_view.cuh" -#include "lora.cuh" -#include "quant/qdq_util.cuh" -#include "../config.h" - -#include "q_mlp_softmax.cuh" - -#if defined(USE_ROCM) -__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half2_raw>(x).data.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half2_raw>(x).data.y))}; -} -#define h2rcp __compat_h2rcp -#endif - -const int THREADS_X = 32; -const int THREADS_Y = 4; -// const int MAX_DIMENSION = 8192; - -__device__ __forceinline__ half silu(half x) -{ - half one = __float2half(1.0f); - half neg_x = __hneg(x); - half e = hexp(neg_x); - half sum = __hadd(one, e); - half r = hrcp(sum); - half result = __hmul(x, r); - return result; -} - -__device__ __forceinline__ half2 silu(half2 x) -{ - half2 one = __float2half2_rn(1.0f); - half2 neg_x = __hneg2(x); - half2 e = h2exp(neg_x); - half2 sum = __hadd2(one, e); - half2 r = h2rcp(sum); - half2 result = __hmul2(x, r); - return result; -} - -typedef void (*fp_silu_mul_kernel) -( - half*, - const half*, - const int, - const int, - const half*, - const int -); - -template -__global__ void silu_mul_kernel -( - half* __restrict__ x, - const half* __restrict__ y, - const int height, - const int width, - const half* r_weights, - const int r_weights_stride -) -{ - MatrixView_half_rw x_(x, height, width); - MatrixView_half y_(y, height, width); - - int column = (THREADS_X * blockIdx.x + threadIdx.x); if constexpr (use_half2) column *= 2; - int row = THREADS_Y * blockIdx.y + threadIdx.y; - if (row >= height) return; - - if constexpr (use_r_weights) - { - half_uint16 weight(r_weights[row * r_weights_stride]); - if (!weight.as_uint16) - { -// half2 ppp = __float2half2_rn(6.9f); -// x_.set_half2(row, column, ppp); - return; - } - } - - // silu(x) * y - - if constexpr (use_half2) - { - half2 x_item = x_.item_half2(row, column); - half2 y_item = y_.item_half2(row, column); - - x_item = silu(x_item); - x_item = __hmul2(x_item, y_item); - - x_.set_half2(row, column, x_item); - } - else - { - half x_item = x_.item(row, column); - half y_item = y_.item(row, column); - - x_item = silu(x_item); - x_item = __hmul(x_item, y_item); - - x_.set(row, column, x_item); - } -} - -fp_silu_mul_kernel pick_silu_mul_kernel(bool use_half2, bool mul_r_weights) -{ - if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false>; - if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true>; - if (!use_half2 && !mul_r_weights) return silu_mul_kernel; - if (!use_half2 && mul_r_weights) return silu_mul_kernel; - return NULL; -}; - -QMLP::QMLP -( - half* _layernorm, - half* _layernorm_bias, - bool _layernorm_is_rms, - float _norm_epsilon, - QMatrix* _gate, - QMatrix* _up, - QMatrix* _down, - half* _temp_state, - half* _temp_a, - half* _temp_b, - half* _temp_dq, - int _max_rows -): - layernorm(_layernorm), - layernorm_bias(_layernorm_bias), - layernorm_is_rms(_layernorm_is_rms), - norm_epsilon(_norm_epsilon), - gate(_gate), - up(_up), - down(_down), - temp_state(_temp_state), - temp_a(_temp_a), - temp_b(_temp_b), - temp_dq(_temp_dq), - max_rows(_max_rows) -{ -} - -QMLP::~QMLP() { -} - -void QMLP::forward_ -( - cublasHandle_t cublas_handle, - half* x, - int rows, - int columns, - const std::vector& loras, - half* lora_temp -) -{ - bool use_half2 = true; - int intermediate_size = gate->width; - - if (layernorm_is_rms) - rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, rows, columns); - else - layer_norm_cuda(x, layernorm, layernorm_bias, temp_state, norm_epsilon, rows, columns); - - gemm_half_q_half_cuda(cublas_handle, temp_state, gate, temp_a, rows, intermediate_size, columns, true, temp_dq); - gemm_half_q_half_cuda(cublas_handle, temp_state, up, temp_b, rows, intermediate_size, columns, true, temp_dq); - - apply_loras_cuda(cublas_handle, gate_proj_lora, loras, gate, temp_state, temp_a, lora_temp, rows); - apply_loras_cuda(cublas_handle, up_proj_lora, loras, up, temp_state, temp_b, lora_temp, rows); - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = THREADS_Y; - gridDim.x = DIVIDE(up->width, THREADS_X) / (use_half2 ? 2 : 1); - gridDim.y = DIVIDE(rows, THREADS_Y); - - fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, false); - kernel<<>>(temp_a, temp_b, rows, intermediate_size, NULL, 0); - - gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, false, temp_dq); - - apply_loras_cuda(cublas_handle, down_proj_lora, loras, down, temp_a, x, lora_temp, rows); -} - - -QMoEMLP::QMoEMLP -( - half* _layernorm, - half* _layernorm_bias, - bool _layernorm_is_rms, - float _norm_epsilon, - half* _gate, - int _num_experts, - int _num_experts_per_token, - std::vector& _w1, - std::vector& _w2, - std::vector& _w3, - half* _temp_state, - half* _temp_gathered_state, - half* _temp_a, - half* _temp_b, - half* _temp_logits, - half* _temp_dq, - int _max_rows, - int _hidden_dim -): - layernorm(_layernorm), - layernorm_bias(_layernorm_bias), - layernorm_is_rms(_layernorm_is_rms), - norm_epsilon(_norm_epsilon), - gate(_gate), - num_experts(_num_experts), - num_experts_per_token(_num_experts_per_token), - w1(_w1), - w2(_w2), - w3(_w3), - temp_state(_temp_state), - temp_gathered_state(_temp_gathered_state), - temp_a(_temp_a), - temp_b(_temp_b), - temp_logits(_temp_logits), - temp_dq(_temp_dq), - max_rows(_max_rows), - hidden_dim(_hidden_dim) -{ -// for (int i = 0; i < num_experts; ++i) -// { -// std::unordered_map> w1; -// std::unordered_map> w2; -// std::unordered_map> w3; -// w1_lora.push_back(w1); -// w2_lora.push_back(w2); -// w3_lora.push_back(w3); -// } -} - -QMoEMLP::~QMoEMLP() { -} - -void QMoEMLP::forward_ -( - cublasHandle_t cublas_handle, - half* x, - int rows, - int columns -// const std::vector& loras, -// half* lora_temp -) -{ - if (num_experts != 8 && num_experts != 4) - { - printf(" ## num_experts != 4 or 8 not implemented\n"); - return; - } - - bool use_half2 = true; - - // Norm - - if (layernorm_is_rms) - rms_norm_cuda(x, layernorm, temp_state, norm_epsilon, rows, columns); - else - layer_norm_cuda(x, layernorm, layernorm_bias, temp_state, norm_epsilon, rows, columns); - - // Compute gate logits - - half alpha_ = __float2half(1.0f); - half beta_ = __float2half(0.0f); - cublasHgemm(cublas_handle, - CUBLAS_OP_T, // gate is column-major - CUBLAS_OP_N, - num_experts, rows, hidden_dim, - &alpha_, - gate, hidden_dim, - temp_state, hidden_dim, - &beta_, - temp_logits, num_experts); - - // Compute softmax filter to and normalize top-k outputs - - dim3 blockDim, gridDim; - blockDim.x = WARPS; - blockDim.y = 1; - gridDim.x = 1; - gridDim.y = DIVIDE(rows, WARPS); - if (num_experts == 4) - softmax4_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); - else if (num_experts == 8) - softmax8_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); - - // For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot - // product accum and kernels launched with only zero-weights will exit prematurely. - - if (rows <= MAX_Q_GEMM_WEIGHTS) - { - int intermediate_size = w1[0]->width; - fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, true); - - for (int i = 0; i < num_experts; i++) - { - gemm_half_q_half_cuda(cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); - gemm_half_q_half_cuda(cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); - -// apply_loras_cuda(cublas_handle, w1_lora[i], loras, w1[i], temp_state, temp_a, lora_temp, rows); -// apply_loras_cuda(cublas_handle, w3_lora[i], loras, w3[i], temp_state, temp_b, lora_temp, rows); - - blockDim.x = THREADS_X; - blockDim.y = THREADS_Y; - gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1); - gridDim.y = DIVIDE(rows, THREADS_Y); - kernel<<>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts); - - gemm_half_q_half_cuda(cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true); - -// apply_loras_cuda(cublas_handle, w2_lora[i], loras, w2[i], temp_a, x, lora_temp, rows); - } - } - - // Gather larger number of rows in separate batches according to which experts they trigger, evaluate each MLP - // only on the affected rows and scale by routing weights while adding back directly onto the residual hidden state - - else - { - printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS); - DBGI(rows); - } -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cuh deleted file mode 100644 index e23d1d56f..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp.cuh +++ /dev/null @@ -1,139 +0,0 @@ -#ifndef _q_mlp_cuh -#define _q_mlp_cuh - -#include -#include -#include -#include -#include - -#include "q_matrix.cuh" - -class QMLP -{ -public: - - half* layernorm; - half* layernorm_bias; - bool layernorm_is_rms; - float norm_epsilon; - - QMatrix* gate; - QMatrix* up; - QMatrix* down; - - half* temp_state; - half* temp_a; - half* temp_b; - half* temp_dq; - - int device; - int max_rows; - - std::unordered_map> gate_proj_lora; - std::unordered_map> up_proj_lora; - std::unordered_map> down_proj_lora; - - QMLP - ( - half* _layernorm, - half* _layermorm_bias, - bool _layernorm_is_rms, - float _norm_epsilon, - QMatrix* _gate, - QMatrix* _up, - QMatrix* _down, - half* _temp_state, - half* _temp_a, - half* _temp_b, - half* _temp_dq, - int _max_rows - ); - - ~QMLP(); - - void forward_ - ( - cublasHandle_t cublas_handle, - half* x, - int rows, - int columns, - const std::vector& loras, - half* lora_temp - ); - -private: - -}; - -class QMoEMLP -{ -public: - - half* layernorm; - half* layernorm_bias; - bool layernorm_is_rms; - float norm_epsilon; - - half* gate; - int num_experts; - int num_experts_per_token; - - std::vector w1; - std::vector w2; - std::vector w3; - - half* temp_state; - half* temp_gathered_state; - half* temp_a; - half* temp_b; - half* temp_logits; - half* temp_dq; - - int device; - int max_rows; - int hidden_dim; - -// std::vector>> w1_lora; -// std::vector>> w2_lora; -// std::vector>> w3_lora; - - QMoEMLP - ( - half* _layernorm, - half* _layermorm_bias, - bool _layernorm_is_rms, - float _norm_epsilon, - half* _gate, - int _num_experts, - int _num_experts_per_token, - std::vector& w1, - std::vector& w2, - std::vector& w3, - half* _temp_state, - half* _temp_gathered_state, - half* _temp_a, - half* _temp_b, - half* _temp_logits, - half* _temp_dq, - int _max_rows, - int _hidden_dim - ); - - ~QMoEMLP(); - - void forward_ - ( - cublasHandle_t cublas_handle, - half* x, - int rows, - int columns -// const std::vector& loras, -// half* lora_temp - ); - -private: - -}; - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp_softmax.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp_softmax.cuh deleted file mode 100644 index 91bc271d0..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_mlp_softmax.cuh +++ /dev/null @@ -1,156 +0,0 @@ - -#define WARPS 32 - -__global__ void softmax8_topk_norm_kernel -( - half* __restrict__ x, - const int rows, - const int topk -) -{ - int row = blockIdx.y * WARPS + threadIdx.x; - if (row >= rows) return; - - // Softmax - - int4* row_ptr = (int4*) (x + row * 8); - int4 logits_int4 = *row_ptr; - half2_uint32 l01(logits_int4.x); - half2_uint32 l23(logits_int4.y); - half2_uint32 l45(logits_int4.z); - half2_uint32 l67(logits_int4.w); - float f[] = - { - __low2float(l01.as_half2), - __high2float(l01.as_half2), - __low2float(l23.as_half2), - __high2float(l23.as_half2), - __low2float(l45.as_half2), - __high2float(l45.as_half2), - __low2float(l67.as_half2), - __high2float(l67.as_half2) - }; - - float maxf1 = fmaxf(f[0], f[1]); - float maxf2 = fmaxf(f[2], f[3]); - float maxf3 = fmaxf(f[4], f[5]); - float maxf4 = fmaxf(f[6], f[7]); - maxf1 = fmaxf(maxf1, maxf2); - maxf2 = fmaxf(maxf3, maxf4); - maxf1 = fmaxf(maxf1, maxf2); - - float sum = 0; - for (int i = 0; i < 8; ++i) - { - float e = expf(f[i] - maxf1); - sum += e; - f[i] = e; - } - float epsilon = 1e-8; - float isum = 1.0f / (sum + 8 * epsilon); - for (int i = 0; i < 8; ++i) f[i] = f[i] * isum + epsilon; - - // This is awful but surely faster than synchronizing or launching more kernels (??) - - sum = 1.0f; - for (int i = 0; i < 8 - topk; ++i) - { - float minf = 1.0f; - int minj = -1; - for (int j = 0; j < 8; ++j) - { - if (f[j] > 0 && f[j] < minf) - { - minf = f[j]; - minj = j; - } - } - sum -= f[minj]; - f[minj] = 0.0f; - } - - __syncthreads(); - - isum = 1.0f / sum; - for (int i = 0; i < 8; ++i) f[i] *= isum; - - l01.as_half2 = __floats2half2_rn(f[0], f[1]); - l23.as_half2 = __floats2half2_rn(f[2], f[3]); - l45.as_half2 = __floats2half2_rn(f[4], f[5]); - l67.as_half2 = __floats2half2_rn(f[6], f[7]); - logits_int4.x = l01.as_uint32; - logits_int4.y = l23.as_uint32; - logits_int4.z = l45.as_uint32; - logits_int4.w = l67.as_uint32; - *row_ptr = logits_int4; -} - -__global__ void softmax4_topk_norm_kernel -( - half* __restrict__ x, - const int rows, - const int topk -) -{ - int row = blockIdx.y * WARPS + threadIdx.x; - if (row >= rows) return; - - // Softmax - - int2* row_ptr = (int2*) (x + row * 4); - int2 logits_int2 = *row_ptr; - half2_uint32 l01(logits_int2.x); - half2_uint32 l23(logits_int2.y); - float f[] = - { - __low2float(l01.as_half2), - __high2float(l01.as_half2), - __low2float(l23.as_half2), - __high2float(l23.as_half2), - }; - - float maxf1 = fmaxf(f[0], f[1]); - float maxf2 = fmaxf(f[2], f[3]); - maxf1 = fmaxf(maxf1, maxf2); - - float sum = 0; - for (int i = 0; i < 4; ++i) - { - float e = expf(f[i] - maxf1); - sum += e; - f[i] = e; - } - float epsilon = 1e-8; - float isum = 1.0f / (sum + 4 * epsilon); - for (int i = 0; i < 4; ++i) f[i] = f[i] * isum + epsilon; - - // This is awful but surely faster than synchronizing or launching more kernels (??) - - sum = 1.0f; - for (int i = 0; i < 4 - topk; ++i) - { - float minf = 1.0f; - int minj = -1; - for (int j = 0; j < 4; ++j) - { - if (f[j] > 0 && f[j] < minf) - { - minf = f[j]; - minj = j; - } - } - sum -= f[minj]; - f[minj] = 0.0f; - } - - __syncthreads(); - - isum = 1.0f / sum; - for (int i = 0; i < 4; ++i) f[i] *= isum; - - l01.as_half2 = __floats2half2_rn(f[0], f[1]); - l23.as_half2 = __floats2half2_rn(f[2], f[3]); - logits_int2.x = l01.as_uint32; - logits_int2.y = l23.as_uint32; - *row_ptr = logits_int2; -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh index 3beaeefa9..90c18a0c6 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh @@ -100,4 +100,4 @@ __forceinline__ __device__ void dequant_2bit_16 #endif -#endif \ No newline at end of file +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh index 5fb070d06..ad95edb46 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh @@ -224,4 +224,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq #endif -#endif \ No newline at end of file +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh index 454e4b93b..78d81f92e 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh @@ -204,4 +204,4 @@ __forceinline__ __device__ void dequant_5bit_32 #endif -#endif \ No newline at end of file +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh index c2eb8cfbf..562fe695f 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh @@ -40,5 +40,3 @@ __forceinline__ __device__ void dequant_6bit_16 #endif #endif - - diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh index e2409efac..6e6bedbdf 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh @@ -35,4 +35,4 @@ __forceinline__ __device__ void dequant_8bit_8 #endif -#endif \ No newline at end of file +#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cu deleted file mode 100644 index b3d2af2dc..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cu +++ /dev/null @@ -1,343 +0,0 @@ -#include "quantize.cuh" -#include "util.cuh" -#include -#include "compat.cuh" - -#define BLOCKSIZE_X 32 -#define BLOCKSIZE_Y 32 -#define MAX_P_GRID 64 - -__global__ void fused_quantize_adjust_kernel -( - const float* __restrict__ weights, // input weights [rows, columns] - float* __restrict__ quant, // output quantized weights [rows, columns] - const float* __restrict__ scale, // input scales [1, columns] - uint16_t* __restrict__ out_q, // output qweights [rows, columns] - const float* __restrict__ hessian_inv, // input hessian [rows, rows] - float* __restrict__ error, // output error [rows, columns] - int row, // row index to quantize - int rows, // num rows - int columns, // num columns - float qzero, // 2^(bits - 1) - float maxq // (2^bits) - 1 -) -{ - int column = blockIdx.x * blockDim.x + threadIdx.x; - if (column >= columns) return; - - int idx = row * columns + column; - - // Quantize - - float x = weights[idx]; - float s = scale[column]; - x /= s; - x = rintf(x); - x += qzero; - x = clamp(x, 0.0f, maxq); - - // Optionally save quant - - if (out_q) out_q[idx] = static_cast(x); - - // Downcast while quantizing - - half h_s = __float2half_rn(s); - half h_x = __float2half_rn(x); - half h_qzero = __float2half_rn(qzero); - - // Dequantize - - h_x = __hsub(h_x, h_qzero); - h_x = __hmul(h_x, h_s); - float q = __half2float(h_x); - quant[idx] = q; - - // Adjust error - - float d = hessian_inv[row * rows + row]; // H diagonal - float w = weights[idx]; - error[idx] = (w - q) / d; -} - -void fused_quantize_adjust_cuda -( - const float* weights, - float* quant, - const float* scale, - uint16_t* out_q, - const float* hessian_inv, - float* error, - int row, - int rows, - int columns, - float qzero, - float maxq -) -{ - dim3 threads(BLOCKSIZE_X, 1); - dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1); - - fused_quantize_adjust_kernel<<>> - ( - weights, - quant, - scale, - out_q, - hessian_inv, - error, - row, - rows, - columns, - qzero, - maxq - ); -} - -__global__ void quantize_kernel -( - const float* __restrict__ input, - float* __restrict__ output, - const float* __restrict__ scale, - uint16_t* __restrict__ out_q, - int rows, - int columns, - float qzero, - float maxq -) -{ - int column = blockIdx.x * blockDim.x + threadIdx.x; - int row = blockIdx.y * blockDim.y + threadIdx.y; - if (column >= columns) return; - if (row >= rows) return; - - // Quantize - - float x = input[row * columns + column]; - float s = scale[column]; - x /= s; - x = rintf(x); - x += qzero; - x = clamp(x, 0.0f, maxq); - - // Optionally save quant - - if (out_q) - { - uint16_t q = static_cast(x); - out_q[row * columns + column] = q; - } - - half h_s = __float2half_rn(s); - half h_x = __float2half_rn(x); - half h_qzero = __float2half_rn(qzero); - - h_x = __hsub(h_x, h_qzero); - h_x = __hmul(h_x, h_s); - - // Dequantize - - output[row * columns + column] = __half2float(h_x); -} - -void quantize_cuda -( - const float* input, - float* output, - const float* scale, - uint16_t* out_q, - int rows, - int columns, - float qzero, - float maxq -) -{ - dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y); - dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y)); - -// DBGI2(rows, columns); -// DBGF2(qzero, maxq); - - quantize_kernel<<>> - ( - input, - output, - scale, - out_q, - rows, - columns, - qzero, - maxq - ); -} - -__global__ void adjust_error_row_kernel -( - const float* __restrict__ hessian_inv, - float* __restrict__ error, - const float* __restrict__ weights, - const float* __restrict__ quant, - int c, - int columns, - int hcolumns -) -{ - int column = blockIdx.x * blockDim.x + threadIdx.x; - if (column >= columns) return; - - float d = hessian_inv[c * hcolumns + c]; - - int idx = c * columns + column; - float w = weights[idx]; - float q = quant[idx]; - error[idx] = (w - q) / d; -} - -void adjust_error_row_cuda -( - const float* hessian_inv, - float* error, - const float* weights, - const float* quant, - int c, - int columns, - int hcolumns -) -{ - dim3 threads(BLOCKSIZE_X, 1); - dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1); - - adjust_error_row_kernel<<>>(hessian_inv, error, weights, quant, c, columns, hcolumns); -} - -__global__ void quantize_err_kernel -( - const float* __restrict__ input, - float* __restrict__ output, - const float* __restrict__ scale, - int rows, - int columns, - float qzero, - float maxq, - float err_norm, - float min_p, - float max_p, - int p_grid -) -{ - int column_ = blockIdx.x * blockDim.x + threadIdx.x; - int column = column_ * 4; - int row = blockIdx.y * blockDim.y + threadIdx.y; - if (column >= columns) return; - if (row >= rows) return; - - float clamp_min = -qzero; - float clamp_max = maxq - qzero; - - float4 sc4 = *((float4*) (scale + column)); - float4 w4 = *((float4*) (input + row * columns + column)); - - for (int i = 0; i <= p_grid; i++) - { - float pi = __int2float_rn(i) / __int2float_rn(p_grid); - float p = min_p * (1.0f - pi) + max_p * pi; - - float4 s4 = sc4; - s4.x *= p; - s4.y *= p; - s4.z *= p; - s4.w *= p; - - float err = __powf(fabsf(clamp(rintf(w4.x / s4.x), clamp_min, clamp_max) * s4.x - w4.x), err_norm); - err += __powf(fabsf(clamp(rintf(w4.y / s4.y), clamp_min, clamp_max) * s4.y - w4.y), err_norm); - err += __powf(fabsf(clamp(rintf(w4.z / s4.z), clamp_min, clamp_max) * s4.z - w4.z), err_norm); - err += __powf(fabsf(clamp(rintf(w4.w / s4.w), clamp_min, clamp_max) * s4.w - w4.w), err_norm); - - atomicAdd(&output[i * 128 + column_ % 128], err); - } -} - -void quantize_err_cuda -( - const float* input, - float* output, - const float* scale, - int rows, - int columns, - float qzero, - float maxq, - float err_norm, - float min_p, - float max_p, - int p_grid -) -{ - dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y); - dim3 blocks(DIVIDE(columns / 4, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y)); - -// DBGI2(rows, columns); -// DBGF2(qzero, maxq); - - quantize_err_kernel<<>> - ( - input, - output, - scale, - rows, - columns, - qzero, - maxq, - err_norm, - min_p, - max_p, - p_grid - ); -} - -// Compute z = z - x.T @ y - -__global__ void vv_mul_sub_kernel -( - const float* __restrict__ x, - const float* __restrict__ y, - float* __restrict__ z, - int x_size, - int y_size -) -{ - int y_idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - int x_idx = blockIdx.y * blockDim.y + threadIdx.y; - if (y_idx >= y_size) return; - if (x_idx >= x_size) return; - - int z_idx = y_size * x_idx + y_idx; - - float vx = x[x_idx]; - float4 vy = *((float4*) (y + y_idx)); - float4 vz = *((float4*) (z + z_idx)); - vz.x -= vy.x * vx; - vz.y -= vy.y * vx; - vz.z -= vy.z * vx; - vz.w -= vy.w * vx; - *((float4*) (z + z_idx)) = vz; -} - -void vv_mul_sub_cuda -( - const float* x, - const float* y, - float* z, - int x_size, - int y_size -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCKSIZE_X; - blockDim.y = BLOCKSIZE_Y; - blockDim.z = 1; - gridDim.x = DIVIDE(y_size / 4, BLOCKSIZE_X); - gridDim.y = DIVIDE(x_size, BLOCKSIZE_Y); - gridDim.z = 1; - - vv_mul_sub_kernel<<>>(x, y, z, x_size, y_size); -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cuh deleted file mode 100644 index fd75cd1ef..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/quantize.cuh +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef _quantize_cuh -#define _quantize_cuh - -#include -#include -#include -#include - -void quantize_cuda -( - const float* input, - float* output, - const float* scale, - uint16_t* out_q, - int rows, - int columns, - float qzero, - float maxq -); - -void adjust_error_row_cuda -( - const float* hessian_inv, - float* error, - const float* weights, - const float* quant, - int c, - int columns, - int hcolumns -); - -void fused_quantize_adjust_cuda -( - const float* weights, - float* quant, - const float* scale, - uint16_t* out_q, - const float* hessian_inv, - float* error, - int row, - int rows, - int columns, - float qzero, - float maxq -); - -void quantize_err_cuda -( - const float* input, - float* output, - const float* scale, - int rows, - int columns, - float qzero, - float maxq, - float err_norm, - float min_p, - float max_p, - int p_grid -); - -void vv_mul_sub_cuda -( - const float* x, - const float* y, - float* z, - int x_size, - int y_size -); - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cu deleted file mode 100644 index 03089ee59..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cu +++ /dev/null @@ -1,133 +0,0 @@ -#include "rms_norm.cuh" -#include "util.cuh" - -#if defined(USE_ROCM) -#define __shfl_xor_sync(mask, var, laneMask) __shfl_xor(var, laneMask) -#define NUM_WARPS (1024 / warpSize) -#define WARP_SIZE (warpSize) -#else -#define NUM_WARPS 32 -#define WARP_SIZE 32 -#endif - -// y = x * w / sqrt(row_mean(x * x) + epsilon) - -#define BLOCK_SIZE WARP_SIZE -#define NUM_THREADS (NUM_WARPS * WARP_SIZE) - -typedef void (*fp_rms_norm_kernel) -( - const half*, - const half*, - half*, - const float, - const float, - const int, - const int -); - -template -__global__ void rms_norm_kernel -( - const half* __restrict__ x, - const half* __restrict__ w, - half* __restrict__ y, - const float epsilon, - const float r_dim, - const int rows, - const int dim -) -{ - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - int row = blockIdx.x; - const half* x_row = x + row * dim; - half* y_row = y + row * dim; - - //int blocks_per_warp = DIVIDE(dim, NUM_THREADS); - - // Compute sum of squares for each block - - float sum = 0.0f; - float itemf[blocks_per_warp]; - - #pragma unroll - for (int i = 0; i < blocks_per_warp; i++) - { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim) break; - - float f = __half2float(x_row[column]); - f = fmaxf(-65504.0f, fminf(f, 65504.0f)); - itemf[i] = f; - sum = fma(f, f, sum); - } - - // Shuffle to sum across lanes - - __shared__ float sums[NUM_WARPS]; - - for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); - if (lane_id == 0) sums[warp_id] = sum; - __syncthreads(); - - // Load partial sums from across warps, shuffle again across lanes - - sum = sums[lane_id]; - for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset); - - // Get norm - - float rmf = rsqrtf(sum * r_dim + epsilon); - - // Normalize x, scaling by w - - #pragma unroll 4 - for (int i = 0; i < blocks_per_warp; i++) - { - int column = warp_id * WARP_SIZE + lane_id + NUM_THREADS * i; - if (column >= dim) return; - - float x_itemf = itemf[i]; - float w_itemf = __half2float(w[column]); - float n = x_itemf * w_itemf * rmf; - y_row[column] = __float2half_rn(n); - } -} - -fp_rms_norm_kernel pick_rms_norm_kernel(const int blocks_per_warp) -{ - if (blocks_per_warp == 1) return rms_norm_kernel<1>; - if (blocks_per_warp == 2) return rms_norm_kernel<2>; - if (blocks_per_warp == 3) return rms_norm_kernel<3>; - if (blocks_per_warp == 4) return rms_norm_kernel<4>; - if (blocks_per_warp == 5) return rms_norm_kernel<5>; - if (blocks_per_warp == 6) return rms_norm_kernel<6>; - if (blocks_per_warp == 7) return rms_norm_kernel<7>; - if (blocks_per_warp == 8) return rms_norm_kernel<8>; - return NULL; -} - - -void rms_norm_cuda -( - const half* x, - const half* w, - half* y, - const float epsilon, - const int rows, - const int dim -) -{ - dim3 blockDim, gridDim; - blockDim.x = NUM_THREADS; - blockDim.y = 1; - gridDim.x = rows; - gridDim.y = 1; - - float r_dim = 1.0f / (float) dim; - - int blocks_per_warp = DIVIDE(dim, NUM_THREADS); - fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp); - kernel<<>>(x, w, y, epsilon, r_dim, rows, dim); -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cuh deleted file mode 100644 index 4cb0fea97..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/rms_norm.cuh +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _rms_norm_cuh -#define _rms_norm_cuh - -#include -#include -#include -#include - -void rms_norm_cuda -( - const half* x, - const half* w, - half* y, - const float epsilon, - const int rows, - const int dim -); - -#endif \ No newline at end of file diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cu deleted file mode 100644 index e4d4ef3c4..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cu +++ /dev/null @@ -1,142 +0,0 @@ -#include "rope.cuh" -#include "util.cuh" -#include "matrix_view.cuh" - -const int THREADS_X = 32; -const int THREADS_Y = 4; -const int MAX_POS_EMBEDDINGS = 32768; // Actual number doesn't matter -const int MAX_ROWS = 32768; // Actual number doesn't matter - -typedef void (*fp_rope_cuda_kernel) -( - half*, - const half*, - const half*, - int, - int, - int, - int, - const int32_t*, - int -); - -template -__global__ void rope_cuda_kernel -( - half* __restrict__ x, - const half* __restrict__ sin, - const half* __restrict__ cos, - int rows_per_batch, - int head_dim, - int num_heads, - int past_len, - const int32_t* __restrict__ past_lens, - int threads_y -) -{ - MatrixView_half_rw x_(x, MAX_ROWS, head_dim); - MatrixView_half sin_(sin, MAX_POS_EMBEDDINGS, head_dim); - MatrixView_half cos_(cos, MAX_POS_EMBEDDINGS, head_dim); - - int column = (blockIdx.x * THREADS_X + threadIdx.x); if constexpr (use_half2) column *= 2; - int half_dim = head_dim / 2; - if (column >= half_dim) return; - - int row = blockIdx.y * threads_y + threadIdx.y; - if (row >= rows_per_batch) return; - int batch_offset = blockIdx.z * rows_per_batch; - int row_offset = batch_offset + row; - - // Get sin and cos - - if (past_len == -1) - { - past_len = past_lens[blockIdx.z]; - past_len = max(past_len, 0); - } - else if (past_lens) - { - past_len += past_lens[blockIdx.z]; - } - - int sincos_row = past_len + row / num_heads; - sincos_row = max(sincos_row, 0); - - if constexpr (use_half2) - { - half2 cos2_l = cos_.item_half2(sincos_row, column); - half2 cos2_r = cos_.item_half2(sincos_row, column + half_dim); - half2 sin2_l = sin_.item_half2(sincos_row, column); - half2 sin2_r = sin_.item_half2(sincos_row, column + half_dim); - sin2_l = __hneg2(sin2_l); - - // Apply embedding to row - - half2 item2_l = x_.item_half2(row_offset, column); - half2 item2_r = x_.item_half2(row_offset, column + half_dim); - half2 item2_ls = __hmul2(item2_r, sin2_l); - half2 item2_rs = __hmul2(item2_l, sin2_r); - item2_l = __hfma2(item2_l, cos2_l, item2_ls); - item2_r = __hfma2(item2_r, cos2_r, item2_rs); - x_.set_half2(row_offset, column, item2_l); - x_.set_half2(row_offset, column + half_dim, item2_r); - } - else - { - half cos_l = cos_.item(sincos_row, column); - half cos_r = cos_.item(sincos_row, column + half_dim); - half sin_l = sin_.item(sincos_row, column); - half sin_r = sin_.item(sincos_row, column + half_dim); - sin_l = __hneg(sin_l); - - // Apply embedding to row - - half item_l = x_.item(row_offset, column); - half item_r = x_.item(row_offset, column + half_dim); - half item_ls = __hmul(item_r, sin_l); - half item_rs = __hmul(item_l, sin_r); - item_l = __hfma(item_l, cos_l, item_ls); - item_r = __hfma(item_r, cos_r, item_rs); - x_.set(row_offset, column, item_l); - x_.set(row_offset, column + half_dim, item_r); - } -} - -fp_rope_cuda_kernel pick_rope_cuda_kernel(bool use_half2) -{ - if (use_half2) return rope_cuda_kernel; - else return rope_cuda_kernel; -}; - -void rope_cuda -( - half* x, - const half* sin, - const half* cos, - const int batch_size, - const int rows_per_batch, - const int head_dim, - const int num_heads, - const int past_len, - const int32_t* past_lens -) -{ - bool use_half2 = true; - - // For large batch sizes we risk exceeding grid dimension of 65535, so shift to block dimension instead - - int threads_y = THREADS_Y; - while (DIVIDE(rows_per_batch, threads_y) > 65535) threads_y *= 2; - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = threads_y; - gridDim.x = DIVIDE(head_dim / (use_half2 ? 2 : 1), THREADS_X); - gridDim.y = DIVIDE(rows_per_batch, threads_y); - gridDim.z = batch_size; - - fp_rope_cuda_kernel kernel = pick_rope_cuda_kernel(use_half2); - kernel<<>>(x, sin, cos, rows_per_batch, head_dim, num_heads, past_len, past_lens, threads_y); - - cuda_check( cudaPeekAtLastError() ); -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cuh deleted file mode 100644 index c9bc1da8d..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/rope.cuh +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _rope_cuh -#define _rope_cuh - -#include -#include -#include -#include - -void rope_cuda -( - half* x, - const half* sin, - const half* cos, - const int batch_size, - const int rows_per_batch, - const int head_dim, - const int num_heads, - const int past_len, - const int32_t* past_lens -); - -#endif diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cu deleted file mode 100644 index 4f385791b..000000000 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cu +++ /dev/null @@ -1,24 +0,0 @@ -#include "util.cuh" - -void print_global_mem(const half* ptr, int rows, int columns, int stride) -{ - half* temp = (half*) malloc(rows * columns * sizeof(half)); - - cudaDeviceSynchronize(); - cudaMemcpyAsync(temp, ptr, rows * columns * sizeof(half), cudaMemcpyDeviceToHost); - cudaDeviceSynchronize(); - - for (int r = 0; r < rows; ++r) - { - printf("%4d: ", r); - for (int c = 0; c < columns; c++) - { - float v = __half2float(temp[r * stride + c]); - printf("%10.6f", v); - if (c < columns - 1) printf(" "); - } - printf("\n"); - } - - free(temp); -} diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh index f56eda791..e167bc23e 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh @@ -51,4 +51,4 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort= void print_global_mem(const half* ptr, int rows, int columns, int stride); -#endif \ No newline at end of file +#endif diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py index 518db1df9..4a16b546f 100644 --- a/server/exllamav2_kernels/setup.py +++ b/server/exllamav2_kernels/setup.py @@ -1,5 +1,15 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_cuda_cflags = ["-lineinfo", "-O3"] + +if torch.version.hip: + extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"] + +extra_compile_args = { + "nvcc": extra_cuda_cflags, +} setup( name="exllamav2_kernels", @@ -11,6 +21,7 @@ "exllamav2_kernels/cuda/q_matrix.cu", "exllamav2_kernels/cuda/q_gemm.cu", ], + extra_compile_args=extra_compile_args, ) ], cmdclass={"build_ext": BuildExtension}, From 1337aa56390783e30d9cb1152ed5104020a11a7c Mon Sep 17 00:00:00 2001 From: flozi00 Date: Mon, 20 May 2024 22:03:24 +0200 Subject: [PATCH 02/32] awq --- server/Makefile | 1 + .../attention/cuda_bf16_fallbacks.cuh | 257 --- .../awq_cuda/attention/cuda_bf16_wrapper.h | 23 - .../decoder_masked_multihead_attention.cu | 152 -- .../decoder_masked_multihead_attention.h | 184 -- ...er_masked_multihead_attention_template.hpp | 1608 --------------- ...decoder_masked_multihead_attention_utils.h | 1786 ----------------- .../awq_cuda/attention/ft_attention.cpp | 182 -- .../awq_cuda/attention/ft_attention.h | 15 - .../awq_cuda/layernorm/layernorm.cu | 113 -- .../awq_cuda/layernorm/layernorm.h | 3 - .../awq_cuda/layernorm/reduction.cuh | 82 - .../position_embedding/pos_encoding.h | 9 - .../pos_encoding_kernels.cu | 88 - server/awq_kernels/awq_cuda/pybind_awq.cpp | 15 - server/awq_kernels/awq_cuda/pybind_ft.cpp | 11 - .../awq_cuda/quantization/dequantize.cuh | 79 - .../awq_cuda/quantization/gemm_cuda.h | 7 - .../awq_cuda/quantization/gemm_cuda_gen.cu | 848 -------- .../awq_cuda/quantization/gemv_cuda.cu | 249 --- .../awq_cuda/quantization/gemv_cuda.h | 9 - server/awq_kernels/setup.py | 92 - 22 files changed, 1 insertion(+), 5812 deletions(-) delete mode 100644 server/awq_kernels/awq_cuda/attention/cuda_bf16_fallbacks.cuh delete mode 100644 server/awq_kernels/awq_cuda/attention/cuda_bf16_wrapper.h delete mode 100644 server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.cu delete mode 100644 server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.h delete mode 100644 server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_template.hpp delete mode 100644 server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_utils.h delete mode 100644 server/awq_kernels/awq_cuda/attention/ft_attention.cpp delete mode 100644 server/awq_kernels/awq_cuda/attention/ft_attention.h delete mode 100644 server/awq_kernels/awq_cuda/layernorm/layernorm.cu delete mode 100644 server/awq_kernels/awq_cuda/layernorm/layernorm.h delete mode 100644 server/awq_kernels/awq_cuda/layernorm/reduction.cuh delete mode 100644 server/awq_kernels/awq_cuda/position_embedding/pos_encoding.h delete mode 100644 server/awq_kernels/awq_cuda/position_embedding/pos_encoding_kernels.cu delete mode 100644 server/awq_kernels/awq_cuda/pybind_awq.cpp delete mode 100644 server/awq_kernels/awq_cuda/pybind_ft.cpp delete mode 100644 server/awq_kernels/awq_cuda/quantization/dequantize.cuh delete mode 100644 server/awq_kernels/awq_cuda/quantization/gemm_cuda.h delete mode 100644 server/awq_kernels/awq_cuda/quantization/gemm_cuda_gen.cu delete mode 100644 server/awq_kernels/awq_cuda/quantization/gemv_cuda.cu delete mode 100644 server/awq_kernels/awq_cuda/quantization/gemv_cuda.h delete mode 100644 server/awq_kernels/setup.py diff --git a/server/Makefile b/server/Makefile index e6cd87395..75cb9e942 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,6 +3,7 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-megablocks include Makefile-eetq +include Makefile-awq unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/awq_kernels/awq_cuda/attention/cuda_bf16_fallbacks.cuh b/server/awq_kernels/awq_cuda/attention/cuda_bf16_fallbacks.cuh deleted file mode 100644 index f5641f616..000000000 --- a/server/awq_kernels/awq_cuda/attention/cuda_bf16_fallbacks.cuh +++ /dev/null @@ -1,257 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include - -namespace fastertransformer { - -#ifdef ENABLE_BF16 -inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; -#else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; -#endif -} - -inline __device__ __nv_bfloat162 float22bf162(const float2 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); -#else - return __float22bfloat162_rn(val); -#endif -} - -inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; -#else - return __bfloat162bfloat162(val); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#else - return __hadd2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); -#else - return __hadd(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); -#else - return __hsub2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); -#else - return __hsub(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#else - return __hmul2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); -#else - return __hmul(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); -#else - return __hfma2(x, y, z); -#endif -} - -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); -#else - return __hfma(x, y, z); -#endif -} - -inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x);; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); -#else - return h2exp(x); -#endif -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; - -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; t.x = x; t.y = y; return t; -} - -#endif - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); -#else - return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); -#else - return a * b * c + d; -#endif -} - -#endif // ENABLE_BF16 - -} // namespace fastertransformer diff --git a/server/awq_kernels/awq_cuda/attention/cuda_bf16_wrapper.h b/server/awq_kernels/awq_cuda/attention/cuda_bf16_wrapper.h deleted file mode 100644 index efb6e7987..000000000 --- a/server/awq_kernels/awq_cuda/attention/cuda_bf16_wrapper.h +++ /dev/null @@ -1,23 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_BF16 -#include -#endif diff --git a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.cu b/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.cu deleted file mode 100644 index 12fc1dea2..000000000 --- a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.cu +++ /dev/null @@ -1,152 +0,0 @@ -// Adapted from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include -#include -#include - -#include "decoder_masked_multihead_attention_template.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - auto kernel = mmha::masked_multihead_attention_kernel; \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - dim3 grid(params.num_heads, params.batch_size); \ - kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); - if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); - } - else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); - } - else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#undef MMHA_LAUNCH_KERNEL - -template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - switch (params.hidden_size_per_head) { - case 32: - mmha_launch_kernel(params, stream); - break; - case 48: - mmha_launch_kernel(params, stream); - break; - case 64: - mmha_launch_kernel(params, stream); - break; - case 80: - mmha_launch_kernel(params, stream); - break; - case 96: - mmha_launch_kernel(params, stream); - break; - case 112: - mmha_launch_kernel(params, stream); - break; - case 128: - mmha_launch_kernel(params, stream); - break; - case 160: - mmha_launch_kernel(params, stream); - break; - case 192: - mmha_launch_kernel(params, stream); - break; - case 224: - mmha_launch_kernel(params, stream); - break; - case 256: - mmha_launch_kernel(params, stream); - break; - default: - assert(false); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.h b/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.h deleted file mode 100644 index 292fa3401..000000000 --- a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention.h +++ /dev/null @@ -1,184 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The structure of parameters for the masked multihead attention kernel. -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. - -template -struct Multihead_attention_params_base { - - // The output buffer. Dimensions B x D. - T* out = nullptr; - - // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; - // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; - // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; - - // The cache for the Ks. The size must be at least B x L x D. - T* k_cache = nullptr; - // The cache for the Vs. The size must be at least B x L x D. - T* v_cache = nullptr; - // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; - - // Stride to handle the case when KQV is a single buffer - int stride = 0; - - // The batch size. - int batch_size = 0; - // The beam width - int beam_width = 0; - // The sequence length. - int memory_max_len = 0; - // The number of heads (H). - int num_heads = 0; - // The number of heads for KV cache. - int num_kv_heads = 0; - // The hidden dimension per head (Dh). - int hidden_size_per_head = 0; - // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; - bool neox_rotary_style = false; - float rotary_base = 0.0f; - // The maximum length of input sentences. - int max_input_length = 0; - // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep = 0; - // The current timestep of each sentences (support different timestep for different sentences) - - // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh = 0.0f; - - // Used when we have some input context like gpt - const int* total_padding_tokens = nullptr; - - const bool* masked_tokens = nullptr; - const int* prefix_prompt_lengths = nullptr; - int max_prefix_prompt_length = 0; - - const T* relative_attention_bias = nullptr; - int relative_attention_bias_stride = 0; - // The slope per head of linear position bias to attention score (H). - const float* linear_bias_slopes = nullptr; - - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; - - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; - int int8_mode = 0; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - // will need it here till if constexpr in c++17 - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -using Masked_multihead_attention_params = Multihead_attention_params; - -template -using Cross_multihead_attention_params = Multihead_attention_params; - -template -struct outputCrossAttentionParam { - // max decoder output length - int max_decoder_seq_len = 0; - T* cross_attention_out = nullptr; - bool is_return_cross_attentions = false; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_template.hpp b/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_template.hpp deleted file mode 100644 index 84eda817b..000000000 --- a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_template.hpp +++ /dev/null @@ -1,1608 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include -#include -#include - -// #define MMHA_USE_HMMA_FOR_REDUCTION - -// Below are knobs to extend FP32 accumulation for higher FP16 accuracy - -// Does not seem to affect the accuracy that much -#define MMHA_USE_FP32_ACUM_FOR_FMA - -// Seems to slightly improve the accuracy -#define MMHA_USE_FP32_ACUM_FOR_OUT - -#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) - // Does not seem to improve the accuracy - //#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#endif - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. -// -// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use -// 64, 128 and 256 threads per block. -// -// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to -// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The -// cache buffer helps with memory accesses and contains keys with bias. -// -// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and -// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The -// values for x are chosen to create chunks of 16 bytes. -// -// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs -// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At -// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an -// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. -// -// After that loop, a parallel softmax is computed across the different Q * K^T values stored in -// shared memory. -// -// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many -// timesteps are computed by loop iteration. As with the keys, the values are read from a cache -// except for the current timestep. The layout of the cache buffer for the values is much simpler -// as it is [B, H, L, Dh]. -// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_ { -}; - -template<> -struct Qk_vec_ { - using Type = float; -}; -template<> -struct Qk_vec_ { - using Type = float2; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint2; -}; -template<> -struct Qk_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct Qk_vec_<__nv_bfloat16, 32> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 64> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 128> { - using Type = bf16_4_t; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 256> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_ { -}; - -template<> -struct K_vec_ { - using Type = float; -}; -template<> -struct K_vec_ { - using Type = float2; -}; -template<> -struct K_vec_ { - using Type = float4; -}; -template<> -struct K_vec_ { - using Type = uint32_t; -}; -template<> -struct K_vec_ { - using Type = uint2; -}; -template<> -struct K_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct K_vec_<__nv_bfloat16, 4> { - using Type = __nv_bfloat162; -}; -template<> -struct K_vec_<__nv_bfloat16, 2> { - using Type = bf16_4_t; -}; -template<> -struct K_vec_<__nv_bfloat16, 1> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_ { -}; - -template<> -struct V_vec_ { - using Type = float; -}; -template<> -struct V_vec_ { - using Type = float2; -}; -template<> -struct V_vec_ { - using Type = float4; -}; -template<> -struct V_vec_ { - using Type = uint32_t; -}; -template<> -struct V_vec_ { - using Type = uint2; -}; -template<> -struct V_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_<__nv_bfloat16, 2> { - using Type = __nv_bfloat162; -}; -template<> -struct V_vec_<__nv_bfloat16, 4> { - using Type = bf16_4_t; -}; -template<> -struct V_vec_<__nv_bfloat16, 8> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct Qk_vec_acum_fp32_ { -}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float4; -}; -// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_acum_fp32_ { -}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ { -}; - -template<> -struct V_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = K_vec; -#endif - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -// Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(float u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(Float8_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float float_from_int8(int8_t u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 float_from_int8(int16_t u) -{ - union { - int16_t int16; - int8_t int8[2]; - }; - int16 = u; - return make_float2(int8[0], int8[1]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 float_from_int8(int32_t u) -{ - union { - int32_t int32; - int8_t int8[4]; - }; - int32 = u; - return make_float4(int8[0], int8[1], int8[2], int8[3]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// clang-format off -inline __device__ Float8_ float_from_int8(int64_t u) -{ - union { - int64_t int64; - int16_t int16[4]; - }; - int64 = u; - return Float8_ {float_from_int8(int16[0]), - float_from_int8(int16[1]), - float_from_int8(int16[2]), - float_from_int8(int16[3])}; -} -// clang-format on - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int8_t cast_to_int8(float val) -{ - union { - int8_t int8[2]; - int16_t int16; - }; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int32_t cast_to_int8(float4 val) -{ - union { - int8_t int8[4]; - int32_t int32; - }; - int8[0] = cast_to_int8(val.x); - int8[1] = cast_to_int8(val.y); - int8[2] = cast_to_int8(val.z); - int8[3] = cast_to_int8(val.w); - return int32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int64_t cast_to_int8(Float8_ val) -{ - union { - int8_t int8[8]; - int64_t int64; - }; - int8[0] = cast_to_int8(val.x.x); - int8[1] = cast_to_int8(val.x.y); - int8[2] = cast_to_int8(val.y.x); - int8[3] = cast_to_int8(val.y.y); - int8[4] = cast_to_int8(val.z.x); - int8[5] = cast_to_int8(val.z.y); - int8[6] = cast_to_int8(val.w.x); - int8[7] = cast_to_int8(val.w.y); - return int64; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T div_up(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, - int threads_per_value, - int threads_per_block) -{ - // The amount of shared memory needed to store the Q*K^T values in float. - const int max_timesteps = min(params.timestep, params.memory_max_len); - size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - - // The extra memory needed if we are not using floats for the final logits. - size_t logits_sz = 0; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TDOD - logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) : - div_up(max_timesteps + 1, 4) * 4 * sizeof(T); - } -#endif - - // The total size needed during softmax. - size_t softmax_sz = qk_sz + logits_sz; - - // The number of partial rows to reduce in the final reduction. - int rows_per_red = threads_per_block / threads_per_value; - // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; - - size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); - } - - // The max. - return max(max(softmax_sz, red_sz), transpose_rotary_size); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ constexpr uint32_t shfl_mask(int threads) -{ - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the inputs. Supported types: float and half. - typename T, - // The hidden dimension per head. - int Dh, - int Dh_MAX, - // The number of threads per key. - int THREADS_PER_KEY, - // The number of threads per value. - int THREADS_PER_VALUE, - // The number of threads in a threadblock. - int THREADS_PER_BLOCK, - bool DO_CROSS_ATTENTION> -__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) -{ - - // Make sure the hidden dimension per head is a multiple of the number of threads per key. - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - // Make sure the hidden dimension per head is a multiple of the number of threads per value. - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - // The size of a warp. - constexpr int WARP_SIZE = 32; - // The number of warps in a threadblock. - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // Use smem_size_in_bytes (above) to determine the amount of shared memory. - extern __shared__ char smem_[]; - - // The shared memory for the Q*K^T values and partial logits in softmax. - float* qk_smem = reinterpret_cast(smem_); - - // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. - char* logits_smem_ = smem_; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TODO - change to tlength - const int max_timesteps = min(params.timestep, params.memory_max_len); - logits_smem_ += - (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - } - T* logits_smem = reinterpret_cast(logits_smem_); -#else - float* logits_smem = reinterpret_cast(logits_smem_); -#endif - - // The shared memory to do the final reduction for the output values. Reuse qk_smem. - T* out_smem = reinterpret_cast(smem_); - - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - - // Use alignment for safely casting the shared buffers as Qk_vec. - // Shared memory to store Q inputs. - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - // This is one of the reasons we should have a separate kernel for cross attention - __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - // The number of elements per vector. - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // We will use block wide reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - // The number of vectors per warp. - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread - // owns x elements, we have to decompose the linear index into chunks of x values and the posi- - // tion of the thread in that chunk. - - // The number of elements in a chunk of 16B (that's the x in the above formula). - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - // The number of K vectors in 16B. - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - // The batch/beam idx - const int bi = blockIdx.y; - if (params.finished != nullptr && params.finished[bi] == true) { - return; - } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; - // The head. - const int num_kv_heads = params.num_kv_heads; - const int kv_rep = (params.num_heads / num_kv_heads); - const int hi = blockIdx.x; - const int hi_kv = hi / kv_rep; - - // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; - const int bhi_kv = bi * (params.num_heads / kv_rep) + hi_kv; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads + hi; - const int bbhi_kv = bbi * params.beam_width * (params.num_heads / kv_rep) + hi_kv; - // The thread in the block. - const int tidx = threadIdx.x; - - const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); - // Every kv_rep threads have the same kv_cache values. So only the first one writes back. - const int write_kv_cache = handle_kv && (hi % kv_rep == 0); - - // While doing the product Q*K^T for the different keys we track the max. - float qk_max = -FLT_MAX; - - float qk = 0.0F; - - // int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; - const int q_base_offset = bi * params.stride + hi * Dh; - const int k_base_offset = bi * params.stride + hi_kv * Dh; - const int v_base_offset = k_base_offset; - - const size_t bi_seq_len_offset = bi * params.memory_max_len; - - // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : - (params.length_per_sample == nullptr) ? - params.timestep : - params.length_per_sample[bi] + params.max_prefix_prompt_length; - const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len; - - // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. - const bool is_masked = tidx >= QK_VECS_PER_WARP; - - // The offset in the Q and K buffer also accounts for the batch. - // int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; - int q_offset = q_base_offset + tidx * QK_VEC_SIZE; - int k_offset = k_base_offset + tidx * QK_VEC_SIZE; - int v_offset = k_offset; - - // The offset in the bias buffer. - // int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE; - int v_bias_offset = k_bias_offset; - - const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; - const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; - - // Trigger the loads from the Q and K buffers. - Qk_vec q; - zero(q); - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_out[0]; - const auto q_quant = - *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); - - convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); - } - else { - q = *reinterpret_cast(¶ms.q[q_offset]); - } - } - - Qk_vec k; - zero(k); - if (DO_CROSS_ATTENTION) { - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - *reinterpret_cast(¶ms.k_cache[offset]) : - k; - } - else { - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = *reinterpret_cast(¶ms.k[k_offset]); - } - } - } - - // Trigger the loads from the Q and K bias buffers. - Qk_vec q_bias; - zero(q_bias); - q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - *reinterpret_cast(¶ms.q_bias[q_bias_offset]) : - q_bias; - - Qk_vec k_bias; - zero(k_bias); - if (handle_kv) { - k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - *reinterpret_cast(¶ms.k_bias[k_bias_offset]) : - k_bias; - } - - // Computes the Q/K values with bias. - q = add(q, q_bias); - if (handle_kv) { - k = add(k, k_bias); - } - if (do_ia3 && !is_masked) { - k = mul( - k, - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])); - } - - // Padded len - const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; - if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { - if (handle_kv) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } - else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } - } - else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; - - T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; - - const int half_rotary_dim = params.rotary_embedding_dim / 2; - const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; - const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts - - assert(half_rotary_dim % QK_VEC_SIZE == 0); - - if (do_rotary) { - *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; - - if (handle_kv) { - *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; - } - } - - __syncthreads(); - - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; - constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; - if (do_rotary) { - mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - - if (handle_kv) { - mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - - mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - } - else { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); - } - mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - } - - __syncthreads(); - - if (do_rotary) { - q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); - if (handle_kv) { - k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); - } - } - - __syncthreads(); - } - - if (!is_masked) { - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - - // Store Dh values of k_bias into smem, since will need to add later - // if params.timestep == 0 - if (DO_CROSS_ATTENTION && params.timestep == 0) { - *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - } - - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength_circ * QK_ELTS_IN_16B + ci; - - if (write_kv_cache) { - // Trigger the stores to global memory. - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = k; - } - } - - // Compute \sum_i Q[i] * K^T[i] for the current timestep. -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; -#else - using Qk_vec_acum = Qk_vec; -#endif - qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= params.inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - // TODO (Haotian): check whether we should replace hi with hi_kv, - // although params.relative_attention_bias is usually not used. - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + (tlength - padd_len) * params.relative_attention_bias_stride - + (tlength - padd_len)]); - } - // Add alibi positional encoding - // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0; - // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. - - qk_max = qk; - qk_smem[tlength - first_step] = qk; - // qk_smem[params.timestep] = qk; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The type of queries and keys for the math in the Q*K^T product. - using K_vec = typename K_vec_::Type; - // The number of elements per vector. - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - // The number of elements per thread. - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - // The number of vectors per thread. - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - // The position the first key loaded by each thread from the cache buffer (for this B * H). - int ko = tidx / THREADS_PER_KEY; - // The position of the thread in the chunk of keys. - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); - - // Load the Q values from shared memory. The values are reused during the loop on K. - K_vec q_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - - K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; - if (DO_CROSS_ATTENTION && params.timestep == 0) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - } - - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi_kv * params.memory_max_len * Dh + ki]; - - // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). - // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; - - // prefix prompt length if has - const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; - - // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - const bool has_beams = params.cache_indir != nullptr; - const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; - - for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // The keys loaded from the key cache. - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.memory_max_len + ti_circ; - // if( ti < params.timestep ) { - const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); - if (ti < tlength) { - if (!within_bounds) { - k[ii] = k_vec_zero; - } - else { - if (has_beams) { - const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; - k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); - } - else { - k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); - } - } - // add bias and update k_cache - if (DO_CROSS_ATTENTION && params.timestep == 0) { - k[ii] = add(k[ii], k_bias_vec[ii]); - - if (do_ia3) { - k[ii] = mul( - k[ii], - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki - + ii * THREADS_PER_KEY * K_VEC_SIZE])); - } - - if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { - *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; - } - } - } - } - - // Perform the dot product and normalize qk. - // - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - - // Store the product to shared memory. There's one qk value per timestep. Update the max. - // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]); - } - if (params.linear_bias_slopes != nullptr) { - // Apply the linear position bias: (ki - qi) * slope[hi]. - // The padding token locates between the input context and the generated tokens. - // We need to remove the number of padding tokens in the distance computation. - // ti : 0 1 2 3 4 5 6 7 8 9(tlength) - // token: i i i i p p p o o o where i=input, p=pad, o=output. - // e.g. ti = 2, dist = (9 - 3) - 2 = 4. - int max_context_length = params.max_prefix_prompt_length + params.max_input_length; - float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; - - qk += mul(params.linear_bias_slopes[hi], dist); - } - // Add alibi positional encoding - // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0; - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = qk; - } - } - -// Perform the final reduction to compute the max inside each warp. -// -// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the -// group so it's not needed to run the reduction inside the group (again). -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - const int warp = tidx / WARP_SIZE; - const int lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Compute the logits and start the sum. - float sum = 0.f; - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); - sum += logit; - qk_smem[ti - first_step] = logit; - } - - // Compute the sum. - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // Normalize the logits. - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - const size_t cross_attention_out_offset = - params.is_return_cross_attentions ? - bhi_kv * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : - 0; - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - float logit = qk_smem[ti - first_step] * inv_sum; - if (params.is_return_cross_attentions) { - params.cross_attention_out[cross_attention_out_offset + ti] = logit; - } - convert_from_float(logits_smem[ti - first_step], logit); - } - - // Put Values part below so we leverage __syncthreads - // from the previous step - - // The number of elements per vector. - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - // A vector of V elements for the current timestep. - using V_vec = typename V_vec_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - - // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi_kv * params.memory_max_len * Dh + vi]; - - // The number of values processed per iteration of the loop. - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - // One group of threads computes the product(s) for the current timestep. - V_vec v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (handle_kv) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = *reinterpret_cast(¶ms.v_bias[hi_kv * Dh + vi]); - } - if (DO_CROSS_ATTENTION) { - *reinterpret_cast(&bias_smem[vi]) = v_bias; - } - } - } - } - - // From previous, before values, step - // Also make sure the logits are in shared memory. - __syncthreads(); - - // Values continued -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - // The partial outputs computed by each thread. - V_vec_acum out; - zero(out); - - // Loop over the timesteps to compute the partial outputs. - // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // Fetch offset based on cache_indir when beam sampling - const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; - const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; - // Load the values from the cache. - V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, *reinterpret_cast(&bias_smem[vi])); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - *reinterpret_cast(&v_cache[ti * Dh]) = v; - } - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti - first_step]; - out = fma(logit, cast_to_float(v), out); -#else - T logit = logits_smem[ti - first_step]; - - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - - // One group of threads computes the product(s) for the current timestep. - // if( vo == params.timestep % V_PER_ITER ) { - if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { - - V_vec v; - if (DO_CROSS_ATTENTION) { - v = *reinterpret_cast(&v_cache[tlength * Dh]); - } - else { - // Trigger the loads from the V buffer. - const auto v_offset = v_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = *reinterpret_cast(¶ms.v[v_offset]); - } - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); - } - - // Compute the V values with bias. - v = add(v, v_bias); - if (write_kv_cache) { - - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - - // Store the values with bias back to global memory in the cache for V. - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; - } - - // Initialize the output value with the current timestep. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - // out = fma(logits_smem[params.timestep], cast_to_float(v), out); - out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); -#else - // out = fma(logits_smem[params.timestep], v, out); - out = fma(logits_smem[tlength - first_step], v, out); -#endif - } - - // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - out = mul(*params.attention_out_scale, out); - *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = - cast_to_int8(out); - } - else { - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); - } -#else - // TODO: support int8_mode? - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace mmha - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_utils.h b/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_utils.h deleted file mode 100644 index fe462574c..000000000 --- a/server/awq_kernels/awq_cuda/attention/decoder_masked_multihead_attention_utils.h +++ /dev/null @@ -1,1786 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include - -using namespace fastertransformer; - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float4_ { - float2 x; - float2 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -struct bf16_4_t { - __nv_bfloat162 x; - __nv_bfloat162 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct bf16_8_t { - __nv_bfloat162 x; - __nv_bfloat162 y; - __nv_bfloat162 z; - __nv_bfloat162 w; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct num_elems; -template<> -struct num_elems { - static constexpr int value = 1; -}; -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -#ifdef ENABLE_BF16 -template<> -struct num_elems<__nv_bfloat162> { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct packed_type; -template -struct packed_type { - using type = T; -}; -template<> -struct packed_type { - using type = int16_t; -}; -template<> -struct packed_type { - using type = int32_t; -}; -template<> -struct packed_type { - using type = int64_t; -}; - -template<> -struct packed_type { - using type = float2; -}; -template<> -struct packed_type { - using type = float4; -}; -template<> -struct packed_type { - using type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, float b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(float2 a, float2 b) -{ - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 add(float4 a, float4 b) -{ - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t add(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t add(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 add(uint2 a, uint2 b) -{ - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 add(uint4 a, uint4 b) -{ - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t float_to_half(float f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? - float zero = 0.f; - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); -#endif - return tmp.u16[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t float2_to_half2(float2 f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float half_to_float(uint16_t h) -{ - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 half2_to_float2(uint32_t v) -{ - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, uint16_t b) -{ - return a + half_to_float(b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float add(float a, __nv_bfloat16 b) -{ - return a + __bfloat162float(b); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(uint32_t a, float2 fb) -{ - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(uint2 a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(uint4 a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t h0_h0(uint16_t a) -{ - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(float a, float b, float c) -{ - return a * b + c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float2 a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float4 a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) -{ - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) -{ - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float2 add(__nv_bfloat162 a, float2 fb) -{ - float2 fa = bf1622float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) -{ - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) -{ - return fma(h0_h0(a), b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) -{ - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) -{ - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) -{ - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) -{ - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(uint16_t a, uint16_t b, float fc) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) -{ - return fma(h0_h0(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) -{ - uint32_t s = h0_h0(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) -{ - uint32_t s = h0_h0(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(bf162bf162(a), b, c); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) -{ - bf16_4_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) -{ - bf16_8_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) -{ - return __bfloat162float(a) * __bfloat162float(b) + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) -{ - return fma(bf162bf162(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ Acc mul(A a, B b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float2 a, float2 b) -{ - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float a, float2 b) -{ - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float4 a, float4 b) -{ - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float a, float4 b) -{ - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(float a, Float8_ b) -{ - Float8_ c; - c.x = make_float2(a * b.x.x, a * b.x.y); - c.y = make_float2(a * b.y.x, a * b.y.y); - c.z = make_float2(a * b.z.x, a * b.z.y); - c.w = make_float2(a * b.w.x, a * b.w.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint2 a, uint2 b) -{ - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - uint2 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint4 a, uint4 b) -{ - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - uint4 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - c.z = mul(s, b.z); - c.w = mul(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, uint16_t b) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, float b) -{ - return half_to_float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint32_t a, uint32_t b) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint2 a, uint2 b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint4 a, uint4 b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -template<> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __hmul(a, b); -#else - return bf16hmul(a, b); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ - float fa = (float)a; - float fb = (float)b; - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, float b) -{ - return __bfloat162float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float v) -{ - return v; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float2 v) -{ - return v.x + v.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float4 v) -{ - return v.x + v.y + v.z + v.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float sum(__nv_bfloat162 v) -{ - float2 vf = bf1622float2(v); - return vf.x + vf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_4_t v) -{ - return sum(v.x) + sum(v.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_8_t v) -{ - return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint16_t v) -{ - return half_to_float(v); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint32_t v) -{ - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint2 v) -{ - uint32_t c = add(v.x, v.y); - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint4 v) -{ -#if 1 - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); -#else - uint32_t c = add(v.x, v.y); - uint32_t d = add(v.z, v.w); - c = add(c, d); -#endif - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float4_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float8_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t& dst) -{ - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void zero(T& dst) -{ - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base) -{ - const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); - return {cos(inv_freq), sin(inv_freq)}; -} - -inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) -{ - float2 rot_v; - rot_v.x = coef.x * v.x - coef.y * v.y; - rot_v.y = coef.x * v.y + coef.y * v.x; - return rot_v; -} - -inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) -{ - float2 fv = half2_to_float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return float2_to_half2(rot_fv); -} - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) -{ - float2 fv = bf1622float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); -} -#endif - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void -apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - - vec = tmp_3.u32x2; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - tmp_3.u16[4] = tmp_1.u16[2]; - tmp_3.u16[5] = tmp_2.u16[2]; - tmp_3.u16[6] = tmp_1.u16[3]; - tmp_3.u16[7] = tmp_2.u16[3]; - - vec = tmp_3.u32x4; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - __nv_bfloat16 bf16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; -} - -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - __nv_bfloat16 bf16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; - vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; - vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; -} -#endif // ENABLE_BF16 - -template<> -__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.z = smem[transpose_idx + 1]; - vec.y = smem[smem_pitch + transpose_idx]; - vec.w = smem[smem_pitch + transpose_idx + 1]; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} -#endif - -template<> -__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} - -template -__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u32x4 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - tmp_1.u16[2] = tmp_3.u16[4]; - tmp_2.u16[2] = tmp_3.u16[5]; - tmp_1.u16[3] = tmp_3.u16[6]; - tmp_2.u16[3] = tmp_3.u16[7]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u32x2 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u32 = vec; - - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -template<> -__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[transpose_idx + 1] = vec.z; - smem[smem_pitch + transpose_idx] = vec.y; - smem[smem_pitch + transpose_idx + 1] = vec.w; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - - tmp.u32 = vec; - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} -#endif - -template<> -__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -} // namespace mmha diff --git a/server/awq_kernels/awq_cuda/attention/ft_attention.cpp b/server/awq_kernels/awq_cuda/attention/ft_attention.cpp deleted file mode 100644 index bef83ec1c..000000000 --- a/server/awq_kernels/awq_cuda/attention/ft_attention.cpp +++ /dev/null @@ -1,182 +0,0 @@ -// Adapted from NVIDIA/FasterTransformer and FlashAttention - -#include -#include "ATen/cuda/CUDAContext.h" -#include - -#include "ft_attention.h" -#include "decoder_masked_multihead_attention.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ - if (TYPE == at::ScalarType::Half) { \ - using scalar_t = at::Half; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::BFloat16) { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::Float) { \ - using scalar_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ - } - -template -void masked_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -void cross_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -struct SATypeConverter { - using Type = T; -}; - -template<> -struct SATypeConverter { - using Type = uint16_t; -}; - -template<> -struct SATypeConverter { - using Type = __nv_bfloat16; -}; - -template -void set_params(Masked_multihead_attention_params ¶ms, - const size_t batch_size, - const size_t nheads, - const size_t nheads_kv, - const size_t memory_max_seqlen, - const size_t headdim, - const int timestep, - const int rotary_embedding_dim, - const float rotary_base, - const bool neox_rotary_style, - const int qkv_batch_stride, - T *q_ptr, - T *k_ptr, - T *v_ptr, - T *k_cache_ptr, - T *v_cache_ptr, - int *length_per_sample, - float *alibi_slopes_ptr, - T *out_ptr) { - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.q_bias = nullptr; - params.k_bias = nullptr; - params.v_bias = nullptr; - params.k_cache = k_cache_ptr; - params.v_cache = v_cache_ptr; - params.linear_bias_slopes = alibi_slopes_ptr; - params.out = out_ptr; - params.cache_indir = nullptr; - params.stride = qkv_batch_stride; - params.batch_size = batch_size; - params.beam_width = 1; - params.memory_max_len = memory_max_seqlen; - params.num_heads = nheads; - params.num_kv_heads = nheads_kv; - params.hidden_size_per_head = headdim; - params.rotary_embedding_dim = rotary_embedding_dim; - params.rotary_base = rotary_base; - params.neox_rotary_style = neox_rotary_style; - params.timestep = timestep; - params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); - params.total_padding_tokens = nullptr; - params.masked_tokens = nullptr; - params.prefix_prompt_lengths = nullptr; - params.max_prefix_prompt_length = 0; - params.relative_attention_bias = nullptr; - params.relative_attention_bias_stride = 0; - params.cross_attention_out = nullptr; - params.max_decoder_seq_len = 0; - params.is_return_cross_attentions = false; - params.finished = nullptr; - params.memory_length_per_sample = nullptr; - params.length_per_sample = length_per_sample; -} - -torch::Tensor single_query_attention(const torch::Tensor q, - const torch::Tensor k, - const torch::Tensor v, - torch::Tensor k_cache, - torch::Tensor v_cache, - c10::optional length_per_sample_, - c10::optional alibi_slopes_, - const int timestep, - const int rotary_embedding_dim, - const float rotary_base, - // neox_rotary_style = not interleaved - const bool neox_rotary_style) { - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); - int batch_size = v_cache.size(0); - int nheads = q.size(1); - int nheads_kv = v_cache.size(1); - int memory_max_seqlen = v_cache.size(2); - int headdim = v_cache.size(3); - CHECK_SHAPE(q, batch_size, nheads, headdim); - CHECK_SHAPE(k, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); - // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 - int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; - CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); - TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); - TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); - TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); - // TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0)); - CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); - - if (length_per_sample_.has_value()) { - auto length_per_sample = length_per_sample_.value(); - CHECK_DEVICE(length_per_sample); - CHECK_SHAPE(length_per_sample, batch_size); - CHECK_CONTIGUOUS(length_per_sample); - TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); - } - - if (alibi_slopes_.has_value()) { - auto alibi_slopes = alibi_slopes_.value(); - CHECK_DEVICE(alibi_slopes); - CHECK_SHAPE(alibi_slopes, nheads); - CHECK_CONTIGUOUS(alibi_slopes); - TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32); - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - torch::Tensor out = torch::empty_like(q); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { - using DataType = typename SATypeConverter::Type; - Masked_multihead_attention_params params; - set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, - timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), - reinterpret_cast(q.data_ptr()), - reinterpret_cast(k.data_ptr()), - reinterpret_cast(v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - length_per_sample_.has_value() - ? length_per_sample_.value().data_ptr() : nullptr, - alibi_slopes_.has_value() - ? alibi_slopes_.value().data_ptr(): nullptr, - reinterpret_cast(out.data_ptr())); - auto stream = at::cuda::getCurrentCUDAStream(); - masked_multihead_attention(params, stream); - }); - return out; -} \ No newline at end of file diff --git a/server/awq_kernels/awq_cuda/attention/ft_attention.h b/server/awq_kernels/awq_cuda/attention/ft_attention.h deleted file mode 100644 index df67d5363..000000000 --- a/server/awq_kernels/awq_cuda/attention/ft_attention.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once -#include - - -torch::Tensor single_query_attention(const torch::Tensor q, - const torch::Tensor k, - const torch::Tensor v, - torch::Tensor k_cache, - torch::Tensor v_cache, - c10::optional length_per_sample_, - c10::optional alibi_slopes_, - const int timestep, - const int rotary_embedding_dim = 0, - const float rotary_base = 10000.0f, - const bool neox_rotary_style=true); \ No newline at end of file diff --git a/server/awq_kernels/awq_cuda/layernorm/layernorm.cu b/server/awq_kernels/awq_cuda/layernorm/layernorm.cu deleted file mode 100644 index 0a5d2d251..000000000 --- a/server/awq_kernels/awq_cuda/layernorm/layernorm.cu +++ /dev/null @@ -1,113 +0,0 @@ -/* - -Adapted from NVIDIA FasterTransformer: -https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu - -*/ - -#include -#include -#include "reduction.cuh" -#include "layernorm.h" -#include -#include - -static inline __device__ float to_float(half src) -{ - return __half2float(src); -} - -static inline __device__ float to_float(float src) -{ - return src; -} - -template -__global__ void generalT5LayerNorm( - const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n) -{ - // layernorm module in the T5 style No bias and no subtraction of mean. - const int tid = threadIdx.x; - - __shared__ float s_variance; - float variance = 0.0f; - - float local_var_sum = 0.0f; - for (int i = tid; i < n; i += blockDim.x) { - float diff = to_float(__ldg(&input[blockIdx.x * n + i])); - local_var_sum += diff * diff; - } - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / (float)n + layernorm_eps); - } - __syncthreads(); - - for (int i = tid; i < n; i += blockDim.x) { - output[blockIdx.x * n + i] = - clamp_inf_for_half((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i]))); - } -} - - -template -void invokeGeneralT5LayerNorm(T* out, - const T* input, - const T* gamma, - // const T* beta, - const float layernorm_eps, - const int m, - const int n) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - - /* For general cases, n is equal to hidden_units, e.g., 512/1024. - Since we have warp shuffle inside the code, block.x % 32 should be 0. - */ - if (n % 32 != 0) { - block.x = 1024; - } - - block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x - - /* should pay attention to the rsqrt precision*/ - generalT5LayerNorm<<>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3 -} - -template void invokeGeneralT5LayerNorm(half* out, - const half* input, - const half* gamma, - // const half* beta, - const float layernorm_eps, - const int m, - const int n); - -template void invokeGeneralT5LayerNorm(float* out, - const float* input, - const float* gamma, - // const half* beta, - const float layernorm_eps, - const int m, - const int n); - - - -// input b, n, c -void layernorm_forward_cuda( - torch::Tensor _input, - torch::Tensor _gamma, - torch::Tensor _out, - float eps) -{ - int m = _input.size(0) * _input.size(1); - int n = _input.size(2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_input)); - - auto input = reinterpret_cast(_input.data_ptr()); - auto gamma = reinterpret_cast(_gamma.data_ptr()); - auto out = reinterpret_cast(_out.data_ptr()); - - invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n); -} diff --git a/server/awq_kernels/awq_cuda/layernorm/layernorm.h b/server/awq_kernels/awq_cuda/layernorm/layernorm.h deleted file mode 100644 index de43ccac6..000000000 --- a/server/awq_kernels/awq_cuda/layernorm/layernorm.h +++ /dev/null @@ -1,3 +0,0 @@ -#include - -void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps); diff --git a/server/awq_kernels/awq_cuda/layernorm/reduction.cuh b/server/awq_kernels/awq_cuda/layernorm/reduction.cuh deleted file mode 100644 index f670d1857..000000000 --- a/server/awq_kernels/awq_cuda/layernorm/reduction.cuh +++ /dev/null @@ -1,82 +0,0 @@ -/* - -Adapted from NVIDIA FasterTransformer: -https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh -*/ - -#pragma once -#include -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) -#include -#else -#include -#endif -#include -#include -#include -#include - -#define HALF_FLT_MAX 65504.F -#define FINAL_MASK 0xffffffff - - -template -inline __device__ T add(T a, T b) { - return a + b; -} - -template<> -inline __device__ half2 add(half2 a, half2 b) { - return __hadd2(a, b); -} - -template<> -inline __device__ half add(half a, half b) { - return __hadd(a, b); -} - -template -__inline__ __device__ T warpReduceSum(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} - -/* Calculate the sum of all elements in a block */ -template -__inline__ __device__ T blockReduceSum(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if (lane == 0) - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); - - return val; -} - - -template -__device__ __forceinline__ T clamp_inf_for_half(const float input) -{ - return input; -} - -template<> -__device__ __forceinline__ half clamp_inf_for_half(const float input) -{ - // clamp inf values to enable fp16 training - return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000)); -} diff --git a/server/awq_kernels/awq_cuda/position_embedding/pos_encoding.h b/server/awq_kernels/awq_cuda/position_embedding/pos_encoding.h deleted file mode 100644 index 04e205b23..000000000 --- a/server/awq_kernels/awq_cuda/position_embedding/pos_encoding.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once -#include - -void rotary_embedding_neox( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache); \ No newline at end of file diff --git a/server/awq_kernels/awq_cuda/position_embedding/pos_encoding_kernels.cu b/server/awq_kernels/awq_cuda/position_embedding/pos_encoding_kernels.cu deleted file mode 100644 index 883b59c41..000000000 --- a/server/awq_kernels/awq_cuda/position_embedding/pos_encoding_kernels.cu +++ /dev/null @@ -1,88 +0,0 @@ -/* - -Adapted from the VLLM project: -https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu - -*/ - -#include -#include -#include "pos_encoding.h" - -template -__global__ void rotary_embedding_neox_kernel( - const int64_t* __restrict__ positions, // [num_tokens] - scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int stride, - const int num_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - - const int embed_dim = rot_dim / 2; - const int n = num_heads * embed_dim; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int token_head = token_idx * stride + head_idx * head_size; - - const int rot_offset = i % embed_dim; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int out_x = token_idx * stride + head_idx * head_size + x_index; - const int out_y = token_idx * stride + head_idx * head_size + y_index; - - const scalar_t cos = __ldg(cache_ptr + x_index); - const scalar_t sin = __ldg(cache_ptr + y_index); - - const scalar_t q_x = query[token_head + x_index]; - const scalar_t q_y = query[token_head + y_index]; - query[out_x] = q_x * cos - q_y * sin; - query[out_y] = q_y * cos + q_x * sin; - - const scalar_t k_x = key[token_head + x_index]; - const scalar_t k_y = key[token_head + y_index]; - key[out_x] = k_x * cos - k_y * sin; - key[out_y] = k_y * cos + k_x * sin; - } -} - -void rotary_embedding_neox( - torch::Tensor& positions, // [b, num_tokens] - torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size] - torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size] - int head_size, - torch::Tensor& cos_sin_cache) // [max_position, rot_dim] -{ - int num_tokens = query.size(0) * query.size(1); - int rot_dim = cos_sin_cache.size(1); - int num_heads = query.size(-2); - int stride = num_heads * head_size; - // TORCH_CHECK(stride == key.stride(0)); - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - query.scalar_type(), - "rotary_embedding_neox", - [&] { - rotary_embedding_neox_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - stride, - num_heads, - head_size); - }); -} - diff --git a/server/awq_kernels/awq_cuda/pybind_awq.cpp b/server/awq_kernels/awq_cuda/pybind_awq.cpp deleted file mode 100644 index e4b828483..000000000 --- a/server/awq_kernels/awq_cuda/pybind_awq.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include -#include -#include "layernorm/layernorm.h" -#include "quantization/gemm_cuda.h" -#include "quantization/gemv_cuda.h" -#include "position_embedding/pos_encoding.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); - m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); - m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel."); - m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel."); - m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); -} \ No newline at end of file diff --git a/server/awq_kernels/awq_cuda/pybind_ft.cpp b/server/awq_kernels/awq_cuda/pybind_ft.cpp deleted file mode 100644 index 17c5c860c..000000000 --- a/server/awq_kernels/awq_cuda/pybind_ft.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include -#include "attention/ft_attention.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("single_query_attention", &single_query_attention, "Attention with a single query", - py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), - py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, - py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); -} \ No newline at end of file diff --git a/server/awq_kernels/awq_cuda/quantization/dequantize.cuh b/server/awq_kernels/awq_cuda/quantization/dequantize.cuh deleted file mode 100644 index 5d333b35c..000000000 --- a/server/awq_kernels/awq_cuda/quantization/dequantize.cuh +++ /dev/null @@ -1,79 +0,0 @@ -/* -Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - -@article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} -} -*/ - -#pragma once - - -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) -{ - uint4 result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. - - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. - - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; - - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - - return result; -} - diff --git a/server/awq_kernels/awq_cuda/quantization/gemm_cuda.h b/server/awq_kernels/awq_cuda/quantization/gemm_cuda.h deleted file mode 100644 index 6793658d4..000000000 --- a/server/awq_kernels/awq_cuda/quantization/gemm_cuda.h +++ /dev/null @@ -1,7 +0,0 @@ -#include - -torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); - -torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters); \ No newline at end of file diff --git a/server/awq_kernels/awq_cuda/quantization/gemm_cuda_gen.cu b/server/awq_kernels/awq_cuda/quantization/gemm_cuda_gen.cu deleted file mode 100644 index 207c171c6..000000000 --- a/server/awq_kernels/awq_cuda/quantization/gemm_cuda_gen.cu +++ /dev/null @@ -1,848 +0,0 @@ -/* - -@article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} -} - - */ - -#include -#include -#include "gemm_cuda.h" -#include "dequantize.cuh" -#include -#include - - -// Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); - return (v1 << 16) | v0; -} - -__device__ __forceinline__ int make_divisible(int c, int divisor){ - return (c + divisor - 1) / divisor; -} - -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) -{ - static constexpr uint32_t ZERO = 0x0; - float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (128 + 8)]; - - int j_factors1 = ((OC + 128 - 1) / 128); - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); - int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); - - half A_shared_warp[8]; - half B_shared_warp[32]; - for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[(j_0_4_init * 8) + i] = 0.0; - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride = 2 * 32 * 8 / 128; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; - - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * 2 - + (((int)threadIdx.x) / (128 / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (128 / 8) - + (((int)threadIdx.x) % (128 / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) - + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) - + (((int)threadIdx.x) % (128 / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (128 / 8) - + ((int)threadIdx.x) % (128 / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (128) - + (((int)threadIdx.x) % (128 / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 128 - + ((int)threadIdx.y) * 64 - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { - *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); - } - - // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); - /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); - } - */ - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); - int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); - // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); - } - */ - - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; - } - __syncthreads(); - - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - - - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); - } - - for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); - } - } - for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#endif - } - } - } - -// TODO: Shang: Hoist loop invariance. - for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); - } - } - } -} - - -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) -{ - static constexpr uint32_t ZERO = 0x0; - float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (64 + 8)]; - - __shared__ half scaling_factors_shared[64]; - __shared__ half zeros_shared[64]; - - int j_factors1 = ((OC + 64 - 1) / 64); - - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); - int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); - - half A_shared_warp[8]; - half B_shared_warp[16]; - for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[(j_0_4_init * 8) + i] = 0.0; - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride = 2 * 32 * 8 / 64; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; - - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * 4 - + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (64 / 8) - + (((int)threadIdx.x) % (64 / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) - + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) - + (((int)threadIdx.x) % (64 / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (64 / 8) - + ((int)threadIdx.x) % (64 / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (64) - + (((int)threadIdx.x) % (64 / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 64 - + ((int)threadIdx.y) * 32 - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { - *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); - } - - // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); - /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); - } - */ - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); - int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); - // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); - } - */ - - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16; - } - __syncthreads(); - - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) - { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); - } - - - for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) - { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); - } - } - - for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) - { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#endif - } - } - } - -// TODO: Shang: Hoist loop invariance. - for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); - } - } - } -} - -template -__global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* zeros, int M, int IC, int OC, half* __restrict__ C) -{ - static constexpr uint32_t ZERO = 0x0; - float C_warp[64]; - __shared__ half A_shared[128 * (32 + 8)]; - __shared__ half B_shared[64 * (32 + 8)]; - - // __shared__ half scaling_factors_shared[64]; - // __shared__ half zeros_shared[64]; - - int j_factors1 = ((OC + 64 - 1) / 64); - - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1); - int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1); - - half A_shared_warp[32]; - half B_shared_warp[16]; - for (int i_0_3_init = 0; i_0_3_init < 4; ++i_0_3_init) { - for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[((i_0_3_init * 16) + (j_0_4_init * 8)) + i] = 0.0; - } - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride_A = 4 * 32 * 8 / 32; - static constexpr int row_stride = 4 * 32 * 8 / 32; - const int make_divisible_multipler = 128 / G; - const int zeros_w = make_divisible(make_divisible(IC / G, 8), make_divisible_multipler) * make_divisible_multipler; - const int sf_w = zeros_w * 8; - - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; - int ld_A_row = (blockIdx_y / j_factors1 * 128 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; - - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 128 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (IC / 8) * 8 - + (((int)threadIdx.x) / (32 / 8)) * (IC / 8) - + (((int)blockIdx_y) % j_factors1) * 64 * (IC / 8) - + (((int)threadIdx.x) % (32 / 8)) * 1; - -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 4) * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8)) * 8; - - - int* zeros_ptr = zeros - + ((int)threadIdx.y) * zeros_w * 8 - + (((int)threadIdx.x) / (32 / 8)) * zeros_w - + (((int)blockIdx_y) % j_factors1) * 64 * zeros_w - // this term is zero - + (((int)threadIdx.x) % (32 / 8)) / G ; - - half* scaling_factors_ptr = scaling_factors - + ((int)threadIdx.y) * sf_w * 8 - + (((int)threadIdx.x) / (32 / 8)) * sf_w - + (((int)blockIdx_y) % j_factors1) * (64) * sf_w - // this term is zero - + (((int)threadIdx.x) % (32 / 8)) * 8 / G; - - - // Haotian: TBD, check, May 29 11:46 AM PST - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdx_z -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 64 - + (((int)threadIdx.y) / 2) * 32 - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = make_divisible(IC / 32, split_k_iters); // (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1; - - // TODO (Haotian): load scales and zero points to smem - - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: Here we assume M % cta_M = 0. - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) - { - if (ld_A_row + ax0_ax1_fused_0 * row_stride_A < M) - { - *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = *(uint4*)(A_ptr + (ax0_ax1_fused_0 * row_stride_A * IC) + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = make_uint4(0, 0, 0, 0); - } - } - - - int* zeros_ptr_local = zeros_ptr + k_0_0 * 32 / G / 8; - half* scaling_factors_ptr_local = scaling_factors_ptr + k_0_0 * 32 / G; - - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); - int* B_ptr_local = B_ptr + k_0_0 * (32 / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - int B_loaded_current = *(B_ptr_local + ax0_ax1_fused_0 * row_stride * (IC / 8)); - int zeros_loaded = *(zeros_ptr_local + ax0_ax1_fused_0 * row_stride * zeros_w); - zeros_loaded >>= ((k_0_0 * 32 / G) % 8) * 4; - float current_zeros = (float)(zeros_loaded & 0xF); - half scaling_factors_loaded = *(scaling_factors_ptr_local + ax0_ax1_fused_0 * row_stride * sf_w); - half B_loaded_fp16[8]; - #pragma unroll - for (int ic_1 = 0; ic_1 < 8; ic_1++){ - float current_single_weight_fp = (float)(B_loaded_current & 0xF); - half dequantized_weight = __float2half(__half2float(scaling_factors_loaded) * (current_single_weight_fp - current_zeros)); - B_loaded_current = B_loaded_current >> 4; - B_loaded_fp16[ic_1] = dequantized_weight; - } - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (32 + 8)) = *reinterpret_cast(B_loaded_fp16); - } - __syncthreads(); - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { - for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (k_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3]) - : "r"(addr) - ); - } - } - - for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[((((((int)threadIdx.y) >> 1) * 1280) + (ax0_0_1 * 640)) + (k_0_1 * 16))])) + ((((((int)threadIdx.x) >> 4) * 320) + ((((int)threadIdx.x) & 7) * 40)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[3]) - : "r"(addr) - ); - } - } - - for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) { - for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])); - } -#endif - } - } - } - } - -// Haotian: Here (May 29 11:46AM PST) -// TODO: Shang: Hoist loop invariance. - for (int ax0_0_2 = 0; ax0_0_2 < 4; ++ax0_0_2) { - for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 128 + (threadIdx.y % 2) * 64 + ax0_0_2 * 16 + (local_id % 4) / 2 * 8 + ((int)threadIdx.x) / 4; - if (row_offset < M) - { - *(C_ptr + ax1_0 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax0_0_2 * 16) + (ax1_0 * 8) + local_id]); - } - } - } - } -} - -// in_feats: M, IC [float16] -// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] -// scaling_factors: IC // G, OC [float16] -// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] -// assume that batch_size < 16 for now - -torch::Tensor gemmv2_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int group_size, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - // for int4, need _kernel.size(1) * 8 - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(0)}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - - // blockIdx_x: i_factors[0] * j_factors[0] - // blockIdx_y: i_factors[1] * j_factors[1] - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks((num_out_feats + 128 - 1) / 128 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 4); - if (group_size == 128) - { - gemmv2_forward_4bit_cuda_m128n64k32<128><<>>( - split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else if (group_size == 64) - { - gemmv2_forward_4bit_cuda_m128n64k32<64><<>>( - split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else - { - throw std::invalid_argument("Group size temporarily not supported."); - } - return _out_feats.sum(0); -} - -// in_feats: M, IC [float16] -// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] -// scaling_factors: IC // G, OC [float16] -// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] -// assume that batch_size < 16 for now - -torch::Tensor gemm_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - gemm_forward_4bit_cuda_m16n128k32<<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - gemm_forward_4bit_cuda_m16n64k32<<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - return _out_feats.sum(0); -} diff --git a/server/awq_kernels/awq_cuda/quantization/gemv_cuda.cu b/server/awq_kernels/awq_cuda/quantization/gemv_cuda.cu deleted file mode 100644 index d4a26a066..000000000 --- a/server/awq_kernels/awq_cuda/quantization/gemv_cuda.cu +++ /dev/null @@ -1,249 +0,0 @@ -// Inspired by https://github.com/ankan-ban/llama_cu_awq -/* - -@article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} -} - -*/ - -#include -#include -#include -#include -#include "gemv_cuda.h" -#define VECTORIZE_FACTOR 8 -#define Q_VECTORIZE_FACTOR 8 -#define PACK_FACTOR 8 -#define WARP_SIZE 32 - - -// Reduce sum within the warp using the tree reduction algorithm. -__device__ __forceinline__ float warp_reduce_sum(float sum) { - #pragma unroll - for(int i = 4; i >= 0; i--){ - sum += __shfl_down_sync(0xffffffff, sum, 1<(zeros + oc_idx * zeros_w + packed_group_idx * 2); - uint32_t packed_weights[4]; - // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) - *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); - // load scaling factors - // g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups. - float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); - float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF); - int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; - const float4* inputs_ptr = inputs + inputs_ptr_delta; - // multiply 32 weights with 32 inputs - #pragma unroll - for (int ic_0 = 0; ic_0 < 4; ic_0++){ - // iterate over different uint32_t packed_weights in this loop - uint32_t current_packed_weight = packed_weights[ic_0]; - half packed_inputs[PACK_FACTOR]; - // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) - if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { - *((float4*)packed_inputs) = *(inputs_ptr + ic_0); - #pragma unroll - for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ - // iterate over 8 numbers packed within each uint32_t number - float current_single_weight_fp = (float)(current_packed_weight & 0xF); - float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); - //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); - psum += dequantized_weight * __half2float(packed_inputs[ic_1]); - current_packed_weight = current_packed_weight >> 4; - } - } - } - } - psum = warp_reduce_sum(psum); - if (threadIdx.x == 0) { - outputs[oc_idx] = __float2half(psum); - } -} - - -/* -Computes GEMV (group_size = 128). - -Args: - inputs: vector of shape [batch_size, IC]; - weight: matrix of shape [OC, IC / 8]; - output: vector of shape [OC]; - zeros: matrix of shape [OC, IC / group_size / 8]; - scaling_factors: matrix of shape [OC, IC / group_size]; - -Notes: - One cannot infer group_size from the shape of scaling factors. - the second dimension is rounded up to a multiple of PACK_FACTOR. -*/ -__global__ void gemv_kernel_g128( - const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs, - const int IC, const int OC){ - const int group_size = 128; - float psum = 0; - const int batch_idx = blockIdx.z; - const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; - const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; - half* outputs = _outputs + batch_idx * OC; - const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); - const int weight_w = IC / PACK_FACTOR; - // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address - const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); - // consistent with input shape - const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; - //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); - // tile size: 4 OC x 1024 IC per iter - for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ - // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. - uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx); - uint32_t packed_weights[4]; - // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) - *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); - // load scaling factors - // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. - float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); - float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF); - int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; - const float4* inputs_ptr = inputs + inputs_ptr_delta; - // multiply 32 weights with 32 inputs - #pragma unroll - for (int ic_0 = 0; ic_0 < 4; ic_0++){ - // iterate over different uint32_t packed_weights in this loop - uint32_t current_packed_weight = packed_weights[ic_0]; - half packed_inputs[PACK_FACTOR]; - // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) - if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { - *((float4*)packed_inputs) = *(inputs_ptr + ic_0); - #pragma unroll - for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ - // iterate over 8 numbers packed within each uint32_t number - float current_single_weight_fp = (float)(current_packed_weight & 0xF); - float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); - //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); - psum += dequantized_weight * __half2float(packed_inputs[ic_1]); - current_packed_weight = current_packed_weight >> 4; - } - } - } - } - psum = warp_reduce_sum(psum); - if (threadIdx.x == 0) { - outputs[oc_idx] = __float2half(psum); - } -} - - -/* -Computes GEMV (PyTorch interface). - -Args: - _in_feats: tensor of shape [B, IC]; - _kernel: int tensor of shape [OC, IC // 8]; - _zeros: int tensor of shape [OC, IC // G // 8]; - _scaling_factors: tensor of shape [OC, IC // G]; - blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; - blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; - -Returns: - out_feats: tensor of shape [B, OC]; -*/ -torch::Tensor gemv_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int group_size) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - // int kernel_volume = _out_in_map.size(1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - // auto out_in_map = _out_in_map.data_ptr(); - auto options = - torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - // kernel is [OC, IC] - at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - int blockDim_z = num_out_feats; - dim3 num_blocks(1, num_out_channels / 4, num_out_feats); - dim3 num_threads(32, 4); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (group_size == 64) - { - gemv_kernel_g64<<>>( - // pointers - in_feats, kernel, zeros, scaling_factors, out_feats, - // constants - num_in_channels, num_out_channels - ); - } - else if (group_size == 128) - { - gemv_kernel_g128<<>>( - // pointers - in_feats, kernel, zeros, scaling_factors, out_feats, - // constants - num_in_channels, num_out_channels - ); - } - return _out_feats; -;} - diff --git a/server/awq_kernels/awq_cuda/quantization/gemv_cuda.h b/server/awq_kernels/awq_cuda/quantization/gemv_cuda.h deleted file mode 100644 index 748abc5d1..000000000 --- a/server/awq_kernels/awq_cuda/quantization/gemv_cuda.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once -#include - -torch::Tensor gemv_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int group_size); diff --git a/server/awq_kernels/setup.py b/server/awq_kernels/setup.py deleted file mode 100644 index a53d073a5..000000000 --- a/server/awq_kernels/setup.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -common_setup_kwargs = { -} - - -def get_generator_flag(): - generator_flag = [] - torch_dir = torch.__path__[0] - if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - - return generator_flag - - -def get_compute_capabilities(): - # Collect the compute capabilities of all available GPUs. - for i in range(torch.cuda.device_count()): - major, minor = torch.cuda.get_device_capability(i) - cc = major * 10 + minor - - if cc < 75: - raise RuntimeError("GPUs with compute capability less than 7.5 are not supported.") - - # figure out compute capability - compute_capabilities = {75, 80, 86, 89, 90} - - capability_flags = [] - for cap in compute_capabilities: - capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"] - - return capability_flags - -generator_flags = get_generator_flag() -arch_flags = get_compute_capabilities() - - -extra_compile_args={ - "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"], - "nvcc": [ - "-O3", - "-std=c++17", - "-DENABLE_BF16", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] + arch_flags + generator_flags -} - -extensions = [ - CUDAExtension( - "awq_inference_engine", - [ - "awq_cuda/pybind_awq.cpp", - "awq_cuda/quantization/gemm_cuda_gen.cu", - "awq_cuda/layernorm/layernorm.cu", - "awq_cuda/position_embedding/pos_encoding_kernels.cu", - "awq_cuda/quantization/gemv_cuda.cu" - ], extra_compile_args=extra_compile_args - ) -] - -extensions.append( - CUDAExtension( - "ft_inference_engine", - [ - "awq_cuda/pybind_ft.cpp", - "awq_cuda/attention/ft_attention.cpp", - "awq_cuda/attention/decoder_masked_multihead_attention.cu" - ], extra_compile_args=extra_compile_args - ) -) - -additional_setup_kwargs = { - "ext_modules": extensions, - "cmdclass": {'build_ext': BuildExtension} -} - -common_setup_kwargs.update(additional_setup_kwargs) - -setup( - **common_setup_kwargs -) \ No newline at end of file From 8449a9498f73ed64f20f566d4227b1772613ed25 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 09:27:29 +0200 Subject: [PATCH 03/32] refactor attention --- .devcontainer/devcontainer.json | 2 +- .../custom_modeling/flash_cohere_modeling.py | 6 +- .../custom_modeling/flash_dbrx_modeling.py | 6 +- .../custom_modeling/flash_gemma_modeling.py | 6 +- .../custom_modeling/flash_gpt2_modeling.py | 6 +- .../custom_modeling/flash_llama_modeling.py | 6 +- .../custom_modeling/flash_mistral_modeling.py | 6 +- .../custom_modeling/flash_mixtral_modeling.py | 6 +- .../custom_modeling/flash_neox_modeling.py | 6 +- .../custom_modeling/flash_phi3_modeling.py | 6 +- .../custom_modeling/flash_phi_modeling.py | 6 +- .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../custom_modeling/flash_qwen_modeling.py | 6 +- .../custom_modeling/flash_rw_modeling.py | 10 +- .../flash_santacoder_modeling.py | 6 +- server/lorax_server/utils/flash_attn.py | 267 ++++-- .../lorax_server/utils/flash_attn_triton.py | 816 ++++++++++++++++++ server/lorax_server/utils/import_utils.py | 50 ++ server/lorax_server/utils/paged_attention.py | 137 +++ 19 files changed, 1265 insertions(+), 95 deletions(-) create mode 100644 server/lorax_server/utils/flash_attn_triton.py create mode 100644 server/lorax_server/utils/import_utils.py create mode 100644 server/lorax_server/utils/paged_attention.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d17986d4e..d8a441b9d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,6 +1,6 @@ { "build": { - "dockerfile": "../Dockerfile" + "dockerfile": "../Dockerfile.dev" }, "runArgs": [ "--gpus", diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index d4857e480..8e1bd1292 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,7 +28,7 @@ from transformers.activations import ACT2FN from lorax_server.adapters.weights import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, MultiAdapterHead, @@ -277,7 +277,7 @@ def forward( self.rotary_emb(query, key, cos, sin) - paged_attn.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -296,7 +296,7 @@ def forward( ) # Decode else: - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index ee1685f0a..44c01c8b2 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -24,7 +24,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters.weights import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, FastLinear, @@ -414,7 +414,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -433,7 +433,7 @@ def forward( ) # Decode else: - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index c471ba414..b036224c9 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -23,7 +23,7 @@ # Flash attention imports from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( PositionRotaryEmbedding, TensorParallelAdapterRowLinear, @@ -262,7 +262,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -282,7 +282,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index d6cbf4311..77bed2134 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -28,7 +28,7 @@ from transformers.models.gpt2 import GPT2Config from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, TensorParallelAdapterRowLinear, @@ -155,7 +155,7 @@ def forward( qkv = self.c_attn(hidden_states, adapter_data) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - paged_attn.reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(qkv[:, 0]) @@ -175,7 +175,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 4941f1608..810cb8c5b 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -29,7 +29,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, @@ -306,7 +306,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -326,7 +326,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 4fbeb9885..bc39f6dd6 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -29,7 +29,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2 from lorax_server.utils.layers import ( MultiAdapterHead, @@ -319,7 +319,7 @@ def forward( else: kv_to_cache = kv - paged_attn.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -340,7 +340,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index ab6859f70..404abd0fe 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -31,7 +31,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2 from lorax_server.utils.layers import ( FastLinear, @@ -399,7 +399,7 @@ def forward( else: kv_to_cache = kv - paged_attn.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -419,7 +419,7 @@ def forward( ) # Decode else: - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index 2f042236c..072d3ee79 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -27,7 +27,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, PositionRotaryEmbedding, @@ -131,7 +131,7 @@ def forward( self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - paged_attn.reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(qkv[:, 0]) @@ -151,7 +151,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index 41fc573f8..aa8b6a302 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -29,7 +29,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, @@ -271,7 +271,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -291,7 +291,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index c52312132..8827fd5fd 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -17,7 +17,7 @@ from transformers.activations import ACT2FN from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, MultiAdapterHead, @@ -157,7 +157,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -177,7 +177,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index f4d327611..6b6bf0374 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -14,7 +14,7 @@ from transformers.activations import ACT2FN from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, @@ -227,7 +227,7 @@ def forward( else: kv_to_cache = kv - paged_attn.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -248,7 +248,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index cdf1958ca..0ad3c0cbe 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -15,7 +15,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, @@ -244,7 +244,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -264,7 +264,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index d60be3583..1da4b5329 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, PositionRotaryEmbedding, @@ -171,7 +171,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attn.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output attn_output = torch.empty_like(query) @@ -191,7 +191,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -279,7 +279,7 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) - paged_attn.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, :, 0].contiguous(), kv[:, :, 1].contiguous(), kv_cache[0], @@ -305,7 +305,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, num_groups, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index cea2319d6..088be68b2 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,7 +5,7 @@ from torch import nn from transformers.activations import ACT2FN -from lorax_server.utils import flash_attn, paged_attn +from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, TensorParallelColumnLinear, @@ -231,7 +231,7 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - paged_attn.reshape_and_cache(key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache(key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots) # output attn_output = torch.empty_like(query) @@ -251,7 +251,7 @@ def forward( # Decode else: # kv_cache[1] => [num_blocks, 1, head_size, block_size] - paged_attn.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index 2181c21b4..2287b5b89 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -1,60 +1,145 @@ import os - import torch + from loguru import logger +import math + +from lorax_server.utils.import_utils import SYSTEM +from lorax_server.utils.flash_attn_triton import triton_attention if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") +HAS_FLASH_ATTN = False +HAS_FLASH_ATTN_V2_CUDA = False +HAS_FLASH_ATTN_V2_ROCM = False +ROCM_USE_FLASH_ATTN_V2_CK = False +ROCM_USE_FLASH_ATTN_V2_TRITON = False -if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") +if SYSTEM == "xpu": + import intel_extension_for_pytorch as ipex -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -is_sm8x = major == 8 and minor >= 0 -is_sm90 = major == 9 and minor == 0 + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") -HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2 = False -try: - try: - import flash_attn_2_cuda - except ImportError: - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/predibase/lorax:latest) " - "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + if window_size_left != -1: + raise ValueError( + f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + return ipex.llm.functional.varlen_attention( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, ) - if not (is_sm8x or is_sm90): - raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported for " "Flash Attention V2") - HAS_FLASH_ATTN_V2 = True -except ImportError as e: + + +if SYSTEM in {"cuda", "rocm"}: + if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + is_sm94 = major == 9 and minor == 4 + + if SYSTEM == "rocm": + if ( + os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" + or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1" + ): + ROCM_USE_FLASH_ATTN_V2_TRITON = True + logger.info("ROCm: using Flash Attention 2 Triton implementation.") + else: + ROCM_USE_FLASH_ATTN_V2_CK = True + logger.info( + "ROCm: using Flash Attention 2 Composable Kernel implementation." + ) + try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/predibase/lorax:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if not (is_sm75 or is_sm8x or is_sm90): - raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported") from e - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True - - -def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, -): - if HAS_FLASH_ATTN_V2: + try: + import flash_attn_2_cuda + except ImportError: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + if SYSTEM == "cuda" and not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94): + raise ImportError( + f"AMD GPU with compute capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda" + HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm" + except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/predibase/lorax:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + elif SYSTEM == "rocm": + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name( + idx + ) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +if HAS_FLASH_ATTN_V2_CUDA: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( q, k, @@ -62,21 +147,102 @@ def attention( out, cu_seqlens, cu_seqlens, + None, + None, + None, max_s, max_s, 0.0, softmax_scale, False, - True, + causal, window_size_left, 0, False, None, ) - if HAS_FLASH_ATTN: +elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + if window_size_left != -1: + raise ValueError( + f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + + # RoCm flash API does not take the window_size_left and window_size_right arguments. + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + causal, + False, + None, + ) + +elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + output, _ = triton_attention( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + causal, + softmax_scale, + ) + return output + +elif HAS_FLASH_ATTN: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): if window_size_left != -1: - raise NotImplementedError("window_size_left is only available with flash attn v2") + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) # Flash attention v1 requires q, k and v to have the same number of heads if k.shape[1] != q.shape[1]: @@ -122,4 +288,5 @@ def attention( None, ) - raise NotImplementedError("flash attention is not installed") +else: + raise NotImplementedError("flash attention is not installed") \ No newline at end of file diff --git a/server/lorax_server/utils/flash_attn_triton.py b/server/lorax_server/utils/flash_attn_triton.py new file mode 100644 index 000000000..3fe322311 --- /dev/null +++ b/server/lorax_server/utils/flash_attn_triton.py @@ -0,0 +1,816 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets( + philox_seed, philox_offset, dropout_p, m, n, stride + ).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn( + bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" + ) + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = ( + batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + + start_n + - BLOCK_N + ) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, BLOCK_N) + ) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N + ) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = ( + off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + ) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k + ) + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, n_full_blocks) + ) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full( + (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 + ) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/server/lorax_server/utils/import_utils.py b/server/lorax_server/utils/import_utils.py new file mode 100644 index 000000000..f54987eb1 --- /dev/null +++ b/server/lorax_server/utils/import_utils.py @@ -0,0 +1,50 @@ +import torch + + +def is_xpu_available(): + try: + import intel_extension_for_pytorch + except ImportError: + return False + + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def get_cuda_free_memory(device, memory_fraction): + total_free_memory, _ = torch.cuda.mem_get_info(device) + total_gpu_memory = torch.cuda.get_device_properties(device).total_memory + free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) + return free_memory + + +def get_xpu_free_memory(device): + total_gpu_memory = torch.xpu.get_device_properties(device).total_memory + free_memory = int(total_gpu_memory * 0.5) + return free_memory + + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif torch.version.cuda is not None and torch.cuda.is_available(): + SYSTEM = "cuda" + empty_cache = torch.cuda.empty_cache + synchronize = torch.cuda.synchronize + get_free_memory = get_cuda_free_memory +elif is_xpu_available(): + SYSTEM = "xpu" + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory +else: + SYSTEM = "cpu" + + def noop(*args, **kwargs): + pass + + empty_cache = noop + synchronize = noop + get_free_memory = noop diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py new file mode 100644 index 000000000..ad18fc65c --- /dev/null +++ b/server/lorax_server/utils/paged_attention.py @@ -0,0 +1,137 @@ +import torch +from lorax_server.utils.import_utils import SYSTEM + +_PARTITION_SIZE = 512 + +if SYSTEM == "xpu": + import intel_extension_for_pytorch as ipex +else: + try: + from vllm._C import cache_ops + from vllm._C import ops + except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + if SYSTEM == "xpu": + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) + + +def attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + if SYSTEM == "xpu": + query = query.contiguous() + return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + ) + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) From e446812ab9c43eaab4bbeb64d7956e87d49dc58d Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 14:18:11 +0200 Subject: [PATCH 04/32] more layers --- server/layers/__init__.py | 14 + server/layers/awq/conversion_utils.py | 97 +++ server/layers/awq/quantize/qmodule.py | 50 ++ server/layers/bnb.py | 106 +++ server/layers/conv.py | 41 + server/layers/eetq.py | 25 + server/layers/fp8.py | 43 + server/layers/gptq/__init__.py | 39 + server/layers/gptq/custom_autotune.py | 261 ++++++ server/layers/gptq/exllama.py | 132 +++ server/layers/gptq/exllamav2.py | 234 ++++++ server/layers/gptq/quant_linear.py | 356 ++++++++ server/layers/gptq/quantize.py | 1002 +++++++++++++++++++++++ server/layers/hqq.py | 33 + server/layers/layernorm.py | 185 +++++ server/layers/linear.py | 171 ++++ server/layers/rotary.py | 419 ++++++++++ server/layers/speculative.py | 52 ++ server/layers/tensor_parallel.py | 187 +++++ server/lorax_server/utils/layers.py | 447 +--------- server/lorax_server/utils/paged_attn.py | 99 --- 21 files changed, 3449 insertions(+), 544 deletions(-) create mode 100644 server/layers/__init__.py create mode 100644 server/layers/awq/conversion_utils.py create mode 100644 server/layers/awq/quantize/qmodule.py create mode 100644 server/layers/bnb.py create mode 100644 server/layers/conv.py create mode 100644 server/layers/eetq.py create mode 100644 server/layers/fp8.py create mode 100644 server/layers/gptq/__init__.py create mode 100644 server/layers/gptq/custom_autotune.py create mode 100644 server/layers/gptq/exllama.py create mode 100644 server/layers/gptq/exllamav2.py create mode 100644 server/layers/gptq/quant_linear.py create mode 100644 server/layers/gptq/quantize.py create mode 100644 server/layers/hqq.py create mode 100644 server/layers/layernorm.py create mode 100644 server/layers/linear.py create mode 100644 server/layers/rotary.py create mode 100644 server/layers/speculative.py create mode 100644 server/layers/tensor_parallel.py delete mode 100644 server/lorax_server/utils/paged_attn.py diff --git a/server/layers/__init__.py b/server/layers/__init__.py new file mode 100644 index 000000000..6428366a3 --- /dev/null +++ b/server/layers/__init__.py @@ -0,0 +1,14 @@ +from lorax_server.layers.tensor_parallel import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, +) +from lorax_server.layers.linear import ( + get_linear, + FastLinear, +) +from lorax_server.layers.speculative import SpeculativeHead + +# Just to add the `load` methods. +from lorax_server.layers.layernorm import load_layer_norm +from lorax_server.layers.conv import load_conv2d diff --git a/server/layers/awq/conversion_utils.py b/server/layers/awq/conversion_utils.py new file mode 100644 index 000000000..b19eafbbe --- /dev/null +++ b/server/layers/awq/conversion_utils.py @@ -0,0 +1,97 @@ +import torch +from typing import List + + +AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def pack(imatrix: torch.Tensor, direction: str = "column"): + """ + Packs a 4-bit integer matrix into a packed 32-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of packing, either "column" or "row" + Returns: + qmatrix (torch.Tensor): packed matrix of integers + """ + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + if direction == "column": + imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) + + elif direction == "row": + imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) + + qmatrix = qmatrix.to(torch.int32) + + return qmatrix + + +def unpack(qmatrix: torch.Tensor, direction: str = "column"): + """ + Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. + Args: + qmatrix (torch.Tensor): matrix of packed integers + direction (str): direction of unpacking, either "column" or "row" + Returns: + imatrix (torch.Tensor): matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=qmatrix.device) + + if direction == "column": + imatrix = torch.bitwise_right_shift( + qmatrix[:, :, None], shifts[None, None, :] + ).view(qmatrix.shape[0], -1) + + elif direction == "row": + imatrix = torch.bitwise_right_shift( + qmatrix[:, None, :], shifts[None, :, None] + ).view(-1, qmatrix.shape[-1]) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + return imatrix + + +def apply_order( + imatrix: torch.Tensor, + direction: str = "column", + order: List[int] = AWQ_PACK_ORDER, +): + """ + Applies the order to a 4-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of applying order, either "column" or "row" + order (List[int]): order to apply, default is AWQ_PACK_ORDER + Returns: + imatrix (torch.Tensor): matrix of integers + """ + if direction == "column": + imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) + elif direction == "row": + imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) + + return imatrix + + +def fast_awq_to_gptq(qweight, qzeros): + # awq uses column packing for both weights and zeros + izeros = unpack(qzeros, direction="column") + iweights = unpack(qweight, direction="column") + + # Reverse the order of the iweight and izeros tensors + izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) + iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) + izeros = izeros - 1 + # exllama uses row packing for weights and column packing for zeros + qzeros = pack(izeros, direction="column") + qweight = pack(iweights, direction="row") + + return qweight, qzeros diff --git a/server/layers/awq/quantize/qmodule.py b/server/layers/awq/quantize/qmodule.py new file mode 100644 index 000000000..ca8caf508 --- /dev/null +++ b/server/layers/awq/quantize/qmodule.py @@ -0,0 +1,50 @@ +# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py + +import math +import torch +import torch.nn as nn +import awq_inference_engine # with CUDA kernels + + +# class ScaledActivation(nn.Module): +# def __init__(self, module, scales): +# super().__init__() +# self.act = module +# self.scales = nn.Parameter(scales.data) +# +# def forward(self, x): +# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) + + +class WQLinear(nn.Module): + def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + if bias: + self.bias = bias + else: + self.bias = None + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features,) + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/layers/bnb.py b/server/layers/bnb.py new file mode 100644 index 000000000..ca39919ce --- /dev/null +++ b/server/layers/bnb.py @@ -0,0 +1,106 @@ +import torch +from loguru import logger +from functools import lru_cache +import bitsandbytes as bnb +from bitsandbytes.nn import Int8Params, Params4bit + + +@lru_cache(1) +def warn_deprecate_bnb(): + logger.warning( + "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" + ) + + +class Linear8bitLt(torch.nn.Module): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__() + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) + self.weight.cuda(weight.device) + self.bias = bias + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out + + +class Linear4bit(torch.nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, + requires_grad=False, + compress_statistics=True, + quant_type=quant_type, + ) + self.compute_dtype = None + self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out diff --git a/server/layers/conv.py b/server/layers/conv.py new file mode 100644 index 000000000..7fb18ab3f --- /dev/null +++ b/server/layers/conv.py @@ -0,0 +1,41 @@ +from accelerate import init_empty_weights +import torch + + +@classmethod +def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = torch.nn.Parameter(weight) + conv2d.bias = torch.nn.Parameter(bias) + return conv2d + + +@classmethod +def load_conv2d_no_bias( + cls, prefix, weights, in_channels, out_channels, kernel_size, stride +): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = torch.nn.Parameter(weight) + conv2d.bias = None + return conv2d + + +torch.nn.Conv2d.load = load_conv2d +torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias diff --git a/server/layers/eetq.py b/server/layers/eetq.py new file mode 100644 index 000000000..fd22b5c67 --- /dev/null +++ b/server/layers/eetq.py @@ -0,0 +1,25 @@ +import torch +from EETQ import quant_weights, w8_a16_gemm + + +class EETQLinear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + if weight.dtype != torch.float16: + weight = weight.to(dtype=torch.float16) + weight = torch.t(weight).contiguous().cpu() + weight, scale = quant_weights(weight, torch.int8, False) + + self.weight = weight.cuda(device) + self.scale = scale.cuda(device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = w8_a16_gemm(input, self.weight, self.scale) + output = output + self.bias if self.bias is not None else output + return output diff --git a/server/layers/fp8.py b/server/layers/fp8.py new file mode 100644 index 000000000..dd61d0819 --- /dev/null +++ b/server/layers/fp8.py @@ -0,0 +1,43 @@ +import torch + + +def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): + device = weight.device + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale + + +class Fp8Linear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.dtype = weight.dtype + self.qweight, self.scale = fp8_quantize(weight) + + self.bias = bias if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + qinput, scale = fp8_quantize(input) + output, _ = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + return output diff --git a/server/layers/gptq/__init__.py b/server/layers/gptq/__init__.py new file mode 100644 index 000000000..1c46f4932 --- /dev/null +++ b/server/layers/gptq/__init__.py @@ -0,0 +1,39 @@ +import os +import torch +from text_generation_server.utils.import_utils import ( + SYSTEM, +) + +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 + +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + if V2: + from text_generation_server.layers.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + + HAS_EXLLAMA = "2" + else: + from text_generation_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, + ) + + HAS_EXLLAMA = "1" + + except ImportError: + pass + +from text_generation_server.layers.gptq.quant_linear import QuantLinear diff --git a/server/layers/gptq/custom_autotune.py b/server/layers/gptq/custom_autotune.py new file mode 100644 index 000000000..1eb40f1ed --- /dev/null +++ b/server/layers/gptq/custom_autotune.py @@ -0,0 +1,261 @@ +# https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + +import builtins +import math +import time +from typing import Dict + +import triton + + +class Autotuner(triton.KernelInterface): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + """ + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = ( + prune_configs_by["perf_model"], + prune_configs_by["top_k"], + ) + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current, + ) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench( + kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40 + ) + except triton.OutOfResources: + return (float("inf"), float("inf"), float("inf")) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune( + configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + prune_configs_by, + nearest_power_of_two, + ) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) + n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) + k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) + block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) + block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) + group_size_m = config.kwargs["GROUP_SIZE_M"] + + if ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) in used: + continue + + used.add( + ( + block_size_m, + block_size_n, + block_size_k, + group_size_m, + config.num_stages, + config.num_warps, + ) + ) + yield triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + }, + num_stages=config.num_stages, + num_warps=config.num_warps, + ) diff --git a/server/layers/gptq/exllama.py b/server/layers/gptq/exllama.py new file mode 100644 index 000000000..32f817dba --- /dev/null +++ b/server/layers/gptq/exllama.py @@ -0,0 +1,132 @@ +import torch +from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +none_tensor = torch.empty((1, 1), device="meta") + + +def ext_make_q4(qweight, qzeros, scales, g_idx, device): + """Construct Q4Matrix, return handle""" + return make_q4( + qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device + ) + + +def ext_q4_matmul(x, q4, q4_width): + """Matrix multiplication, returns x @ q4""" + outshape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device) + + q4_matmul(x, q4, output) + + return output.view(outshape) + + +MAX_DQ = 1 +MAX_INNER = 1 +ACT_ORDER = False +DEVICE = None + +TEMP_STATE = None +TEMP_DQ = None + + +def set_device(device): + global DEVICE + DEVICE = device + + +def create_exllama_buffers(max_total_tokens: int): + global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ + + assert DEVICE is not None, "call set_device first" + + if not ACT_ORDER: + max_total_tokens = 1 + + # This temp_state buffer is required to reorder X in the act-order case. + temp_state = torch.zeros( + (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE + ) + temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) + + # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + prepare_buffers(DEVICE, temp_state, temp_dq) + + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + TEMP_STATE, TEMP_DQ = temp_state, temp_dq + + +class Ex4bitLinear(torch.nn.Module): + """Linear layer implementation with per-group 4-bit quantization of the weights""" + + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE + assert bits == 4 + + self.device = qweight.device + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.g_idx = g_idx.cpu() if g_idx is not None else None + self.bias = bias if bias is not None else None + + if self.g_idx is not None and ( + (self.g_idx == 0).all() + or torch.equal( + g_idx.cpu(), + torch.tensor( + [i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32 + ), + ) + ): + self.empty_g_idx = True + self.g_idx = None + + assert self.device.type == "cuda" + assert self.device.index is not None + + self.q4 = ext_make_q4( + self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index + ) + + self.height = qweight.shape[0] * 8 + self.width = qweight.shape[1] + + # Infer groupsize from height of qzeros + self.groupsize = None + if self.qzeros.shape[0] > 1: + self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) + + if self.groupsize is not None: + assert groupsize == self.groupsize + + # Handle act-order matrix + if self.g_idx is not None: + if self.groupsize is None: + raise ValueError("Found group index but no groupsize. What do?") + self.act_order = True + else: + self.act_order = False + + DEVICE = self.qweight.device + + MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8) + + if self.act_order: + MAX_INNER = max(MAX_INNER, self.height, self.width) + + ACT_ORDER = True + + def forward(self, x): + out = ext_q4_matmul(x, self.q4, self.width) + + if self.bias is not None: + out.add_(self.bias) + return out diff --git a/server/layers/gptq/exllamav2.py b/server/layers/gptq/exllamav2.py new file mode 100644 index 000000000..01acbd4dc --- /dev/null +++ b/server/layers/gptq/exllamav2.py @@ -0,0 +1,234 @@ +# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 + +import torch +import torch.nn as nn + +from loguru import logger + +try: + from exllamav2_kernels import make_q_matrix, gemm_half_q_half +except ImportError: + logger.error("exllamav2_kernels not installed.") + raise + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +none_tensor = torch.empty((1, 1), device="meta") + + +def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): + """Matrix multiplication, returns x @ q4""" + output_shape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) + gemm_half_q_half(x, q_handle, output, force_cuda) + return output.view(output_shape) + + +# Group map needed for irregular group sizes + + +def make_group_map(q_groups, num_qrows): + + gr = q_groups.tolist() + group_map = [] + num_groups = len(gr) // 2 + + for i in range(num_groups): + bits = gr[i * 2] + if i < num_groups - 1: + qrows = gr[i * 2 + 3] - gr[i * 2 + 1] + else: + qrows = num_qrows - gr[i * 2 + 1] + rows = qrows * 32 // bits + for j in range(rows): + group_map += [i] + group_map += [rows - j] + + return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) + + +# Create Q matrix + + +def ext_make_q_matrix(w: dict, temp_dq, key: str = None): + """ + Create Q matrix + """ + # EXL2 + # won't work as the moment because the tensors are not the same. + if "q_weight" in w: + w["q_scale_max"] /= 256 + w["q_perm"] = w["q_perm"].short() + w["q_invperm"] = w["q_invperm"].short() + + if "q_group_map" not in w: + w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0]) + + return make_q_matrix( + w["q_weight"], + w["q_perm"], + w["q_invperm"], + w["q_scale"], + w["q_scale_max"], + w["q_groups"], + w["q_group_map"], + none_tensor, + none_tensor, + none_tensor, + temp_dq, + ) + # GPTQ + elif "qweight" in w: + if w["scales"].dtype == torch.float: + w["scales"] = w["scales"].half() + + # GPTQ with g_idx (act_order) + if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): + w["q_perm"] = torch.empty( + (w["qweight"].shape[0] * 8,), + dtype=torch.short, + device=w["qweight"].device, + ) + w["q_invperm"] = torch.empty_like(w["q_perm"]) + # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. + return make_q_matrix( + w["qweight"], + w["q_perm"], + w["q_invperm"], + none_tensor, + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + w["g_idx"].cpu(), + temp_dq, + ) + # GPTQ without g_idx + else: + return make_q_matrix( + w["qweight"], + none_tensor, + none_tensor, + none_tensor, + none_tensor, + none_tensor, + none_tensor, + w["qzeros"], + w["scales"], + none_tensor, + temp_dq, + ) + else: + RuntimeError("Cannot create handle") + + +DEVICE = None +FIXED_BYTES = 0 +LAYERS = [] + + +def set_device(device): + global DEVICE + DEVICE = device + + +def create_exllama_buffers(max_total_tokens: int): + global FIXED_BYTES, LAYERS, DEVICE + temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) + + for layer in LAYERS: + layer.post_init(temp_dq) + + +class QuantLinear(nn.Module): + QUANT_TYPE = "exllamav2" + + """Linear layer implementation with per-group 4-bit quantization of the weights""" + + # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + if bits != 4: + raise ValueError( + f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization." + ) + self.q_handle = None + self.q_tensors = None + self.bits = bits + self.maxq = 2**self.bits - 1 + self.infeatures = qweight.shape[0] // self.bits * 32 + self.outfeatures = qweight.shape[1] + self.padding = -self.outfeatures % 32 + self.outfeatures = self.outfeatures + self.padding + + self.device = qweight.device + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.g_idx = g_idx + self.bias = bias if bias is not None else None + self.group_size = groupsize + + global FIXED_BYTES, LAYERS + FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) + LAYERS.append(self) + + def post_init(self, temp_dq): + assert self.qweight.device.type == "cuda" + assert self.qweight.device.index is not None + self.q_tensors = { + "qweight": self.qweight, + "qzeros": self.qzeros, + "scales": self.scales, + "g_idx": self.g_idx, + } + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) + + # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, + # and `Memory access fault by GPU node-2` will EAT you. + self.temp_dq = temp_dq + self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) + + def forward(self, x, force_cuda=False): + output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) + + if self.bias is not None: + output.add_(self.bias) + return output + + def temp_dq_size(self): + return self.infeatures * self.outfeatures * 2 + 128 + + def temp_fwd_size(self, max_input_len, max_batch_size): + return self.outfeatures * max_input_len * max_batch_size * 4 + 128 + + def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): + return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) + + +class ExLlamaV2DeviceTensors: + + device_idx: int + scratch_bytes: int + scratch_idx: int + scratch: torch.tensor = None + + def __init__(self, device, scratch_bytes): + self.device = device + self.scratch_bytes = scratch_bytes + + def prepare(self): + self.scratch = torch.empty( + (self.scratch_bytes // 2,), dtype=torch.half, device=self.device + ) + + def get_scratch_slice(self, size_bytes): + + if self.scratch is None: + self.prepare() + + size_bytes = ((size_bytes + 127) // 128) * 128 + size_half = size_bytes // 2 + scratch_slice = self.scratch.narrow(0, 0, size_half) + return scratch_slice diff --git a/server/layers/gptq/quant_linear.py b/server/layers/gptq/quant_linear.py new file mode 100644 index 000000000..b52ceb0f7 --- /dev/null +++ b/server/layers/gptq/quant_linear.py @@ -0,0 +1,356 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_fwd + +import triton +import triton.language as tl +from . import custom_autotune + + +# code based https://github.com/fpgaminer/GPTQ-triton +@custom_autotune.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + ), + ], + key=["M", "N", "K"], + nearest_power_of_two=True, + prune_configs_by={ + "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, + "perf_model": None, + "top_k": None, + }, +) +@triton.jit +def matmul_248_kernel( + a_ptr, + b_ptr, + c_ptr, + scales_ptr, + zeros_ptr, + g_ptr, + M, + N, + K, + bits, + maxq, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_scales, + stride_zeros, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = offs_am[:, None] < M + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load( + scales_ptrs + g_idx[:, None] * stride_scales + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + g_idx[:, None] * stride_zeros + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) & maxq # eventually avoid overflow + + a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty( + (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 + ) + grid = lambda META: ( + triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), + ) + matmul_248_kernel[grid]( + input, + qweight, + output, + scales, + qzeros, + g_idx, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + ) + return output + + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + return output + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/layers/gptq/quantize.py b/server/layers/gptq/quantize.py new file mode 100644 index 000000000..ca113d8fe --- /dev/null +++ b/server/layers/gptq/quantize.py @@ -0,0 +1,1002 @@ +import time +import torch.nn as nn +import math +import json +import os +import torch +import transformers + +from texttable import Texttable +from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer +from huggingface_hub import HfApi +from accelerate import init_empty_weights +from text_generation_server.utils import initialize_torch_distributed, Weights +from text_generation_server.utils.hub import weight_files +from text_generation_server.utils.gptq.quant_linear import QuantLinear +from loguru import logger +from typing import Optional + +DEV = torch.device("cuda:0") + + +class Quantizer(nn.Module): + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) + + def configure( + self, + bits, + perchannel=False, + sym=True, + mse=False, + norm=2.4, + grid=100, + maxshrink=0.8, + trits=False, + ): + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + self.scale = torch.zeros_like(self.scale) + + def _quantize(self, x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = self._quantize( + x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq + ) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return self._quantize(x, self.scale, self.zero, self.maxq) + + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +class GPTQ: + def __init__(self, layer, observe=False): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.quantizer = Quantizer() + self.observe = observe + + def add_batch(self, inp, out): + # Hessian H = 2 X XT + λ I + if self.observe: + self.inp1 = inp + self.out1 = out + else: + self.inp1 = None + self.out1 = None + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def print_loss(self, name, q_weight, weight_error, timecost): + table = Texttable() + length = 28 + name = ( + (name + " " * (length - len(name))) + if len(name) <= length + else name[:length] + ) + + table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) + + # assign weight + self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) + + if self.inp1 is not None: + # quantize input to int8 + quantizer = Quantizer() + quantizer.configure(8, perchannel=False, sym=True, mse=False) + quantizer.find_params(self.inp1) + q_in = quantizer.quantize(self.inp1).type(torch.float16) + q_out = self.layer(q_in) + + # get kinds of SNR + q_SNR = torch_snr_error(q_out, self.out1).item() + fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() + else: + q_SNR = "-" + fp_SNR = "-" + + table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) + print(table.draw().split("\n")[-2]) + + def fasterquant( + self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name="" + ): + self.layer.to(self.dev) + + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + if not self.observe: + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if act_order: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + try: + H = torch.linalg.cholesky(H, upper=True) + except Exception: + # Addition because Falcon fails on h_to_4h + H = torch.linalg.cholesky( + H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True + ) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params( + W[:, (i1 + i) : (i1 + i + groupsize)], weight=True + ) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = self.quantizer.quantize(w.unsqueeze(1)).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + error = torch.sum(Losses).item() + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if act_order: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + + self.print_loss( + name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) + ) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + return scale, zero, g_idx, error + + def free(self): + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() + + +def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + use_auth_token=False, + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + use_auth_token=False, + ) + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]["text"], return_tensors="pt") + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + ) + + try: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=False, trust_remote_code=trust_remote_code + ) + except: + tokenizer = AutoTokenizer.from_pretrained( + model_id, use_fast=True, trust_remote_code=trust_remote_code + ) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") + valenc = valenc.input_ids[:, : (256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders( + name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False +): + if "wikitext2" in name: + return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code) + if "ptb" in name: + if "new" in name: + return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code) + return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code) + if "c4" in name: + if "new" in name: + return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code) + return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code) + + +def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""): + # Skip last lm_head linear + # Need isintance Falcon is inheriting Linear. + if isinstance(module, layers) and "lm_head" not in name: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res + + +@torch.no_grad() +def sequential( + model, + dataloader, + dev, + nsamples, + bits, + groupsize, + *, + hooks, + percdamp=0.01, + sym: bool = False, + act_order: bool = False, +): + print("Starting ...") + + use_cache = model.config.use_cache + model.config.use_cache = False + try: + layers = model.model.layers + prefix = "model.layers" + except Exception: + layers = model.transformer.h + prefix = "transformer.h" + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + + cache = {"i": 0} + extra = {} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache["i"]] = inp + cache["i"] += 1 + extra.update(kwargs.copy()) + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].cuda()) + except ValueError: + pass + layers[0] = layers[0].module + + # layers[0] = layers[0].cpu() + # model.model.embed_tokens = model.model.embed_tokens.cpu() + # model.model.norm = model.model.norm.cpu() + torch.cuda.empty_cache() + for hook in hooks: + hook.remove() + + outs = torch.zeros_like(inps) + + extra = { + k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items() + } + + print("Ready.") + + quantizers = {} + for i in range(len(layers)): + print(f"Quantizing layer {i+1}/{len(layers)}..") + print("+------------------+--------------+------------+-----------+-------+") + print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") + print("+==================+==============+============+===========+=======+") + + layer = layers[i] + layer.load() + full = find_layers(layer) + sequential = [list(full.keys())] + + for names in sequential: + subset = {n: full[n] for n in names} + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer.configure( + bits, perchannel=True, sym=sym, mse=False + ) + pass + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] + for h in handles: + h.remove() + + for name in subset: + scale, zero, g_idx, error = gptq[name].fasterquant( + percdamp=percdamp, + groupsize=groupsize, + act_order=act_order, + name=name, + ) + quantizers[f"{prefix}.{i}.{name}"] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + bits, + groupsize, + ) + + gptq[name].free() + + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] + + layer.unload() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + print("+------------------+--------------+------------+-----------+-------+") + print("\n") + + model.config.use_cache = use_cache + + return quantizers + + +def make_quant_linear(module, names, bits, groupsize, name=""): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + "." + attr if name != "" else attr + if name1 in names: + delattr(module, attr) + setattr( + module, + attr, + QuantLinear.new( + bits, + groupsize, + tmp.in_features, + tmp.out_features, + tmp.bias is not None, + ), + ) + for name1, child in module.named_children(): + make_quant_linear( + child, names, bits, groupsize, name + "." + name1 if name != "" else name1 + ) + + +# TODO: perform packing on GPU +def pack(model, quantizers, bits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant_linear(model, quantizers, bits, groupsize) + qlayers = find_layers(model, (QuantLinear,)) + print("Packing ...") + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print("Done.") + return model + + +def setdeepattr(module, full_name, tensor): + current = module + tokens = full_name.split(".") + for token in tokens[:-1]: + current = getattr(current, token) + setattr(current, tokens[-1], tensor) + + +def getdeepattr(module, full_name): + current = module + tokens = full_name.split(".") + for token in tokens: + current = getattr(current, token) + return current + + +def load_weights_pre_hook(module_name, weights, recursive=False): + def inner(module, args): + print(f"Pre hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + + for local_param in local_params: + current_tensor = getdeepattr(module, local_param) + if current_tensor.device == torch.device("meta"): + # print(f"Loading {local_param}") + if module_name: + tensor_name = f"{module_name}.{local_param}" + else: + tensor_name = local_param + tensor = weights.get_tensor(tensor_name) + setdeepattr(module, local_param, nn.Parameter(tensor)) + else: + tensor = current_tensor.to(device=torch.device("cuda:0")) + if current_tensor.requires_grad: + tensor = nn.Parameter(tensor) + setdeepattr(module, local_param, tensor) + + return inner + + +def load_weights_post_hook(module_name, weights, recursive=False): + def inner(module, args, output): + print(f"Post hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for local_param in local_params: + # print(f"Unloading {local_param}") + current_tensor = getdeepattr(module, local_param) + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cpu"))), + ) + return output + + return inner + + +def quantize( + model_id: str, + bits: int, + groupsize: int, + output_dir: str, + revision: str, + trust_remote_code: bool, + upload_to_model_id: Optional[str], + percdamp: float, + act_order: bool, +): + print("loading model") + config = AutoConfig.from_pretrained( + model_id, + trust_remote_code=trust_remote_code, + ) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config( + config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) + model = model.eval() + + print("LOADED model") + files = weight_files(model_id, revision, extension=".safetensors") + process_group, _, _ = initialize_torch_distributed() + weights = Weights( + files, + device=torch.device("cuda:0"), + dtype=torch.float16, + process_group=process_group, + aliases={"embed_tokens.weight": ["lm_head.weight"]}, + ) + hooks = [] + for name, module in model.named_modules(): + + def load(module, name): + def _load(): + load_weights_pre_hook(name, weights, recursive=True)(module, None) + + return _load + + def unload(module, name): + def _unload(): + load_weights_post_hook(name, weights, recursive=True)( + module, None, None + ) + + return _unload + + module.load = load(module, name) + module.unload = unload(module, name) + hooks.append( + module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) + ) + hooks.append( + module.register_forward_hook(load_weights_post_hook(name, weights)) + ) + model.seqlen = 2048 + + dataset = "wikitext2" + nsamples = 128 + seed = None + + dataloader, testloader = get_loaders( + dataset, + nsamples=nsamples, + seed=seed, + model_id=model_id, + seqlen=model.seqlen, + trust_remote_code=trust_remote_code, + ) + + tick = time.time() + quantizers = sequential( + model, + dataloader, + DEV, + nsamples, + bits, + groupsize, + percdamp=percdamp, + act_order=act_order, + hooks=hooks, + ) + print(time.time() - tick) + + pack(model, quantizers, bits, groupsize) + from safetensors.torch import save_file + from transformers.modeling_utils import shard_checkpoint + + state_dict = model.state_dict() + state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} + state_dict["gptq_bits"] = torch.LongTensor([bits]) + state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + + max_shard_size = "10GB" + shards, index = shard_checkpoint( + state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" + ) + os.makedirs(output_dir, exist_ok=True) + for shard_file, shard in shards.items(): + save_file( + shard, + os.path.join(output_dir, shard_file), + metadata={ + "format": "pt", + "quantized": "gptq", + "origin": "text-generation-inference", + }, + ) + if index is None: + path_to_weights = os.path.join(output_dir, "model.safetensors") + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = "model.safetensors.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + config.save_pretrained(output_dir) + logger.info("Saved config") + logger.info("Saving tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + tokenizer.save_pretrained(output_dir) + logger.info("Saved tokenizer") + + if upload_to_model_id: + api = HfApi() + + api.upload_folder( + folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model" + ) diff --git a/server/layers/hqq.py b/server/layers/hqq.py new file mode 100644 index 000000000..0c0d59399 --- /dev/null +++ b/server/layers/hqq.py @@ -0,0 +1,33 @@ +import torch +from torch import nn +HAS_HQQ = True +try: + from hqq.core.quantize import BaseQuantizeConfig, HQQLinear + + class HQQLinearLayer(HQQLinear): + @property + def weight(self) -> torch.Tensor: + return self.W_q + +except ImportError: + HAS_HQQ = False + + +def get_hqq_linear(quantize, weight, bias=None) -> HQQLinearLayer: + if quantize == "hqq-4bit": + quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False) + elif quantize == "hqq-3bit": + quant_config = BaseQuantizeConfig(nbits=3, group_size=64, quant_zero=True, quant_scale=False) + elif quantize == "hqq-2bit": + quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=False) + + # init nn.linear from weight and bias + layer = nn.Linear(weight.shape[1], weight.shape[0], bias=bias is not None) + with torch.no_grad(): + layer.weight.data = weight + if bias is not None: + layer.bias.data = bias + + linear = HQQLinearLayer(layer, quant_config, del_orig=True) + + return linear \ No newline at end of file diff --git a/server/layers/layernorm.py b/server/layers/layernorm.py new file mode 100644 index 000000000..7a7f35e38 --- /dev/null +++ b/server/layers/layernorm.py @@ -0,0 +1,185 @@ +import torch +from torch import nn +from accelerate import init_empty_weights +from lorax_server.utils.import_utils import ( + SYSTEM, +) + + +# Monkey patching +@classmethod +def load_layer_norm(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = torch.nn.Parameter(weight) + ln.bias = torch.nn.Parameter(bias) + return ln + + +@classmethod +def load_layer_norm_no_bias(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = torch.nn.Parameter(weight) + ln.bias = None + return ln + + +torch.nn.LayerNorm.load = load_layer_norm +torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias + +if SYSTEM == "cuda": + import dropout_layer_norm + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": + from vllm._C import ops + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super().forward(hidden_states), residual + +elif SYSTEM == "xpu": + import intel_extension_for_pytorch as ipex + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + res_out = hidden_states + out = ipex.llm.functional.add_layer_norm( + residual, hidden_states, self.weight, self.bias, self.eps, True + ) + if residual is not None: + res_out = residual + return out, res_out + + +class FastRMSNorm(nn.Module): + def __init__(self, weight: torch.Tensor, eps: float): + super().__init__() + + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + @classmethod + def load(cls, prefix, weights, eps=1e-6): + weight = weights.get_tensor(f"{prefix}.weight") + return cls(weight, eps) + + def forward(self, hidden_states, residual=None): + if SYSTEM == "xpu": + residual_out = hidden_states + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + True, + ) + if residual is not None: + residual_out = residual + return out, residual_out + elif hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + elif SYSTEM == "cuda": + # faster post attention rms norm + ( + normed_hidden_states, + res, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + elif SYSTEM == "rocm": + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + out = torch.empty_like(hidden_states) + ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) diff --git a/server/layers/linear.py b/server/layers/linear.py new file mode 100644 index 000000000..692615f09 --- /dev/null +++ b/server/layers/linear.py @@ -0,0 +1,171 @@ +import torch +from torch.nn import functional as F +from lorax_server.utils.import_utils import SYSTEM +from server.layers.gptq.quant_linear import QuantLinear +from server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear + +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + + +class FastLinear(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(weight, requires_grad=False) + if bias is not None: + self.bias = torch.nn.Parameter(bias, requires_grad=False) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + +class FastLinearROCm(torch.nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(weight) + if bias is not None: + self.bias = torch.nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + weight = self.weight + bias = self.bias + + if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: + batched = False + inp_shape = inp.shape + + if inp.dim() == 3: + inp = inp.view(-1, inp_shape[-1]) + batched = True + + m, k = weight.shape[0], inp_shape[1] + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" + ) + if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + _custom_C.LLMM1(weight, inp, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + _custom_C.LLMM1(weight, inp, out, 4) + else: + out = F.linear(inp, weight) + + if batched: + out.view(*inp_shape[:-1], out.shape[-1]) + + if bias is not None: + out = out + bias + return out + return F.linear(inp, self.weight, self.bias) + + +def get_linear(weight, bias, quantize, fan_in_fan_out=False): + # https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out + # Set to True if replacing a Conv1D layer with a Linear layer + if fan_in_fan_out: + weight = weight.T.contiguous() + + if quantize is None: + linear = FastLinear(weight, bias) + elif quantize == "fp8": + from server.layers.fp8 import Fp8Linear + linear = Fp8Linear(weight, bias) + + elif quantize == "bitsandbytes": + from server.layers.bnb import Linear8bitLt + + linear = Linear8bitLt( + weight, + bias, + has_fp16_weights=False, + threshold=6.0, + ) + if bias is not None: + linear.bias = nn.Parameter(bias) + elif quantize == "bitsandbytes-nf4": + from server.layers.bnb import Linear4bit + linear = Linear4bit( + weight, + bias, + quant_type="nf4", + ) + elif quantize == "bitsandbytes-fp4": + from server.layers.bnb import Linear4bit + linear = Linear4bit( + weight, + bias, + quant_type="fp4", + ) + elif quantize == "eetq": + from server.layers.eetq import EETQLinear + linear = EETQLinear(weight, bias) + elif quantize == "gptq": + try: + qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight + except Exception: + raise NotImplementedError("The passed weight is not `gptq` compatible, loader needs to be updated.") + + if use_exllama: + linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + else: + linear = QuantLinear( + qweight, + qzeros, + scales, + g_idx, + bias, + bits, + groupsize, + ) + elif quantize == "awq": + try: + qweight, qzeros, scales, _, bits, groupsize, _ = weight + except Exception: + raise NotImplementedError("The passed weight is not compatible with `awq`") + from server.lorax_server.utils.awq.awq import AWQLinear + linear = AWQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias, + ) + elif "hqq-" in quantize: + from server.layers.hqq import get_hqq_linear + linear = get_hqq_linear(quantize, weight, bias) + else: + raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") + return linear \ No newline at end of file diff --git a/server/layers/rotary.py b/server/layers/rotary.py new file mode 100644 index 000000000..7eba12d0f --- /dev/null +++ b/server/layers/rotary.py @@ -0,0 +1,419 @@ +import os +import torch +from torch import nn + +from lorax_server.utils.import_utils import SYSTEM + +if SYSTEM == "cuda": + from flash_attn.layers.rotary import RotaryEmbedding + import rotary_emb +elif SYSTEM == "rocm": + from vllm._C import ops + + +def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + +def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = { + "type": os.environ["ROPE_SCALING"], + "factor": float(os.environ["ROPE_FACTOR"]), + } + return rope_scaling + return getattr(config, "rope_scaling", None) + + +class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() + self.inv_freq = inv_freq + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if SYSTEM == "cuda": + rotary_dim = cos.shape[-1] + q1 = query[..., :rotary_dim] + q2 = query[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., :rotary_dim] + k2 = key[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif SYSTEM == "rocm": + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + ops.rotary_embedding(query, key, head_size, cos, sin, True) + elif SYSTEM == "xpu": + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), True + ) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + @classmethod + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + scaling_factor = rope_scaling["factor"] + return DynamicPositionRotaryEmbedding( + dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_scaling["type"] == "yarn": + scaling_factor = rope_scaling["factor"] + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1, + ) + elif rope_scaling["type"] == "su": + short_factor = torch.tensor( + rope_scaling["short_factor"], dtype=torch.float32, device=device + ) + short_inv_freq = 1.0 / ( + short_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + long_factor = torch.tensor( + rope_scaling["long_factor"], dtype=torch.float32, device=device + ) + long_inv_freq = 1.0 / ( + long_factor + * base + ** ( + torch.arange(0, dim, 2, device=device, dtype=torch.float32) + / dim + ) + ) + + original_max_position_embeddings = ( + config.original_max_position_embeddings + ) + max_position_embeddings = config.max_position_embeddings + if max_position_embeddings <= original_max_position_embeddings: + scaling_factor = 1.0 + else: + scale = max_position_embeddings / original_max_position_embeddings + scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(original_max_position_embeddings) + ) + + return SuRotaryEmbedding( + short_inv_freq=short_inv_freq, + long_inv_freq=long_inv_freq, + scaling_factor=scaling_factor, + original_max_position_embeddings=original_max_position_embeddings, + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + @classmethod + def load(cls, config, prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_scaling["type"] == "yarn": + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1, + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + if SYSTEM == "rocm": + # For RoCm, we always use float cos/sin to avoid a cast. + # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 + # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. + dtype = torch.float32 + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + + # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. + return cos.unsqueeze(1), sin.unsqueeze(1) + + +class SuRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + short_inv_freq, + long_inv_freq, + scaling_factor, + original_max_position_embeddings, + ): + super(PositionRotaryEmbedding, self).__init__() + self.short_inv_freq = short_inv_freq + self.long_inv_freq = long_inv_freq + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.dynamic_args = None + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + if seqlen > self.original_max_position_embeddings: + inv_freq = self.long_inv_freq + else: + inv_freq = self.short_inv_freq + t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq( + self.dim, newbase, self.inv_freq.device + ) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +# Inverse dim formula to find dim based on number of rotations +import math + + +def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def get_mscale(scale=1): + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings, + base, + device, + scaling_factor, + *, + extrapolation_factor, + attn_factor, + beta_fast, + beta_slow, + ): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + inv_freq_extrapolation = _create_inv_freq( + self.dim, self.base, self.inv_freq.device + ) + freqs = 1.0 / inv_freq_extrapolation + inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) + low, high = find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = ( + 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + + self.inv_freq = inv_freq + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation + + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) diff --git a/server/layers/speculative.py b/server/layers/speculative.py new file mode 100644 index 000000000..4b977a56a --- /dev/null +++ b/server/layers/speculative.py @@ -0,0 +1,52 @@ +import torch +import json +from typing import Tuple, Optional +from text_generation_server.layers.tensor_parallel import TensorParallelHead +from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 +from text_generation_server.layers.mlp import MLPSpeculatorHead + + +class SpeculativeHead(torch.nn.Module): + def __init__(self, lm_head, speculator): + super().__init__() + self.head = lm_head + self.speculator = speculator + + @staticmethod + def load(config, prefix: str, weights): + speculator = config.speculator + if speculator: + speculator_path = config.speculator["path"] + speculator_config = str(speculator_path / "config.json") + + with open(speculator_config, "r") as f: + speculator_config = json.load(f) + + config.speculator_config = speculator_config + try: + architecture = speculator_config["architectures"][0] + + if architecture == "MLPSpeculatorPreTrainedModel": + speculator = MLPSpeculatorHead.load(config, prefix, weights) + else: + speculator = None + except KeyError: + try: + speculator = MedusaHeadV1.load(config, prefix, weights) + except: + speculator = MedusaHeadV2(config, prefix, weights) + lm_head = None + else: + lm_head = TensorParallelHead.load(config, prefix, weights) + speculator = None + return SpeculativeHead(lm_head, speculator) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.speculator is not None: + return self.speculator(input) + + assert self.head is not None + logits = self.head(input) + return logits, None diff --git a/server/layers/tensor_parallel.py b/server/layers/tensor_parallel.py new file mode 100644 index 000000000..6027a48da --- /dev/null +++ b/server/layers/tensor_parallel.py @@ -0,0 +1,187 @@ +import torch +from torch.nn import functional as F +from typing import List +from lorax_server.layers.linear import get_linear, FastLinear + +class SuperLayer(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group, should_gather: bool): + super().__init__(linear) + self.process_group = process_group + self.should_gather = should_gather + + @staticmethod + def load(config, prefix: str, weights): + if weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + + # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) + if config.quantize in ["gptq", "awq", "eetq"]: + quantize = None + else: + quantize = config.quantize + return TensorParallelHead( + get_linear(weight, bias=None, quantize=quantize), + process_group=weights.process_group, + should_gather=should_gather, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.should_gather: + return super().forward(input) + + world_size = self.process_group.size() + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T + + torch.mm(input, self.linear.weight.T, out=local_out) + + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + + if input.shape[0] == 1: + return world_out + return world_out.T + + output = super().forward(input) + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_gate_up(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_gate_up( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_gate_up only implemented without bias") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + return cls.load_multi(config, [prefix], weights, bias, dim=0) + + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) + + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) + self.process_group = process_group + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) + + def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: + out = super().forward(input) + if self.process_group.size() > 1 and reduce: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class TensorParallelEmbedding(torch.nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = (num_embeddings + world_size - 1) // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = weight.shape[ + 0 + ] # Usually block_size, might be less in non even vocab_size. + self.process_group = weights.process_group + self.reduce = reduce + + """Additional 0 entry used for masking""" + self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = torch.nn.functional.embedding(input, self.weight) + if self.reduce and self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + return out diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 5045967af..3cfb2ebd5 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -9,7 +9,6 @@ from torch.nn import functional as F from lorax_server.adapters.types import LORA, MEDUSA -from lorax_server.utils.gptq.quant_linear import QuantLinear from lorax_server.utils.sgmv import ( add_lora_a_bgmv, add_lora_b_bgmv, @@ -20,47 +19,9 @@ ) from lorax_server.utils.state import is_warmup -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params, Params4bit - -except ImportError: - HAS_BITS_AND_BYTES = False - -HAS_AWQ = True -try: - from lorax_server.utils.awq.awq import AWQLinear -except ImportError: - HAS_AWQ = False - -HAS_EETQ = False -try: - from EETQ import quant_weights, w8_a16_gemm +from server.layers.linear import get_linear +from server.layers.tensor_parallel import SuperLayer - HAS_EETQ = True -except ImportError: - pass - -HAS_HQQ = True -try: - from hqq.core.quantize import BaseQuantizeConfig, HQQLinear - - class HQQLinearLayer(HQQLinear): - @property - def weight(self) -> torch.Tensor: - return self.W_q - -except ImportError: - HAS_HQQ = False - -HAS_EXLLAMA = True -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -try: - from lorax_server.utils.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear -except ImportError: - HAS_EXLLAMA = False if TYPE_CHECKING: from lorax_server.adapters import AdapterBatchData @@ -96,411 +57,7 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias -class FastLinear(nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - self.weight = nn.Parameter(weight) - if bias is not None: - self.bias = nn.Parameter(bias) - else: - self.bias = None - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_tensor(f"{prefix}.weight") - if bias: - bias = weights.get_tensor(f"{prefix}.bias") - else: - bias = None - return cls(weight, bias) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, self.weight, self.bias) - - -class FastConv1D(nn.Module): - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - self.weight = nn.Parameter(weight) - if bias is not None: - self.bias = nn.Parameter(bias) - else: - self.bias = None - self.nf = weight.shape[1] - - @classmethod - def load(cls, config, prefix: str, weights): - weight = weights.get_tensor(f"{prefix}.weight") - bias = weights.get_tensor(f"{prefix}.bias") - return cls(weight, bias) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - size_out = input.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, input.view(-1, input.size(-1)), self.weight) - x = x.view(size_out) - return x - - -class EETQLinear(nn.Module): - """ - EETQLinear module applies quantized linear transformation to the input tensor. - - Args: - weight (torch.Tensor): The weight tensor for the linear transformation. - bias (torch.Tensor): The bias tensor for the linear transformation. - - Attributes: - weight (torch.Tensor): The weight tensor for the linear transformation. - scale (torch.Tensor): The scale tensor used for quantization. - bias (torch.Tensor): The bias tensor for the linear transformation. - - """ - - def __init__( - self, - weight, - bias, - ) -> None: - super().__init__() - # Get the device where the weight tensor is currently stored. - device = weight.device - - # Transpose the weight tensor and make a contiguous copy of it on the CPU. - # The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory, - # which can improve performance in some cases. - weight_transposed = torch.t(weight) - weight_contiguous = weight_transposed.contiguous() - weight_cpu = weight_contiguous.cpu() - - # Quantize the weights. The quant_weights function is assumed to perform the quantization. - # The weights are quantized to int8 format, and the quantization is not performed in place (False). - weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False) - - # Move the quantized weights and the scale back to the original device (GPU if available). - # The cuda() function is used to move the tensors to the GPU. - self.weight = weight_quantized.cuda(device) - self.scale = scale.cuda(device) - - # If a bias is present, move it to the GPU as well. If not, set the bias to None. - if bias is not None: - self.bias = bias.cuda(device) - else: - self.bias = None - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - Performs the forward pass of the layer. - - Args: - input (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor. - """ - # The function w8_a16_gemm performs a matrix multiplication operation between the input and the weight of the layer. - # The result is then scaled by a factor (self.scale). - gemm_output = w8_a16_gemm(input, self.weight, self.scale) - - # If a bias is present (i.e., self.bias is not None), it is added to the output of the matrix multiplication. - # If a bias is not present (i.e., self.bias is None), the output of the matrix multiplication is returned as is. - if self.bias is not None: - final_output = gemm_output + self.bias - else: - final_output = gemm_output - - # The final output is returned. - return final_output - - -class Linear8bitLt(nn.Module): - def __init__( - self, - weight, - bias, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - ): - super().__init__() - assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" - self.state = bnb.MatmulLtState() - self.index = index - - # Necessary for stacked layers - self.state.threshold = threshold - self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward - if threshold > 0.0 and not has_fp16_weights: - self.state.use_pool = True - - self.weight = Int8Params( - weight.data, - has_fp16_weights=has_fp16_weights, - requires_grad=has_fp16_weights, - ) - self.weight.cuda(weight.device) - self.bias = bias - - def init_8bit_state(self): - self.state.CB = self.weight.CB - self.state.SCB = self.weight.SCB - self.weight.CB = None - self.weight.SCB = None - - def forward(self, x: torch.Tensor): - self.state.is_training = self.training - if self.weight.CB is not None: - self.init_8bit_state() - - # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - - if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB - return out - - -class Linear4bit(nn.Module): - def __init__(self, weight, bias, quant_type): - super().__init__() - - # Initialize weight with 4-bit quantization - self.weight = Params4bit(weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type) - self.weight.cuda(weight.device) - - # Initialize other attributes - self.compute_dtype = None - self.bias = bias - - def forward(self, x: torch.Tensor): - # Ensure bias has the same dtype as input x - if self.bias is not None and self.bias.dtype != x.dtype: - self.bias.data = self.bias.data.to(x.dtype) - - # Check if quantization state is initialized - if getattr(self.weight, "quant_state", None) is None: - print( - "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." - ) - - # Convert input to compute_dtype if specified - inp_dtype = x.dtype - if self.compute_dtype is not None: - x = x.to(self.compute_dtype) - - # Convert bias to compute_dtype if it exists - bias = None if self.bias is None else self.bias.to(self.compute_dtype) - - # Perform 4-bit matrix multiplication - out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) - - # Convert output back to the input dtype - out = out.to(inp_dtype) - - return out - -def get_linear(weight, bias, quantize, fan_in_fan_out=False): - # https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out - # Set to True if replacing a Conv1D layer with a Linear layer - if fan_in_fan_out: - weight = weight.T.contiguous() - - if quantize is None: - linear = FastLinear(weight, bias) - elif quantize == "bitsandbytes": - linear = Linear8bitLt( - weight, - bias, - has_fp16_weights=False, - threshold=6.0, - ) - if bias is not None: - linear.bias = nn.Parameter(bias) - elif quantize == "bitsandbytes-nf4": - linear = Linear4bit( - weight, - bias, - quant_type="nf4", - ) - elif quantize == "bitsandbytes-fp4": - linear = Linear4bit( - weight, - bias, - quant_type="fp4", - ) - elif quantize == "eetq": - if HAS_EETQ: - linear = EETQLinear(weight, bias) - else: - raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") - elif quantize == "gptq": - try: - qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight - except Exception: - raise NotImplementedError("The passed weight is not `gptq` compatible, loader needs to be updated.") - - if use_exllama: - linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) - else: - linear = QuantLinear( - qweight, - qzeros, - scales, - g_idx, - bias, - bits, - groupsize, - ) - elif quantize == "awq": - try: - qweight, qzeros, scales, _, bits, groupsize, _ = weight - except Exception: - raise NotImplementedError("The passed weight is not compatible with `awq`") - linear = AWQLinear( - w_bit=bits, - group_size=groupsize, - qweight=qweight, - qzeros=qzeros, - scales=scales, - bias=bias, - ) - elif "hqq-" in quantize: - if quantize == "hqq-4bit": - quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False) - elif quantize == "hqq-3bit": - quant_config = BaseQuantizeConfig(nbits=3, group_size=64, quant_zero=True, quant_scale=False) - elif quantize == "hqq-2bit": - quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=False) - - # init nn.linear from weight and bias - layer = nn.Linear(weight.shape[1], weight.shape[0], bias=bias is not None) - with torch.no_grad(): - layer.weight.data = weight - if bias is not None: - layer.bias.data = bias - - linear = HQQLinearLayer(layer, quant_config, del_orig=True) - else: - raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") - return linear - - -class SuperLayer(nn.Module): - def __init__(self, linear): - super().__init__() - self.linear = linear - - def forward(self, x): - return self.linear.forward(x) - - -class TensorParallelHead(SuperLayer): - def __init__(self, linear, process_group, should_gather: bool): - super().__init__(linear) - self.process_group = process_group - self.should_gather = should_gather - - @staticmethod - def load(config, prefix: str, weights): - if weights.process_group.size() > 1: - try: - weight = weights.get_sharded(f"{prefix}.weight", dim=0) - should_gather = True - except AssertionError: - # If the vocab size is not divisible by number of shards - # just load the entire thing. - weight = weights.get_tensor(f"{prefix}.weight") - should_gather = False - else: - weight = weights.get_tensor(f"{prefix}.weight") - should_gather = False - - if config.quantize in ["gptq", "awq", "eetq"]: - quantize = None - else: - quantize = config.quantize - return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize), - process_group=weights.process_group, - should_gather=should_gather, - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - world_size = self.process_group.size() - # Fast branch for single requests - if self.should_gather and len(input.shape) == 2 and isinstance(self.linear, FastLinear) and input.shape[0] == 1: - out_dim = self.linear.weight.shape[0] - - world_out = input.new_empty(1, out_dim * world_size) - local_out = input.new_empty(1, out_dim) - - torch.mm(input, self.linear.weight.T, out=local_out) - - torch.distributed.all_gather_into_tensor(world_out, local_out, group=self.process_group) - return world_out - - output = super().forward(input) - if not self.should_gather: - return output - - world_output = [torch.empty_like(output) for _ in range(world_size)] - torch.distributed.all_gather(world_output, output, group=self.process_group) - world_output = torch.cat(world_output, dim=-1) - return world_output - - -class TensorParallelColumnLinear(SuperLayer): - @classmethod - def load_qkv(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out=False): - """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) - if bias: - raise NotImplementedError("packed_qkv only implemented for baichuan") - else: - bias = None - linear = get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out) - return cls(linear) - - @classmethod - def load(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out: bool = False): - return cls.load_multi(config, [prefix], weights, bias, dim=0, fan_in_fan_out=fan_in_fan_out) - - @classmethod - def load_multi( - cls, - config, - prefixes: List[Union[str, Tuple]], - weights, - bias: bool, - dim: int, - fan_in_fan_out=False, - ): - weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim) - - if bias: - b = weights.get_sharded_list("bias", prefixes, dim=0) - bias = torch.cat(b, dim=dim) - else: - bias = None - linear = get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out) - return cls(linear) class LoraLinear(nn.Module): diff --git a/server/lorax_server/utils/paged_attn.py b/server/lorax_server/utils/paged_attn.py deleted file mode 100644 index 5034def31..000000000 --- a/server/lorax_server/utils/paged_attn.py +++ /dev/null @@ -1,99 +0,0 @@ -import logging - -import torch - -try: - # Hack to ignore ray warnings from vLLM, which are not relevant to us. - logging.disable(logging.WARNING) - from vllm._C import cache_ops - from vllm._C import ops as attention_ops -finally: - logging.disable(logging.NOTSET) - - -_PARTITION_SIZE = 512 - - -def reshape_and_cache( - key: torch.Tensor, # [num_tokens, num_heads, head_size] - value: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] - slot_mapping: torch.Tensor, # [num_tokens] -): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, "auto", 1.0) - - -# Source: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/attention.py -def single_query_cached_kv_attention( - output: torch.Tensor, # [num_tokens, num_heads, head_size] - query: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] - num_key_value_heads: int, - softmax_scale: float, - block_tables: torch.Tensor, # [num_blocks, block_size] - input_lengths: torch.Tensor, # [num_blocks] - max_s: int, -): - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - # Run PagedAttention V1. - attention_ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_key_value_heads, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_key_value_heads, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) From ce4417b8d361feeae90c3db0dff2f18aafabdd9d Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 14:35:23 +0200 Subject: [PATCH 05/32] docker --- .github/workflows/build.yaml | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 116120764..cdbc2debd 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -4,9 +4,10 @@ on: workflow_dispatch: push: branches: - - 'main' + - "main" + - "synctgi" tags: - - 'v*' + - "v*" jobs: build-and-push-image: @@ -41,8 +42,8 @@ jobs: - name: Install soci uses: lerentis/soci-installer@v1.0.1 with: - soci-release: 'v0.4.0' - + soci-release: "v0.4.0" + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2.10.0 @@ -51,7 +52,7 @@ jobs: with: config-inline: | version = 2 - + # persistent data location root = "/runner/build/containerd" @@ -62,11 +63,8 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix=,suffix=,format=short - type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} - + type=raw,value=synctgi + - name: Create a hash from tags env: tags: ${{ steps.meta.outputs.tags }} @@ -93,7 +91,7 @@ jobs: uses: docker/build-push-action@v5 with: context: . - file: ./Dockerfile # Path to your Dockerfile + file: ./Dockerfile # Path to your Dockerfile push: false tags: ${{ steps.meta.outputs.tags }} outputs: type=oci,compression=gzip,dest=${{ steps.vars.outputs.image_path }}-${{ steps.vars.outputs.tag_hash }}.tar.gz @@ -124,7 +122,7 @@ jobs: echo "Pushing $tag to GHCR" sudo ctr i push --user "${{ github.repository_owner }}:${{ secrets.GHCR_PAT }}" $tag done - + - name: Create and push soci index env: tags: ${{ steps.meta.outputs.tags }} @@ -151,4 +149,3 @@ jobs: # Delete the SHA image(s) from containerd store sudo ctr i rm $(sudo ctr i ls -q) - From 262b5f48f6abf69d9e1da1db3f078709639bf1b9 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 14:48:41 +0200 Subject: [PATCH 06/32] rm awq from docker --- Dockerfile | 6 ------ 1 file changed, 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0b49a718e..fd70921ee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -107,12 +107,6 @@ COPY server/exllamav2_kernels/ . # Build specific version of transformers RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build -# Build awq kernels -FROM kernel-builder as awq-kernels-builder -WORKDIR /usr/src -COPY server/awq_kernels/ . -# Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder From 48aecdc556d3df15a92722e0f3f96d5b645429de Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 15:05:21 +0200 Subject: [PATCH 07/32] ruff --- integration-tests/conftest.py | 2 +- .../scripts/dynamic_adapter_loading.py | 2 -- server/exllamav2_kernels/setup.py | 2 +- server/layers/__init__.py | 20 ++++++++--------- server/layers/awq/conversion_utils.py | 2 +- server/layers/awq/quantize/qmodule.py | 4 +--- server/layers/bnb.py | 5 +++-- server/layers/conv.py | 2 +- server/layers/fp8.py | 1 - server/layers/gptq/__init__.py | 5 +++++ server/layers/gptq/exllama.py | 3 ++- server/layers/gptq/exllamav2.py | 3 +-- server/layers/gptq/quant_linear.py | 11 +++++----- server/layers/gptq/quantize.py | 22 +++++++++---------- server/layers/hqq.py | 1 + server/layers/layernorm.py | 3 ++- server/layers/linear.py | 5 +++-- server/layers/rotary.py | 2 +- server/layers/speculative.py | 7 +++--- server/layers/tensor_parallel.py | 7 ++++-- server/lorax_server/utils/flash_attn.py | 5 ++--- .../lorax_server/utils/flash_attn_triton.py | 7 ++---- server/lorax_server/utils/layers.py | 8 +++---- server/lorax_server/utils/paged_attention.py | 4 ++-- server/tests/adapters/test_medusa.py | 3 +-- server/tests/models/test_bloom.py | 10 ++++----- server/tests/models/test_model.py | 1 - server/tests/models/test_santacoder.py | 2 +- server/tests/models/test_seq2seq_lm.py | 7 +++--- server/tests/utils/test_convert.py | 9 ++++---- server/tests/utils/test_hub.py | 5 ++--- server/tests/utils/test_lora.py | 2 +- server/tests/utils/test_segments.py | 3 +-- server/tests/utils/test_sgmv.py | 3 ++- server/tests/utils/test_tokens.py | 4 ++-- server/tests/utils/test_watermark.py | 3 ++- 36 files changed, 92 insertions(+), 93 deletions(-) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index f1ffbf29d..93631127a 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -164,7 +164,7 @@ async def health(self, timeout: int = 60): try: await self.client.generate("test") return - except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: + except (ClientConnectorError, ClientOSError, ServerDisconnectedError): time.sleep(1) raise RuntimeError("Health check failed") diff --git a/integration-tests/scripts/dynamic_adapter_loading.py b/integration-tests/scripts/dynamic_adapter_loading.py index b2502953a..2b9147cc8 100644 --- a/integration-tests/scripts/dynamic_adapter_loading.py +++ b/integration-tests/scripts/dynamic_adapter_loading.py @@ -36,11 +36,9 @@ import collections import concurrent.futures import json -import random import time from urllib.request import Request, urlopen -import numpy as np def query_lorax(args): diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py index 4a16b546f..1dc1ff335 100644 --- a/server/exllamav2_kernels/setup.py +++ b/server/exllamav2_kernels/setup.py @@ -1,6 +1,6 @@ +import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension -import torch extra_cuda_cflags = ["-lineinfo", "-O3"] diff --git a/server/layers/__init__.py b/server/layers/__init__.py index 6428366a3..20d2676e4 100644 --- a/server/layers/__init__.py +++ b/server/layers/__init__.py @@ -1,14 +1,14 @@ -from lorax_server.layers.tensor_parallel import ( - TensorParallelColumnLinear, - TensorParallelRowLinear, - TensorParallelEmbedding, -) +from lorax_server.layers.conv import load_conv2d + +# Just to add the `load` methods. +from lorax_server.layers.layernorm import load_layer_norm from lorax_server.layers.linear import ( - get_linear, FastLinear, + get_linear, ) from lorax_server.layers.speculative import SpeculativeHead - -# Just to add the `load` methods. -from lorax_server.layers.layernorm import load_layer_norm -from lorax_server.layers.conv import load_conv2d +from lorax_server.layers.tensor_parallel import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) diff --git a/server/layers/awq/conversion_utils.py b/server/layers/awq/conversion_utils.py index b19eafbbe..57ba58483 100644 --- a/server/layers/awq/conversion_utils.py +++ b/server/layers/awq/conversion_utils.py @@ -1,6 +1,6 @@ -import torch from typing import List +import torch AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] diff --git a/server/layers/awq/quantize/qmodule.py b/server/layers/awq/quantize/qmodule.py index ca8caf508..6702df786 100644 --- a/server/layers/awq/quantize/qmodule.py +++ b/server/layers/awq/quantize/qmodule.py @@ -1,10 +1,8 @@ # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py -import math +import awq_inference_engine # with CUDA kernels import torch import torch.nn as nn -import awq_inference_engine # with CUDA kernels - # class ScaledActivation(nn.Module): # def __init__(self, module, scales): diff --git a/server/layers/bnb.py b/server/layers/bnb.py index ca39919ce..9ae5f377b 100644 --- a/server/layers/bnb.py +++ b/server/layers/bnb.py @@ -1,8 +1,9 @@ -import torch -from loguru import logger from functools import lru_cache + import bitsandbytes as bnb +import torch from bitsandbytes.nn import Int8Params, Params4bit +from loguru import logger @lru_cache(1) diff --git a/server/layers/conv.py b/server/layers/conv.py index 7fb18ab3f..517b202c9 100644 --- a/server/layers/conv.py +++ b/server/layers/conv.py @@ -1,5 +1,5 @@ -from accelerate import init_empty_weights import torch +from accelerate import init_empty_weights @classmethod diff --git a/server/layers/fp8.py b/server/layers/fp8.py index dd61d0819..8c7f80563 100644 --- a/server/layers/fp8.py +++ b/server/layers/fp8.py @@ -2,7 +2,6 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - device = weight.device # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax diff --git a/server/layers/gptq/__init__.py b/server/layers/gptq/__init__.py index 1c46f4932..a67ccf9f1 100644 --- a/server/layers/gptq/__init__.py +++ b/server/layers/gptq/__init__.py @@ -1,4 +1,5 @@ import os + import torch from text_generation_server.utils.import_utils import ( SYSTEM, @@ -19,6 +20,8 @@ if V2: from text_generation_server.layers.gptq.exllamav2 import ( QuantLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllamav2 import ( create_exllama_buffers, set_device, ) @@ -27,6 +30,8 @@ else: from text_generation_server.layers.gptq.exllama import ( Ex4bitLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllama import ( create_exllama_buffers, set_device, ) diff --git a/server/layers/gptq/exllama.py b/server/layers/gptq/exllama.py index 32f817dba..bb6096e3e 100644 --- a/server/layers/gptq/exllama.py +++ b/server/layers/gptq/exllama.py @@ -1,5 +1,6 @@ import torch -from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params + +from exllama_kernels import make_q4, prepare_buffers, q4_matmul, set_tuning_params # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device="meta") diff --git a/server/layers/gptq/exllamav2.py b/server/layers/gptq/exllamav2.py index 01acbd4dc..d5485bd86 100644 --- a/server/layers/gptq/exllamav2.py +++ b/server/layers/gptq/exllamav2.py @@ -2,11 +2,10 @@ import torch import torch.nn as nn - from loguru import logger try: - from exllamav2_kernels import make_q_matrix, gemm_half_q_half + from exllamav2_kernels import gemm_half_q_half, make_q_matrix except ImportError: logger.error("exllamav2_kernels not installed.") raise diff --git a/server/layers/gptq/quant_linear.py b/server/layers/gptq/quant_linear.py index b52ceb0f7..cd583b741 100644 --- a/server/layers/gptq/quant_linear.py +++ b/server/layers/gptq/quant_linear.py @@ -1,11 +1,12 @@ import math + import numpy as np import torch import torch.nn as nn -from torch.cuda.amp import custom_fwd - import triton import triton.language as tl +from torch.cuda.amp import custom_fwd + from . import custom_autotune @@ -206,10 +207,8 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): output = torch.empty( (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 ) - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) + def grid(META): + return triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), matmul_248_kernel[grid]( input, qweight, diff --git a/server/layers/gptq/quantize.py b/server/layers/gptq/quantize.py index ca113d8fe..2cc80f351 100644 --- a/server/layers/gptq/quantize.py +++ b/server/layers/gptq/quantize.py @@ -1,20 +1,20 @@ -import time -import torch.nn as nn -import math import json +import math import os +import time +from typing import Optional + import torch +import torch.nn as nn import transformers - -from texttable import Texttable -from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer -from huggingface_hub import HfApi from accelerate import init_empty_weights -from text_generation_server.utils import initialize_torch_distributed, Weights -from text_generation_server.utils.hub import weight_files -from text_generation_server.utils.gptq.quant_linear import QuantLinear +from huggingface_hub import HfApi from loguru import logger -from typing import Optional +from text_generation_server.utils import Weights, initialize_torch_distributed +from text_generation_server.utils.gptq.quant_linear import QuantLinear +from text_generation_server.utils.hub import weight_files +from texttable import Texttable +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer DEV = torch.device("cuda:0") diff --git a/server/layers/hqq.py b/server/layers/hqq.py index 0c0d59399..5bbdaa56e 100644 --- a/server/layers/hqq.py +++ b/server/layers/hqq.py @@ -1,5 +1,6 @@ import torch from torch import nn + HAS_HQQ = True try: from hqq.core.quantize import BaseQuantizeConfig, HQQLinear diff --git a/server/layers/layernorm.py b/server/layers/layernorm.py index 7a7f35e38..8a4cfac3e 100644 --- a/server/layers/layernorm.py +++ b/server/layers/layernorm.py @@ -1,6 +1,7 @@ import torch -from torch import nn from accelerate import init_empty_weights +from torch import nn + from lorax_server.utils.import_utils import ( SYSTEM, ) diff --git a/server/layers/linear.py b/server/layers/linear.py index 692615f09..ae9fbbf86 100644 --- a/server/layers/linear.py +++ b/server/layers/linear.py @@ -1,8 +1,9 @@ import torch +from server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear +from server.layers.gptq.quant_linear import QuantLinear from torch.nn import functional as F + from lorax_server.utils.import_utils import SYSTEM -from server.layers.gptq.quant_linear import QuantLinear -from server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear if SYSTEM == "rocm": try: diff --git a/server/layers/rotary.py b/server/layers/rotary.py index 7eba12d0f..e07f48dec 100644 --- a/server/layers/rotary.py +++ b/server/layers/rotary.py @@ -1,11 +1,11 @@ import os + import torch from torch import nn from lorax_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": - from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops diff --git a/server/layers/speculative.py b/server/layers/speculative.py index 4b977a56a..ff57a81d2 100644 --- a/server/layers/speculative.py +++ b/server/layers/speculative.py @@ -1,9 +1,10 @@ -import torch import json -from typing import Tuple, Optional -from text_generation_server.layers.tensor_parallel import TensorParallelHead +from typing import Optional, Tuple + +import torch from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 from text_generation_server.layers.mlp import MLPSpeculatorHead +from text_generation_server.layers.tensor_parallel import TensorParallelHead class SpeculativeHead(torch.nn.Module): diff --git a/server/layers/tensor_parallel.py b/server/layers/tensor_parallel.py index 6027a48da..51904ed3b 100644 --- a/server/layers/tensor_parallel.py +++ b/server/layers/tensor_parallel.py @@ -1,7 +1,10 @@ +from typing import List + import torch from torch.nn import functional as F -from typing import List -from lorax_server.layers.linear import get_linear, FastLinear + +from lorax_server.layers.linear import FastLinear, get_linear + class SuperLayer(torch.nn.Module): def __init__(self, linear): diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index 2287b5b89..00d456c54 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -1,11 +1,10 @@ import os -import torch +import torch from loguru import logger -import math -from lorax_server.utils.import_utils import SYSTEM from lorax_server.utils.flash_attn_triton import triton_attention +from lorax_server.utils.import_utils import SYSTEM if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") diff --git a/server/lorax_server/utils/flash_attn_triton.py b/server/lorax_server/utils/flash_attn_triton.py index 3fe322311..3a6f9a730 100644 --- a/server/lorax_server/utils/flash_attn_triton.py +++ b/server/lorax_server/utils/flash_attn_triton.py @@ -747,11 +747,8 @@ def forward( padded_d_model = 1 << (head_size - 1).bit_length() padded_d_model = max(padded_d_model, 16) - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) + def grid(META): + return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch encoded_softmax = None diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 3cfb2ebd5..8ae6fa139 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -1,10 +1,12 @@ import math import os -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple import torch import torch.distributed from accelerate import init_empty_weights +from server.layers.linear import get_linear +from server.layers.tensor_parallel import SuperLayer from torch import nn from torch.nn import functional as F @@ -19,10 +21,6 @@ ) from lorax_server.utils.state import is_warmup -from server.layers.linear import get_linear -from server.layers.tensor_parallel import SuperLayer - - if TYPE_CHECKING: from lorax_server.adapters import AdapterBatchData from lorax_server.adapters.lora import BatchLoraWeights diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index ad18fc65c..25c501a28 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -1,4 +1,5 @@ import torch + from lorax_server.utils.import_utils import SYSTEM _PARTITION_SIZE = 512 @@ -7,8 +8,7 @@ import intel_extension_for_pytorch as ipex else: try: - from vllm._C import cache_ops - from vllm._C import ops + from vllm._C import cache_ops, ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" diff --git a/server/tests/adapters/test_medusa.py b/server/tests/adapters/test_medusa.py index 120981d99..6614444dc 100644 --- a/server/tests/adapters/test_medusa.py +++ b/server/tests/adapters/test_medusa.py @@ -1,14 +1,13 @@ import torch from lorax_server.adapters.medusa import BatchMedusaWeights, MedusaConfig -from lorax_server.adapters.weights import AdapterBatchMetadata from lorax_server.adapters.utils import download_adapter +from lorax_server.adapters.weights import AdapterBatchMetadata from lorax_server.models.causal_lm import CausalLM from lorax_server.utils.adapter import load_module_map from lorax_server.utils.lora import LM_HEAD from lorax_server.utils.sources import HUB - model_id = "mistralai/Mistral-7B-Instruct-v0.2" adapter_id = "predibase/Mistral-7B-Instruct-v0.2-medusa" diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 5937b9cfe..b656bf921 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -1,13 +1,13 @@ +from copy import copy + import pytest import torch - -from copy import copy from transformers import AutoTokenizer -from lorax_server.pb import generate_pb2 -from lorax_server.models.causal_lm import CausalLMBatch -from lorax_server.utils import weight_hub_files, download_weights from lorax_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from lorax_server.models.causal_lm import CausalLMBatch +from lorax_server.pb import generate_pb2 +from lorax_server.utils import download_weights, weight_hub_files from lorax_server.utils.tokenizer import TokenizerManager diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py index ecd8e52a8..b24d142d9 100644 --- a/server/tests/models/test_model.py +++ b/server/tests/models/test_model.py @@ -1,6 +1,5 @@ import pytest import torch - from transformers import AutoTokenizer from lorax_server.models.model import Model diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index d6efa4af6..1edd9e1c0 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -1,8 +1,8 @@ import pytest -from lorax_server.pb import generate_pb2 from lorax_server.models.causal_lm import CausalLMBatch from lorax_server.models.santacoder import SantaCoder +from lorax_server.pb import generate_pb2 from lorax_server.utils.tokenizer import TokenizerManager diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 38acfb451..b3ac1c874 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -1,12 +1,11 @@ -import pytest -import torch - from copy import copy +import pytest +import torch from transformers import AutoTokenizer -from lorax_server.pb import generate_pb2 from lorax_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch +from lorax_server.pb import generate_pb2 from lorax_server.utils.tokenizer import TokenizerManager diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py index 3c1ef049c..1292bb69d 100644 --- a/server/tests/utils/test_convert.py +++ b/server/tests/utils/test_convert.py @@ -1,15 +1,16 @@ from pathlib import Path + import pytest import torch + +from lorax_server.utils.convert import convert_files +from lorax_server.utils.errors import NanWeightsError from lorax_server.utils.sources.hub import ( download_weights, - weight_hub_files, weight_files, + weight_hub_files, ) -from lorax_server.utils.convert import convert_files -from lorax_server.utils.errors import NanWeightsError - def test_convert_files(): model_id = "bigscience/bloom-560m" diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index 2ec5883ba..6c80c5492 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -1,15 +1,14 @@ import pytest - from huggingface_hub.utils import ( - LocalEntryNotFoundError, EntryNotFoundError, + LocalEntryNotFoundError, RevisionNotFoundError, ) from lorax_server.utils.sources.hub import ( - weight_hub_files, download_weights, weight_files, + weight_hub_files, ) diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index 852d7e359..f7757f833 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -1,7 +1,7 @@ from typing import List from unittest import mock -import pytest +import pytest import torch from peft import LoraConfig diff --git a/server/tests/utils/test_segments.py b/server/tests/utils/test_segments.py index c81cd0892..a5bf3dca2 100644 --- a/server/tests/utils/test_segments.py +++ b/server/tests/utils/test_segments.py @@ -1,8 +1,7 @@ import pytest import torch -from lorax_server.utils.segments import find_segments, SegmentConcatBuilder - +from lorax_server.utils.segments import SegmentConcatBuilder, find_segments @pytest.mark.parametrize( diff --git a/server/tests/utils/test_sgmv.py b/server/tests/utils/test_sgmv.py index 8433428ea..0c535f1b1 100644 --- a/server/tests/utils/test_sgmv.py +++ b/server/tests/utils/test_sgmv.py @@ -1,12 +1,13 @@ from typing import List, Tuple + import pytest import torch from lorax_server.utils.sgmv import ( get_tmp_tensors, + has_sgmv, lora_a_sgmv_cutlass, lora_b_sgmv_cutlass, - has_sgmv, pad_rank, use_cutlass_shrink, ) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 29750accd..929a4e584 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,10 +1,10 @@ from lorax_server.adapters.weights import AdapterBatchData from lorax_server.models.causal_lm import CausalLMBatch from lorax_server.utils.tokens import ( + FinishReason, NextTokenChooser, - StopSequenceCriteria, StoppingCriteria, - FinishReason, + StopSequenceCriteria, ) diff --git a/server/tests/utils/test_watermark.py b/server/tests/utils/test_watermark.py index c354d16e6..481a6dc61 100644 --- a/server/tests/utils/test_watermark.py +++ b/server/tests/utils/test_watermark.py @@ -1,10 +1,11 @@ # test_watermark_logits_processor.py import os + import numpy as np import torch -from lorax_server.utils.watermark import WatermarkLogitsProcessor +from lorax_server.utils.watermark import WatermarkLogitsProcessor GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) DELTA = os.getenv("WATERMARK_DELTA", 2.0) From fe1a883b3dd1a62053d9214dc8fd355b403a4e2a Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 16:25:37 +0200 Subject: [PATCH 08/32] fix docker awq --- Dockerfile | 3 --- 1 file changed, 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index fd70921ee..cdf2b2e1e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -182,9 +182,6 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86 COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from exllama kernels builder COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Copy build artifacts from awq kernels builder -COPY --from=awq-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages From 679f3386203b47deb604dabb6f3d12564fe6ad57 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 18:20:11 +0200 Subject: [PATCH 09/32] fix vllm make --- Dockerfile | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index cdf2b2e1e..a86e85170 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,8 +39,9 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install ARG PYTORCH_VERSION=2.3.0 ARG PYTHON_VERSION=3.10 +# Keep in sync with `server/pyproject.toml ARG CUDA_VERSION=12.1 -ARG MAMBA_VERSION=23.3.1-1 +ARG MAMBA_VERSION=24.3.0-0 ARG CUDA_CHANNEL=nvidia ARG INSTALL_CHANNEL=pytorch # Automatically set by buildx @@ -52,7 +53,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins build-essential \ ca-certificates \ ccache \ - sudo \ curl \ git && \ rm -rf /var/lib/apt/lists/* @@ -73,7 +73,7 @@ RUN chmod +x ~/mambaforge.sh && \ RUN case ${TARGETPLATFORM} in \ "linux/arm64") exit 1 ;; \ *) /opt/conda/bin/conda update -y conda && \ - /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -c anaconda -c conda-forge -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ + /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ esac && \ /opt/conda/bin/conda clean -ya @@ -117,12 +117,11 @@ RUN python setup.py build # Build vllm CUDA kernels FROM kernel-builder as vllm-builder -RUN /opt/conda/bin/conda install packaging WORKDIR /usr/src +ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" COPY server/Makefile-vllm Makefile # Build specific version of vllm -ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" -RUN make build-vllm +RUN make build-vllm-cuda # Build megablocks kernels FROM kernel-builder as megablocks-kernels-builder From 5f0d8b245bb75dc74d87671c4397f2d08540ae48 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 18:25:09 +0200 Subject: [PATCH 10/32] dockerfile --- Dockerfile | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index a86e85170..1f6698e35 100644 --- a/Dockerfile +++ b/Dockerfile @@ -90,23 +90,32 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins FROM kernel-builder as flash-att-builder WORKDIR /usr/src COPY server/Makefile-flash-att Makefile - -# Build specific version of flash attention RUN make build-flash-attention + # Build Flash Attention v2 CUDA kernels FROM kernel-builder as flash-att-v2-builder WORKDIR /usr/src COPY server/Makefile-flash-att-v2 Makefile -# Build specific version of flash attention v2 -RUN make build-flash-attention-v2 +RUN make build-flash-attention-v2-cuda -# Build Transformers exllamav2 kernels +# Build Transformers exllama kernels FROM kernel-builder as exllama-kernels-builder WORKDIR /usr/src +COPY server/exllama_kernels/ . +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build + +# Build Transformers exllama kernels +FROM kernel-builder as exllamav2-kernels-builder +WORKDIR /usr/src COPY server/exllamav2_kernels/ . -# Build specific version of transformers RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +# Build Transformers awq kernels +FROM kernel-builder as awq-kernels-builder +WORKDIR /usr/src +COPY server/Makefile-awq Makefile +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq + # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder @@ -181,6 +190,10 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86 COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from exllama kernels builder COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllamav2 kernels builder +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from awq kernels builder +COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages From 5d4848d3e64d55d071e1489244397b214bf46cec Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 22:04:00 +0200 Subject: [PATCH 11/32] ttf imports --- server/layers/linear.py | 18 +++++++++--------- server/lorax_server/utils/import_utils.py | 2 +- server/lorax_server/utils/layers.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/server/layers/linear.py b/server/layers/linear.py index ae9fbbf86..39e35b27a 100644 --- a/server/layers/linear.py +++ b/server/layers/linear.py @@ -1,6 +1,6 @@ import torch -from server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear -from server.layers.gptq.quant_linear import QuantLinear +from lorax_server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear +from lorax_server.layers.gptq.quant_linear import QuantLinear from torch.nn import functional as F from lorax_server.utils.import_utils import SYSTEM @@ -101,11 +101,11 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): if quantize is None: linear = FastLinear(weight, bias) elif quantize == "fp8": - from server.layers.fp8 import Fp8Linear + from lorax_server.layers.fp8 import Fp8Linear linear = Fp8Linear(weight, bias) elif quantize == "bitsandbytes": - from server.layers.bnb import Linear8bitLt + from lorax_server.layers.bnb import Linear8bitLt linear = Linear8bitLt( weight, @@ -116,21 +116,21 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): if bias is not None: linear.bias = nn.Parameter(bias) elif quantize == "bitsandbytes-nf4": - from server.layers.bnb import Linear4bit + from lorax_server.layers.bnb import Linear4bit linear = Linear4bit( weight, bias, quant_type="nf4", ) elif quantize == "bitsandbytes-fp4": - from server.layers.bnb import Linear4bit + from lorax_server.layers.bnb import Linear4bit linear = Linear4bit( weight, bias, quant_type="fp4", ) elif quantize == "eetq": - from server.layers.eetq import EETQLinear + from lorax_server.layers.eetq import EETQLinear linear = EETQLinear(weight, bias) elif quantize == "gptq": try: @@ -155,7 +155,7 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): qweight, qzeros, scales, _, bits, groupsize, _ = weight except Exception: raise NotImplementedError("The passed weight is not compatible with `awq`") - from server.lorax_server.utils.awq.awq import AWQLinear + from lorax_server.lorax_server.utils.awq.awq import AWQLinear linear = AWQLinear( w_bit=bits, group_size=groupsize, @@ -165,7 +165,7 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): bias=bias, ) elif "hqq-" in quantize: - from server.layers.hqq import get_hqq_linear + from lorax_server.layers.hqq import get_hqq_linear linear = get_hqq_linear(quantize, weight, bias) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") diff --git a/server/lorax_server/utils/import_utils.py b/server/lorax_server/utils/import_utils.py index f54987eb1..36776528b 100644 --- a/server/lorax_server/utils/import_utils.py +++ b/server/lorax_server/utils/import_utils.py @@ -3,7 +3,7 @@ def is_xpu_available(): try: - import intel_extension_for_pytorch + import intel_extension_for_pytorch # noqa except ImportError: return False diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 8ae6fa139..2ba84e24b 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -5,8 +5,8 @@ import torch import torch.distributed from accelerate import init_empty_weights -from server.layers.linear import get_linear -from server.layers.tensor_parallel import SuperLayer +from lorax_server.layers.linear import get_linear +from lorax_server.layers.tensor_parallel import SuperLayer from torch import nn from torch.nn import functional as F From 893a0ffecb814a1a703a16faa927115697c68509 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 22:05:41 +0200 Subject: [PATCH 12/32] ruff --- server/layers/linear.py | 4 ++-- server/lorax_server/utils/layers.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/layers/linear.py b/server/layers/linear.py index 39e35b27a..b0d832a2d 100644 --- a/server/layers/linear.py +++ b/server/layers/linear.py @@ -1,8 +1,8 @@ import torch -from lorax_server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear -from lorax_server.layers.gptq.quant_linear import QuantLinear from torch.nn import functional as F +from lorax_server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear +from lorax_server.layers.gptq.quant_linear import QuantLinear from lorax_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 2ba84e24b..4be11ade3 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -5,12 +5,12 @@ import torch import torch.distributed from accelerate import init_empty_weights -from lorax_server.layers.linear import get_linear -from lorax_server.layers.tensor_parallel import SuperLayer from torch import nn from torch.nn import functional as F from lorax_server.adapters.types import LORA, MEDUSA +from lorax_server.layers.linear import get_linear +from lorax_server.layers.tensor_parallel import SuperLayer from lorax_server.utils.sgmv import ( add_lora_a_bgmv, add_lora_b_bgmv, From 9caaf0f4fe349b7ed2350fd64540eaf3ae400d3f Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 22:14:42 +0200 Subject: [PATCH 13/32] fix layers --- server/{ => lorax_server}/layers/__init__.py | 0 server/{ => lorax_server}/layers/awq/conversion_utils.py | 0 server/{ => lorax_server}/layers/awq/quantize/qmodule.py | 0 server/{ => lorax_server}/layers/bnb.py | 0 server/{ => lorax_server}/layers/conv.py | 0 server/{ => lorax_server}/layers/eetq.py | 0 server/{ => lorax_server}/layers/fp8.py | 0 server/{ => lorax_server}/layers/gptq/__init__.py | 0 server/{ => lorax_server}/layers/gptq/custom_autotune.py | 0 server/{ => lorax_server}/layers/gptq/exllama.py | 0 server/{ => lorax_server}/layers/gptq/exllamav2.py | 0 server/{ => lorax_server}/layers/gptq/quant_linear.py | 0 server/{ => lorax_server}/layers/gptq/quantize.py | 0 server/{ => lorax_server}/layers/hqq.py | 0 server/{ => lorax_server}/layers/layernorm.py | 0 server/{ => lorax_server}/layers/linear.py | 0 server/{ => lorax_server}/layers/rotary.py | 0 server/{ => lorax_server}/layers/speculative.py | 0 server/{ => lorax_server}/layers/tensor_parallel.py | 0 19 files changed, 0 insertions(+), 0 deletions(-) rename server/{ => lorax_server}/layers/__init__.py (100%) rename server/{ => lorax_server}/layers/awq/conversion_utils.py (100%) rename server/{ => lorax_server}/layers/awq/quantize/qmodule.py (100%) rename server/{ => lorax_server}/layers/bnb.py (100%) rename server/{ => lorax_server}/layers/conv.py (100%) rename server/{ => lorax_server}/layers/eetq.py (100%) rename server/{ => lorax_server}/layers/fp8.py (100%) rename server/{ => lorax_server}/layers/gptq/__init__.py (100%) rename server/{ => lorax_server}/layers/gptq/custom_autotune.py (100%) rename server/{ => lorax_server}/layers/gptq/exllama.py (100%) rename server/{ => lorax_server}/layers/gptq/exllamav2.py (100%) rename server/{ => lorax_server}/layers/gptq/quant_linear.py (100%) rename server/{ => lorax_server}/layers/gptq/quantize.py (100%) rename server/{ => lorax_server}/layers/hqq.py (100%) rename server/{ => lorax_server}/layers/layernorm.py (100%) rename server/{ => lorax_server}/layers/linear.py (100%) rename server/{ => lorax_server}/layers/rotary.py (100%) rename server/{ => lorax_server}/layers/speculative.py (100%) rename server/{ => lorax_server}/layers/tensor_parallel.py (100%) diff --git a/server/layers/__init__.py b/server/lorax_server/layers/__init__.py similarity index 100% rename from server/layers/__init__.py rename to server/lorax_server/layers/__init__.py diff --git a/server/layers/awq/conversion_utils.py b/server/lorax_server/layers/awq/conversion_utils.py similarity index 100% rename from server/layers/awq/conversion_utils.py rename to server/lorax_server/layers/awq/conversion_utils.py diff --git a/server/layers/awq/quantize/qmodule.py b/server/lorax_server/layers/awq/quantize/qmodule.py similarity index 100% rename from server/layers/awq/quantize/qmodule.py rename to server/lorax_server/layers/awq/quantize/qmodule.py diff --git a/server/layers/bnb.py b/server/lorax_server/layers/bnb.py similarity index 100% rename from server/layers/bnb.py rename to server/lorax_server/layers/bnb.py diff --git a/server/layers/conv.py b/server/lorax_server/layers/conv.py similarity index 100% rename from server/layers/conv.py rename to server/lorax_server/layers/conv.py diff --git a/server/layers/eetq.py b/server/lorax_server/layers/eetq.py similarity index 100% rename from server/layers/eetq.py rename to server/lorax_server/layers/eetq.py diff --git a/server/layers/fp8.py b/server/lorax_server/layers/fp8.py similarity index 100% rename from server/layers/fp8.py rename to server/lorax_server/layers/fp8.py diff --git a/server/layers/gptq/__init__.py b/server/lorax_server/layers/gptq/__init__.py similarity index 100% rename from server/layers/gptq/__init__.py rename to server/lorax_server/layers/gptq/__init__.py diff --git a/server/layers/gptq/custom_autotune.py b/server/lorax_server/layers/gptq/custom_autotune.py similarity index 100% rename from server/layers/gptq/custom_autotune.py rename to server/lorax_server/layers/gptq/custom_autotune.py diff --git a/server/layers/gptq/exllama.py b/server/lorax_server/layers/gptq/exllama.py similarity index 100% rename from server/layers/gptq/exllama.py rename to server/lorax_server/layers/gptq/exllama.py diff --git a/server/layers/gptq/exllamav2.py b/server/lorax_server/layers/gptq/exllamav2.py similarity index 100% rename from server/layers/gptq/exllamav2.py rename to server/lorax_server/layers/gptq/exllamav2.py diff --git a/server/layers/gptq/quant_linear.py b/server/lorax_server/layers/gptq/quant_linear.py similarity index 100% rename from server/layers/gptq/quant_linear.py rename to server/lorax_server/layers/gptq/quant_linear.py diff --git a/server/layers/gptq/quantize.py b/server/lorax_server/layers/gptq/quantize.py similarity index 100% rename from server/layers/gptq/quantize.py rename to server/lorax_server/layers/gptq/quantize.py diff --git a/server/layers/hqq.py b/server/lorax_server/layers/hqq.py similarity index 100% rename from server/layers/hqq.py rename to server/lorax_server/layers/hqq.py diff --git a/server/layers/layernorm.py b/server/lorax_server/layers/layernorm.py similarity index 100% rename from server/layers/layernorm.py rename to server/lorax_server/layers/layernorm.py diff --git a/server/layers/linear.py b/server/lorax_server/layers/linear.py similarity index 100% rename from server/layers/linear.py rename to server/lorax_server/layers/linear.py diff --git a/server/layers/rotary.py b/server/lorax_server/layers/rotary.py similarity index 100% rename from server/layers/rotary.py rename to server/lorax_server/layers/rotary.py diff --git a/server/layers/speculative.py b/server/lorax_server/layers/speculative.py similarity index 100% rename from server/layers/speculative.py rename to server/lorax_server/layers/speculative.py diff --git a/server/layers/tensor_parallel.py b/server/lorax_server/layers/tensor_parallel.py similarity index 100% rename from server/layers/tensor_parallel.py rename to server/lorax_server/layers/tensor_parallel.py From 0791dfd93b6aabeef0baf0df21e543fd97b91092 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Tue, 21 May 2024 22:23:25 +0200 Subject: [PATCH 14/32] ruff --- clients/python/lorax/__init__.py | 2 +- server/lorax_server/layers/__init__.py | 16 +- server/lorax_server/layers/gptq/__init__.py | 8 +- server/lorax_server/layers/gptq/quantize.py | 1002 ------------------- server/lorax_server/layers/linear.py | 1 + server/lorax_server/layers/rotary.py | 8 +- server/lorax_server/layers/speculative.py | 2 +- 7 files changed, 16 insertions(+), 1023 deletions(-) delete mode 100644 server/lorax_server/layers/gptq/quantize.py diff --git a/clients/python/lorax/__init__.py b/clients/python/lorax/__init__.py index 98851b2f5..1f42bb9bb 100644 --- a/clients/python/lorax/__init__.py +++ b/clients/python/lorax/__init__.py @@ -14,4 +14,4 @@ __version__ = "0.5.1" -from lorax.client import Client, AsyncClient, MergedAdapters +from lorax.client import Client, AsyncClient, MergedAdapters # noqa diff --git a/server/lorax_server/layers/__init__.py b/server/lorax_server/layers/__init__.py index 20d2676e4..6edfeacef 100644 --- a/server/lorax_server/layers/__init__.py +++ b/server/lorax_server/layers/__init__.py @@ -1,14 +1,14 @@ -from lorax_server.layers.conv import load_conv2d +from lorax_server.layers.conv import load_conv2d # noqa # Just to add the `load` methods. -from lorax_server.layers.layernorm import load_layer_norm +from lorax_server.layers.layernorm import load_layer_norm # noqa from lorax_server.layers.linear import ( - FastLinear, - get_linear, + FastLinear, # noqa + get_linear, # noqa ) -from lorax_server.layers.speculative import SpeculativeHead +from lorax_server.layers.speculative import SpeculativeHead # noqa from lorax_server.layers.tensor_parallel import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, + TensorParallelColumnLinear, # noqa + TensorParallelEmbedding, # noqa + TensorParallelRowLinear, # noqa ) diff --git a/server/lorax_server/layers/gptq/__init__.py b/server/lorax_server/layers/gptq/__init__.py index a67ccf9f1..8ad0914a1 100644 --- a/server/lorax_server/layers/gptq/__init__.py +++ b/server/lorax_server/layers/gptq/__init__.py @@ -29,11 +29,11 @@ HAS_EXLLAMA = "2" else: from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, + Ex4bitLinear as ExllamaQuantLinear, # noqa ) from text_generation_server.layers.gptq.exllama import ( - create_exllama_buffers, - set_device, + create_exllama_buffers, # noqa + set_device, # noqa ) HAS_EXLLAMA = "1" @@ -41,4 +41,4 @@ except ImportError: pass -from text_generation_server.layers.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq.quant_linear import QuantLinear # noqa diff --git a/server/lorax_server/layers/gptq/quantize.py b/server/lorax_server/layers/gptq/quantize.py deleted file mode 100644 index 2cc80f351..000000000 --- a/server/lorax_server/layers/gptq/quantize.py +++ /dev/null @@ -1,1002 +0,0 @@ -import json -import math -import os -import time -from typing import Optional - -import torch -import torch.nn as nn -import transformers -from accelerate import init_empty_weights -from huggingface_hub import HfApi -from loguru import logger -from text_generation_server.utils import Weights, initialize_torch_distributed -from text_generation_server.utils.gptq.quant_linear import QuantLinear -from text_generation_server.utils.hub import weight_files -from texttable import Texttable -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer - -DEV = torch.device("cuda:0") - - -class Quantizer(nn.Module): - def __init__(self, shape=1): - super(Quantizer, self).__init__() - self.register_buffer("maxq", torch.tensor(0)) - self.register_buffer("scale", torch.zeros(shape)) - self.register_buffer("zero", torch.zeros(shape)) - - def configure( - self, - bits, - perchannel=False, - sym=True, - mse=False, - norm=2.4, - grid=100, - maxshrink=0.8, - trits=False, - ): - self.maxq = torch.tensor(2**bits - 1) - self.perchannel = perchannel - self.sym = sym - self.mse = mse - self.norm = norm - self.grid = grid - self.maxshrink = maxshrink - if trits: - self.maxq = torch.tensor(-1) - self.scale = torch.zeros_like(self.scale) - - def _quantize(self, x, scale, zero, maxq): - if maxq < 0: - return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero - q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) - return scale * (q - zero) - - def find_params(self, x, weight=False): - dev = x.device - self.maxq = self.maxq.to(dev) - - shape = x.shape - if self.perchannel: - if weight: - x = x.flatten(1) - else: - if len(shape) == 4: - x = x.permute([1, 0, 2, 3]) - x = x.flatten(1) - if len(shape) == 3: - x = x.reshape((-1, shape[-1])).t() - if len(shape) == 2: - x = x.t() - else: - x = x.flatten().unsqueeze(0) - - tmp = torch.zeros(x.shape[0], device=dev) - xmin = torch.minimum(x.min(1)[0], tmp) - xmax = torch.maximum(x.max(1)[0], tmp) - - if self.sym: - xmax = torch.maximum(torch.abs(xmin), xmax) - tmp = xmin < 0 - if torch.any(tmp): - xmin[tmp] = -xmax[tmp] - tmp = (xmin == 0) & (xmax == 0) - xmin[tmp] = -1 - xmax[tmp] = +1 - - if self.maxq < 0: - self.scale = xmax - self.zero = xmin - else: - self.scale = (xmax - xmin) / self.maxq - if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) - else: - self.zero = torch.round(-xmin / self.scale) - - if self.mse: - best = torch.full([x.shape[0]], float("inf"), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = self._quantize( - x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq - ) - q -= x - q.abs_() - q.pow_(self.norm) - err = torch.sum(q, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - if not self.perchannel: - if weight: - tmp = shape[0] - else: - tmp = shape[1] if len(shape) != 3 else shape[2] - self.scale = self.scale.repeat(tmp) - self.zero = self.zero.repeat(tmp) - - if weight: - shape = [-1] + [1] * (len(shape) - 1) - self.scale = self.scale.reshape(shape) - self.zero = self.zero.reshape(shape) - return - if len(shape) == 4: - self.scale = self.scale.reshape((1, -1, 1, 1)) - self.zero = self.zero.reshape((1, -1, 1, 1)) - if len(shape) == 3: - self.scale = self.scale.reshape((1, 1, -1)) - self.zero = self.zero.reshape((1, 1, -1)) - if len(shape) == 2: - self.scale = self.scale.unsqueeze(0) - self.zero = self.zero.unsqueeze(0) - - def quantize(self, x): - if self.ready(): - return self._quantize(x, self.scale, self.zero, self.maxq) - - return x - - def enabled(self): - return self.maxq > 0 - - def ready(self): - return torch.all(self.scale != 0) - - -class GPTQ: - def __init__(self, layer, observe=False): - self.layer = layer - self.dev = self.layer.weight.device - W = layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() - self.rows = W.shape[0] - self.columns = W.shape[1] - self.H = torch.zeros((self.columns, self.columns), device=self.dev) - self.nsamples = 0 - self.quantizer = Quantizer() - self.observe = observe - - def add_batch(self, inp, out): - # Hessian H = 2 X XT + λ I - if self.observe: - self.inp1 = inp - self.out1 = out - else: - self.inp1 = None - self.out1 = None - - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - if isinstance(self.layer, nn.Conv2d): - unfold = nn.Unfold( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.stride, - ) - inp = unfold(inp) - inp = inp.permute([1, 0, 2]) - inp = inp.flatten(1) - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - # inp = inp.float() - inp = math.sqrt(2 / self.nsamples) * inp.float() - # self.H += 2 / self.nsamples * inp.matmul(inp.t()) - self.H += inp.matmul(inp.t()) - - def print_loss(self, name, q_weight, weight_error, timecost): - table = Texttable() - length = 28 - name = ( - (name + " " * (length - len(name))) - if len(name) <= length - else name[:length] - ) - - table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) - - # assign weight - self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( - self.layer.weight.data.dtype - ) - - if self.inp1 is not None: - # quantize input to int8 - quantizer = Quantizer() - quantizer.configure(8, perchannel=False, sym=True, mse=False) - quantizer.find_params(self.inp1) - q_in = quantizer.quantize(self.inp1).type(torch.float16) - q_out = self.layer(q_in) - - # get kinds of SNR - q_SNR = torch_snr_error(q_out, self.out1).item() - fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() - else: - q_SNR = "-" - fp_SNR = "-" - - table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) - print(table.draw().split("\n")[-2]) - - def fasterquant( - self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name="" - ): - self.layer.to(self.dev) - - W = self.layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() - W = W.float() - - tick = time.time() - - if not self.quantizer.ready(): - self.quantizer.find_params(W, weight=True) - - H = self.H - if not self.observe: - del self.H - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - W[:, dead] = 0 - - if act_order: - perm = torch.argsort(torch.diag(H), descending=True) - W = W[:, perm] - H = H[perm][:, perm] - - Losses = torch.zeros_like(W) - Q = torch.zeros_like(W) - - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.columns, device=self.dev) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - try: - H = torch.linalg.cholesky(H, upper=True) - except Exception: - # Addition because Falcon fails on h_to_4h - H = torch.linalg.cholesky( - H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True - ) - Hinv = H - - g_idx = [] - scale = [] - zero = [] - now_idx = 1 - - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - - if groupsize != -1: - if (i1 + i) % groupsize == 0: - self.quantizer.find_params( - W[:, (i1 + i) : (i1 + i + groupsize)], weight=True - ) - - if ((i1 + i) // groupsize) - now_idx == -1: - scale.append(self.quantizer.scale) - zero.append(self.quantizer.zero) - now_idx += 1 - - q = self.quantizer.quantize(w.unsqueeze(1)).flatten() - Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - Err1[:, i] = err1 - - Q[:, i1:i2] = Q1 - Losses[:, i1:i2] = Losses1 / 2 - - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - - torch.cuda.synchronize() - error = torch.sum(Losses).item() - - groupsize = groupsize if groupsize != -1 else self.columns - g_idx = [i // groupsize for i in range(self.columns)] - g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) - if act_order: - invperm = torch.argsort(perm) - Q = Q[:, invperm] - g_idx = g_idx[invperm] - - if isinstance(self.layer, transformers.Conv1D): - Q = Q.t() - - self.print_loss( - name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) - ) - - if scale == []: - scale.append(self.quantizer.scale) - zero.append(self.quantizer.zero) - scale = torch.cat(scale, dim=1) - zero = torch.cat(zero, dim=1) - return scale, zero, g_idx, error - - def free(self): - self.inp1 = None - self.out1 = None - self.H = None - self.Losses = None - self.Trace = None - torch.cuda.empty_cache() - - -def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code): - from datasets import load_dataset - - traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") - testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=False, trust_remote_code=trust_remote_code - ) - except: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=True, trust_remote_code=trust_remote_code - ) - - trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") - testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") - - import random - - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc - - -def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code): - from datasets import load_dataset - - traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") - valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") - - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=False, trust_remote_code=trust_remote_code - ) - except: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=True, trust_remote_code=trust_remote_code - ) - - trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") - testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") - - import random - - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc - - -def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code): - from datasets import load_dataset - - traindata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, - split="train", - use_auth_token=False, - ) - valdata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, - split="validation", - use_auth_token=False, - ) - - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=False, trust_remote_code=trust_remote_code - ) - except: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=True, trust_remote_code=trust_remote_code - ) - - import random - - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - while True: - i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") - if trainenc.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - - import random - - random.seed(0) - valenc = [] - for _ in range(256): - while True: - i = random.randint(0, len(valdata) - 1) - tmp = tokenizer(valdata[i]["text"], return_tensors="pt") - if tmp.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - valenc.append(tmp.input_ids[:, i:j]) - valenc = torch.hstack(valenc) - - class TokenizerWrapper: - def __init__(self, input_ids): - self.input_ids = input_ids - - valenc = TokenizerWrapper(valenc) - - return trainloader, valenc - - -def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code): - from datasets import load_dataset - - traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") - testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") - - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=False, trust_remote_code=trust_remote_code - ) - except: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=True, trust_remote_code=trust_remote_code - ) - - trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") - testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") - - import random - - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc - - -def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): - from datasets import load_dataset - - traindata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, - split="train", - ) - valdata = load_dataset( - "allenai/c4", - "allenai--c4", - data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, - split="validation", - ) - - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=False, trust_remote_code=trust_remote_code - ) - except: - tokenizer = AutoTokenizer.from_pretrained( - model_id, use_fast=True, trust_remote_code=trust_remote_code - ) - - import random - - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - while True: - i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") - if trainenc.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - - valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") - valenc = valenc.input_ids[:, : (256 * seqlen)] - - class TokenizerWrapper: - def __init__(self, input_ids): - self.input_ids = input_ids - - valenc = TokenizerWrapper(valenc) - - return trainloader, valenc - - -def get_loaders( - name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False -): - if "wikitext2" in name: - return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code) - if "ptb" in name: - if "new" in name: - return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code) - return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code) - if "c4" in name: - if "new" in name: - return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code) - return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code) - - -def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""): - # Skip last lm_head linear - # Need isintance Falcon is inheriting Linear. - if isinstance(module, layers) and "lm_head" not in name: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - find_layers( - child, layers=layers, name=name + "." + name1 if name != "" else name1 - ) - ) - return res - - -@torch.no_grad() -def sequential( - model, - dataloader, - dev, - nsamples, - bits, - groupsize, - *, - hooks, - percdamp=0.01, - sym: bool = False, - act_order: bool = False, -): - print("Starting ...") - - use_cache = model.config.use_cache - model.config.use_cache = False - try: - layers = model.model.layers - prefix = "model.layers" - except Exception: - layers = model.transformer.h - prefix = "transformer.h" - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros( - (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev - ) - - cache = {"i": 0} - extra = {} - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps[cache["i"]] = inp - cache["i"] += 1 - extra.update(kwargs.copy()) - raise ValueError - - layers[0] = Catcher(layers[0]) - for batch in dataloader: - try: - model(batch[0].cuda()) - except ValueError: - pass - layers[0] = layers[0].module - - # layers[0] = layers[0].cpu() - # model.model.embed_tokens = model.model.embed_tokens.cpu() - # model.model.norm = model.model.norm.cpu() - torch.cuda.empty_cache() - for hook in hooks: - hook.remove() - - outs = torch.zeros_like(inps) - - extra = { - k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items() - } - - print("Ready.") - - quantizers = {} - for i in range(len(layers)): - print(f"Quantizing layer {i+1}/{len(layers)}..") - print("+------------------+--------------+------------+-----------+-------+") - print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") - print("+==================+==============+============+===========+=======+") - - layer = layers[i] - layer.load() - full = find_layers(layer) - sequential = [list(full.keys())] - - for names in sequential: - subset = {n: full[n] for n in names} - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer.configure( - bits, perchannel=True, sym=sym, mse=False - ) - pass - - def add_batch(name): - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] - for h in handles: - h.remove() - - for name in subset: - scale, zero, g_idx, error = gptq[name].fasterquant( - percdamp=percdamp, - groupsize=groupsize, - act_order=act_order, - name=name, - ) - quantizers[f"{prefix}.{i}.{name}"] = ( - gptq[name].quantizer.cpu(), - scale.cpu(), - zero.cpu(), - g_idx.cpu(), - bits, - groupsize, - ) - - gptq[name].free() - - for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] - - layer.unload() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - print("+------------------+--------------+------------+-----------+-------+") - print("\n") - - model.config.use_cache = use_cache - - return quantizers - - -def make_quant_linear(module, names, bits, groupsize, name=""): - if isinstance(module, QuantLinear): - return - for attr in dir(module): - tmp = getattr(module, attr) - name1 = name + "." + attr if name != "" else attr - if name1 in names: - delattr(module, attr) - setattr( - module, - attr, - QuantLinear.new( - bits, - groupsize, - tmp.in_features, - tmp.out_features, - tmp.bias is not None, - ), - ) - for name1, child in module.named_children(): - make_quant_linear( - child, names, bits, groupsize, name + "." + name1 if name != "" else name1 - ) - - -# TODO: perform packing on GPU -def pack(model, quantizers, bits, groupsize): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - make_quant_linear(model, quantizers, bits, groupsize) - qlayers = find_layers(model, (QuantLinear,)) - print("Packing ...") - for name in qlayers: - print(name) - quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print("Done.") - return model - - -def setdeepattr(module, full_name, tensor): - current = module - tokens = full_name.split(".") - for token in tokens[:-1]: - current = getattr(current, token) - setattr(current, tokens[-1], tensor) - - -def getdeepattr(module, full_name): - current = module - tokens = full_name.split(".") - for token in tokens: - current = getattr(current, token) - return current - - -def load_weights_pre_hook(module_name, weights, recursive=False): - def inner(module, args): - print(f"Pre hook {module_name}") - local_params = {} - for k, v in module.named_parameters(): - if not recursive and k.count(".") != 1: - continue - local_params[k] = v - for k, v in module.named_buffers(): - if not recursive and k.count(".") != 1: - continue - local_params[k] = v - - for local_param in local_params: - current_tensor = getdeepattr(module, local_param) - if current_tensor.device == torch.device("meta"): - # print(f"Loading {local_param}") - if module_name: - tensor_name = f"{module_name}.{local_param}" - else: - tensor_name = local_param - tensor = weights.get_tensor(tensor_name) - setdeepattr(module, local_param, nn.Parameter(tensor)) - else: - tensor = current_tensor.to(device=torch.device("cuda:0")) - if current_tensor.requires_grad: - tensor = nn.Parameter(tensor) - setdeepattr(module, local_param, tensor) - - return inner - - -def load_weights_post_hook(module_name, weights, recursive=False): - def inner(module, args, output): - print(f"Post hook {module_name}") - local_params = {} - for k, v in module.named_parameters(): - if not recursive and k.count(".") != 1: - continue - local_params[k] = v - for k, v in module.named_buffers(): - if not recursive and k.count(".") != 1: - continue - local_params[k] = v - for local_param in local_params: - # print(f"Unloading {local_param}") - current_tensor = getdeepattr(module, local_param) - setdeepattr( - module, - local_param, - nn.Parameter(current_tensor.to(device=torch.device("cpu"))), - ) - return output - - return inner - - -def quantize( - model_id: str, - bits: int, - groupsize: int, - output_dir: str, - revision: str, - trust_remote_code: bool, - upload_to_model_id: Optional[str], - percdamp: float, - act_order: bool, -): - print("loading model") - config = AutoConfig.from_pretrained( - model_id, - trust_remote_code=trust_remote_code, - ) - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code - ) - model = model.eval() - - print("LOADED model") - files = weight_files(model_id, revision, extension=".safetensors") - process_group, _, _ = initialize_torch_distributed() - weights = Weights( - files, - device=torch.device("cuda:0"), - dtype=torch.float16, - process_group=process_group, - aliases={"embed_tokens.weight": ["lm_head.weight"]}, - ) - hooks = [] - for name, module in model.named_modules(): - - def load(module, name): - def _load(): - load_weights_pre_hook(name, weights, recursive=True)(module, None) - - return _load - - def unload(module, name): - def _unload(): - load_weights_post_hook(name, weights, recursive=True)( - module, None, None - ) - - return _unload - - module.load = load(module, name) - module.unload = unload(module, name) - hooks.append( - module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) - ) - hooks.append( - module.register_forward_hook(load_weights_post_hook(name, weights)) - ) - model.seqlen = 2048 - - dataset = "wikitext2" - nsamples = 128 - seed = None - - dataloader, testloader = get_loaders( - dataset, - nsamples=nsamples, - seed=seed, - model_id=model_id, - seqlen=model.seqlen, - trust_remote_code=trust_remote_code, - ) - - tick = time.time() - quantizers = sequential( - model, - dataloader, - DEV, - nsamples, - bits, - groupsize, - percdamp=percdamp, - act_order=act_order, - hooks=hooks, - ) - print(time.time() - tick) - - pack(model, quantizers, bits, groupsize) - from safetensors.torch import save_file - from transformers.modeling_utils import shard_checkpoint - - state_dict = model.state_dict() - state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} - state_dict["gptq_bits"] = torch.LongTensor([bits]) - state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) - - max_shard_size = "10GB" - shards, index = shard_checkpoint( - state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" - ) - os.makedirs(output_dir, exist_ok=True) - for shard_file, shard in shards.items(): - save_file( - shard, - os.path.join(output_dir, shard_file), - metadata={ - "format": "pt", - "quantized": "gptq", - "origin": "text-generation-inference", - }, - ) - if index is None: - path_to_weights = os.path.join(output_dir, "model.safetensors") - logger.info(f"Model weights saved in {path_to_weights}") - else: - save_index_file = "model.safetensors.index.json" - save_index_file = os.path.join(output_dir, save_index_file) - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - logger.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) - config.save_pretrained(output_dir) - logger.info("Saved config") - logger.info("Saving tokenizer") - tokenizer = AutoTokenizer.from_pretrained( - model_id, trust_remote_code=trust_remote_code - ) - tokenizer.save_pretrained(output_dir) - logger.info("Saved tokenizer") - - if upload_to_model_id: - api = HfApi() - - api.upload_folder( - folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model" - ) diff --git a/server/lorax_server/layers/linear.py b/server/lorax_server/layers/linear.py index b0d832a2d..af8510396 100644 --- a/server/lorax_server/layers/linear.py +++ b/server/lorax_server/layers/linear.py @@ -1,4 +1,5 @@ import torch +from torch import nn from torch.nn import functional as F from lorax_server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear diff --git a/server/lorax_server/layers/rotary.py b/server/lorax_server/layers/rotary.py index e07f48dec..49e4aedd3 100644 --- a/server/lorax_server/layers/rotary.py +++ b/server/lorax_server/layers/rotary.py @@ -1,3 +1,4 @@ +import math import os import torch @@ -67,10 +68,6 @@ def forward( # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "xpu": - ipex.llm.functional.rotary_embedding( - query, key, sin, cos, query.size(-1), True - ) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." @@ -314,9 +311,6 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): self._sin_cached = torch.sin(freqs).to(dtype) -# Inverse dim formula to find dim based on number of rotations -import math - def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( diff --git a/server/lorax_server/layers/speculative.py b/server/lorax_server/layers/speculative.py index ff57a81d2..387ae5198 100644 --- a/server/lorax_server/layers/speculative.py +++ b/server/lorax_server/layers/speculative.py @@ -34,7 +34,7 @@ def load(config, prefix: str, weights): except KeyError: try: speculator = MedusaHeadV1.load(config, prefix, weights) - except: + except Exception: speculator = MedusaHeadV2(config, prefix, weights) lm_head = None else: From 573fb32563c8efd7aadb1544d4f5bbb746f9d731 Mon Sep 17 00:00:00 2001 From: Florian Zimmermeister Date: Tue, 21 May 2024 22:32:40 +0200 Subject: [PATCH 15/32] Update __init__.py --- server/lorax_server/layers/gptq/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/layers/gptq/__init__.py b/server/lorax_server/layers/gptq/__init__.py index 8ad0914a1..ea83d195d 100644 --- a/server/lorax_server/layers/gptq/__init__.py +++ b/server/lorax_server/layers/gptq/__init__.py @@ -1,7 +1,7 @@ import os import torch -from text_generation_server.utils.import_utils import ( +from lorax_server.utils.import_utils import ( SYSTEM, ) From fe65875ba030af0e16bae75cfc8cb2e3429316e0 Mon Sep 17 00:00:00 2001 From: Florian Zimmermeister Date: Tue, 21 May 2024 22:34:07 +0200 Subject: [PATCH 16/32] Update __init__.py --- server/lorax_server/layers/gptq/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/layers/gptq/__init__.py b/server/lorax_server/layers/gptq/__init__.py index ea83d195d..de9896ca8 100644 --- a/server/lorax_server/layers/gptq/__init__.py +++ b/server/lorax_server/layers/gptq/__init__.py @@ -18,17 +18,17 @@ elif CAN_EXLLAMA: try: if V2: - from text_generation_server.layers.gptq.exllamav2 import ( + from lorax_server.layers.gptq.exllamav2 import ( QuantLinear as ExllamaQuantLinear, ) - from text_generation_server.layers.gptq.exllamav2 import ( + from lorax_server.layers.gptq.exllamav2 import ( create_exllama_buffers, set_device, ) HAS_EXLLAMA = "2" else: - from text_generation_server.layers.gptq.exllama import ( + from lorax_server.layers.gptq.exllama import ( Ex4bitLinear as ExllamaQuantLinear, # noqa ) from text_generation_server.layers.gptq.exllama import ( @@ -41,4 +41,4 @@ except ImportError: pass -from text_generation_server.layers.gptq.quant_linear import QuantLinear # noqa +from lorax_server.layers.gptq.quant_linear import QuantLinear # noqa From d32e6a9c472d4bc12c73b2b97dc3c05a651e48d5 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 07:34:47 +0200 Subject: [PATCH 17/32] ruff --- server/lorax_server/layers/gptq/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/layers/gptq/__init__.py b/server/lorax_server/layers/gptq/__init__.py index de9896ca8..76d749ef8 100644 --- a/server/lorax_server/layers/gptq/__init__.py +++ b/server/lorax_server/layers/gptq/__init__.py @@ -1,6 +1,7 @@ import os import torch + from lorax_server.utils.import_utils import ( SYSTEM, ) @@ -28,14 +29,15 @@ HAS_EXLLAMA = "2" else: - from lorax_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa - ) from text_generation_server.layers.gptq.exllama import ( create_exllama_buffers, # noqa set_device, # noqa ) + from lorax_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, # noqa + ) + HAS_EXLLAMA = "1" except ImportError: From f420eee12d8fa08bdb3839c8dd38376ccf81592d Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 07:55:55 +0200 Subject: [PATCH 18/32] imports --- server/lorax_server/layers/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/layers/linear.py b/server/lorax_server/layers/linear.py index af8510396..d215004c6 100644 --- a/server/lorax_server/layers/linear.py +++ b/server/lorax_server/layers/linear.py @@ -2,8 +2,6 @@ from torch import nn from torch.nn import functional as F -from lorax_server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear -from lorax_server.layers.gptq.quant_linear import QuantLinear from lorax_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": @@ -140,8 +138,10 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): raise NotImplementedError("The passed weight is not `gptq` compatible, loader needs to be updated.") if use_exllama: + from lorax_server.layers.gptq.exllamav2 import QuantLinear as exllamav2QuantLinear linear = exllamav2QuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) else: + from lorax_server.layers.gptq.quant_linear import QuantLinear linear = QuantLinear( qweight, qzeros, From 1bb903069b51ebc2ae48ad16029540bc0d007169 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 08:02:13 +0200 Subject: [PATCH 19/32] imports --- server/lorax_server/layers/gptq/__init__.py | 2 +- server/lorax_server/layers/speculative.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/layers/gptq/__init__.py b/server/lorax_server/layers/gptq/__init__.py index 76d749ef8..f9aece4ae 100644 --- a/server/lorax_server/layers/gptq/__init__.py +++ b/server/lorax_server/layers/gptq/__init__.py @@ -29,7 +29,7 @@ HAS_EXLLAMA = "2" else: - from text_generation_server.layers.gptq.exllama import ( + from lorax_server.layers.gptq.exllama import ( create_exllama_buffers, # noqa set_device, # noqa ) diff --git a/server/lorax_server/layers/speculative.py b/server/lorax_server/layers/speculative.py index 387ae5198..255a7bc59 100644 --- a/server/lorax_server/layers/speculative.py +++ b/server/lorax_server/layers/speculative.py @@ -2,9 +2,9 @@ from typing import Optional, Tuple import torch -from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 -from text_generation_server.layers.mlp import MLPSpeculatorHead -from text_generation_server.layers.tensor_parallel import TensorParallelHead +from lorax_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 +from lorax_server.layers.mlp import MLPSpeculatorHead +from lorax_server.layers.tensor_parallel import TensorParallelHead class SpeculativeHead(torch.nn.Module): From 357fe3d5356a0f9392b519241a7884a264e80f31 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 08:03:00 +0200 Subject: [PATCH 20/32] ruff --- server/lorax_server/layers/gptq/__init__.py | 7 +++---- server/lorax_server/layers/speculative.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/layers/gptq/__init__.py b/server/lorax_server/layers/gptq/__init__.py index f9aece4ae..a00915462 100644 --- a/server/lorax_server/layers/gptq/__init__.py +++ b/server/lorax_server/layers/gptq/__init__.py @@ -30,12 +30,11 @@ HAS_EXLLAMA = "2" else: from lorax_server.layers.gptq.exllama import ( - create_exllama_buffers, # noqa - set_device, # noqa + Ex4bitLinear as ExllamaQuantLinear, # noqa ) - from lorax_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa + create_exllama_buffers, # noqa + set_device, # noqa ) HAS_EXLLAMA = "1" diff --git a/server/lorax_server/layers/speculative.py b/server/lorax_server/layers/speculative.py index 255a7bc59..01a5ddfc4 100644 --- a/server/lorax_server/layers/speculative.py +++ b/server/lorax_server/layers/speculative.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple import torch + from lorax_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 from lorax_server.layers.mlp import MLPSpeculatorHead from lorax_server.layers.tensor_parallel import TensorParallelHead From 8712ec2de8113f1fbcf544891cf5771e515d46ec Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 08:07:51 +0200 Subject: [PATCH 21/32] fix --- server/lorax_server/layers/__init__.py | 1 - server/lorax_server/layers/speculative.py | 54 ----------------------- 2 files changed, 55 deletions(-) delete mode 100644 server/lorax_server/layers/speculative.py diff --git a/server/lorax_server/layers/__init__.py b/server/lorax_server/layers/__init__.py index 6edfeacef..48b44513a 100644 --- a/server/lorax_server/layers/__init__.py +++ b/server/lorax_server/layers/__init__.py @@ -6,7 +6,6 @@ FastLinear, # noqa get_linear, # noqa ) -from lorax_server.layers.speculative import SpeculativeHead # noqa from lorax_server.layers.tensor_parallel import ( TensorParallelColumnLinear, # noqa TensorParallelEmbedding, # noqa diff --git a/server/lorax_server/layers/speculative.py b/server/lorax_server/layers/speculative.py deleted file mode 100644 index 01a5ddfc4..000000000 --- a/server/lorax_server/layers/speculative.py +++ /dev/null @@ -1,54 +0,0 @@ -import json -from typing import Optional, Tuple - -import torch - -from lorax_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 -from lorax_server.layers.mlp import MLPSpeculatorHead -from lorax_server.layers.tensor_parallel import TensorParallelHead - - -class SpeculativeHead(torch.nn.Module): - def __init__(self, lm_head, speculator): - super().__init__() - self.head = lm_head - self.speculator = speculator - - @staticmethod - def load(config, prefix: str, weights): - speculator = config.speculator - if speculator: - speculator_path = config.speculator["path"] - speculator_config = str(speculator_path / "config.json") - - with open(speculator_config, "r") as f: - speculator_config = json.load(f) - - config.speculator_config = speculator_config - try: - architecture = speculator_config["architectures"][0] - - if architecture == "MLPSpeculatorPreTrainedModel": - speculator = MLPSpeculatorHead.load(config, prefix, weights) - else: - speculator = None - except KeyError: - try: - speculator = MedusaHeadV1.load(config, prefix, weights) - except Exception: - speculator = MedusaHeadV2(config, prefix, weights) - lm_head = None - else: - lm_head = TensorParallelHead.load(config, prefix, weights) - speculator = None - return SpeculativeHead(lm_head, speculator) - - def forward( - self, input: torch.Tensor - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.speculator is not None: - return self.speculator(input) - - assert self.head is not None - logits = self.head(input) - return logits, None From 629986e1513517cf2b36c24ea7163a31e65ed8b3 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 08:13:15 +0200 Subject: [PATCH 22/32] import --- server/lorax_server/adapters/medusa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 8fb4727b9..31baafbbb 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -7,7 +7,7 @@ from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import MEDUSA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights -from lorax_server.utils.layers import FastLinear, TensorParallelColumnLinear +from lorax_server.layers import FastLinear, TensorParallelColumnLinear from lorax_server.utils.segments import find_segments from lorax_server.utils.sgmv import segmented_matmul from lorax_server.utils.state import get_speculative_tokens From 78519ea754f63f590416ee618caecc837909732e Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 08:23:15 +0200 Subject: [PATCH 23/32] import --- server/lorax_server/utils/layers.py | 2 +- server/lorax_server/utils/paged_attention.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 4be11ade3..784653e98 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -10,7 +10,7 @@ from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.layers.linear import get_linear -from lorax_server.layers.tensor_parallel import SuperLayer +from lorax_server.layers.tensor_parallel import SuperLayer, TensorParallelColumnLinear, TensorParallelHead # noqa: F401 from lorax_server.utils.sgmv import ( add_lora_a_bgmv, add_lora_b_bgmv, diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 25c501a28..ab6dc655d 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -14,6 +14,8 @@ f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" ) +fp8_supported = torch.cuda.get_device_capability()[0] >= 9 or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) + def reshape_and_cache( key: torch.Tensor, @@ -28,7 +30,7 @@ def reshape_and_cache( ) else: cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 + key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0 ) From f207d40ba2629a9fcf096868f476203f4251a80b Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 08:32:52 +0200 Subject: [PATCH 24/32] linear layer import in layers.py --- server/lorax_server/utils/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 784653e98..cea4e5bd9 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -9,7 +9,7 @@ from torch.nn import functional as F from lorax_server.adapters.types import LORA, MEDUSA -from lorax_server.layers.linear import get_linear +from lorax_server.layers.linear import FastLinear, get_linear # noqa: F401 from lorax_server.layers.tensor_parallel import SuperLayer, TensorParallelColumnLinear, TensorParallelHead # noqa: F401 from lorax_server.utils.sgmv import ( add_lora_a_bgmv, From 9649fad227dcdc25e3e9cf4f23fb99263ffd0e44 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 09:24:52 +0200 Subject: [PATCH 25/32] add fp8, update FA --- launcher/src/main.rs | 4 ++++ server/Makefile-flash-att-v2 | 4 ++-- server/lorax_server/cli.py | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d5548030e..4073f72c4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -30,6 +30,7 @@ enum Quantization { Hqq_4bit, Hqq_3bit, Hqq_2bit, + Fp8, } impl std::fmt::Display for Quantization { @@ -63,6 +64,9 @@ impl std::fmt::Display for Quantization { Quantization::Hqq_2bit => { write!(f, "hqq-2bit") } + Quantization::Fp8 => { + write!(f, "fp8") + } } } } diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 36ef576ae..bbff00909 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,11 +1,11 @@ -flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 +flash_att_v2_commit_cuda := v2.5.8 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash-attention-v2-cuda: # Clone flash attention pip install -U packaging ninja --no-cache-dir - git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 build-flash-attention-v2-cuda: flash-attention-v2-cuda cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 5fe7532d1..f323a153e 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -22,6 +22,7 @@ class Quantization(str, Enum): hqq_4bit = "hqq-4bit" hqq_3bit = "hqq-3bit" hqq_2bit = "hqq-2bit" + fp8 = "fp8" class Dtype(str, Enum): From 98ac61c206b13a07ee70592deafefc32f1ef9df1 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 09:26:14 +0200 Subject: [PATCH 26/32] fix FA2 check --- .../models/custom_modeling/flash_mistral_modeling.py | 4 ++-- .../models/custom_modeling/flash_mixtral_modeling.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index bc39f6dd6..ab9a69f74 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -30,7 +30,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn, paged_attention -from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2 +from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, @@ -53,7 +53,7 @@ V_PROJ, ) -if not HAS_FLASH_ATTN_V2: +if not HAS_FLASH_ATTN_V2_CUDA: raise ImportError("Mistral model requires flash attn v2") diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 404abd0fe..cd316731d 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -32,7 +32,7 @@ from lorax_server.adapters import AdapterBatchData from lorax_server.utils import flash_attn, paged_attention -from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2 +from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from lorax_server.utils.layers import ( FastLinear, MultiAdapterHead, @@ -47,7 +47,7 @@ ) from lorax_server.utils.lora import LM_HEAD -if not HAS_FLASH_ATTN_V2: +if not HAS_FLASH_ATTN_V2_CUDA: raise ImportError("Mixtral model requires flash attn v2") try: From 6d41fbfcc42e0b2c670bdc0c62c050dae3725220 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 11:33:28 +0200 Subject: [PATCH 27/32] paged attention --- .../models/custom_modeling/flash_dbrx_modeling.py | 3 ++- .../models/custom_modeling/flash_gemma_modeling.py | 4 ++-- .../models/custom_modeling/flash_mistral_modeling.py | 3 +-- .../models/custom_modeling/flash_mixtral_modeling.py | 2 +- .../models/custom_modeling/flash_phi3_modeling.py | 2 +- .../lorax_server/models/custom_modeling/flash_phi_modeling.py | 2 +- .../models/custom_modeling/flash_qwen2_modeling.py | 2 +- .../models/custom_modeling/flash_qwen_modeling.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index 44c01c8b2..7e4d16cc4 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -430,6 +430,7 @@ def forward( cu_seqlen_prefill, max_s, self.softmax_scale, + window_size_left=self.max_past, ) # Decode else: @@ -438,7 +439,7 @@ def forward( query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index b036224c9..6948bbc8c 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -281,19 +281,19 @@ def forward( ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] paged_attention.attention( attn_output, query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, max_s, ) + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index ab9a69f74..3f1a0f5ec 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -339,13 +339,12 @@ def forward( ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] paged_attention.attention( attn_output, query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index cd316731d..e37e31657 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -424,7 +424,7 @@ def forward( query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index aa8b6a302..e3f4c4e8f 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -296,7 +296,7 @@ def forward( query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 8827fd5fd..3339e6068 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -182,7 +182,7 @@ def forward( query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 6b6bf0374..9722ee9ad 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -253,7 +253,7 @@ def forward( query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 0ad3c0cbe..d9bf7b0e3 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -269,7 +269,7 @@ def forward( query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, From a20bc039b0d6830bc964894609fa8800504cde0c Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 12:05:42 +0200 Subject: [PATCH 28/32] fix llama PA --- .../models/custom_modeling/flash_llama_modeling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 810cb8c5b..82d7d736e 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -331,13 +331,14 @@ def forward( query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, max_s, ) + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) From ab80c158f2c45e0ed881005c2829ec677f825214 Mon Sep 17 00:00:00 2001 From: flozi00 Date: Wed, 22 May 2024 12:34:08 +0200 Subject: [PATCH 29/32] awq typo --- server/lorax_server/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/layers/linear.py b/server/lorax_server/layers/linear.py index d215004c6..040612389 100644 --- a/server/lorax_server/layers/linear.py +++ b/server/lorax_server/layers/linear.py @@ -156,7 +156,7 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): qweight, qzeros, scales, _, bits, groupsize, _ = weight except Exception: raise NotImplementedError("The passed weight is not compatible with `awq`") - from lorax_server.lorax_server.utils.awq.awq import AWQLinear + from lorax_server.utils.awq.awq import AWQLinear linear = AWQLinear( w_bit=bits, group_size=groupsize, From feb06bec683aa051e541cae3edb0a87a54ea9fc1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 24 May 2024 15:40:32 -0700 Subject: [PATCH 30/32] Revert build.yaml --- .github/workflows/build.yaml | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index cdbc2debd..116120764 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -4,10 +4,9 @@ on: workflow_dispatch: push: branches: - - "main" - - "synctgi" + - 'main' tags: - - "v*" + - 'v*' jobs: build-and-push-image: @@ -42,8 +41,8 @@ jobs: - name: Install soci uses: lerentis/soci-installer@v1.0.1 with: - soci-release: "v0.4.0" - + soci-release: 'v0.4.0' + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2.10.0 @@ -52,7 +51,7 @@ jobs: with: config-inline: | version = 2 - + # persistent data location root = "/runner/build/containerd" @@ -63,8 +62,11 @@ jobs: images: | ghcr.io/predibase/lorax tags: | - type=raw,value=synctgi - + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix=,suffix=,format=short + type=raw,value=main,enable=${{ github.ref == 'refs/heads/main' }} + - name: Create a hash from tags env: tags: ${{ steps.meta.outputs.tags }} @@ -91,7 +93,7 @@ jobs: uses: docker/build-push-action@v5 with: context: . - file: ./Dockerfile # Path to your Dockerfile + file: ./Dockerfile # Path to your Dockerfile push: false tags: ${{ steps.meta.outputs.tags }} outputs: type=oci,compression=gzip,dest=${{ steps.vars.outputs.image_path }}-${{ steps.vars.outputs.tag_hash }}.tar.gz @@ -122,7 +124,7 @@ jobs: echo "Pushing $tag to GHCR" sudo ctr i push --user "${{ github.repository_owner }}:${{ secrets.GHCR_PAT }}" $tag done - + - name: Create and push soci index env: tags: ${{ steps.meta.outputs.tags }} @@ -149,3 +151,4 @@ jobs: # Delete the SHA image(s) from containerd store sudo ctr i rm $(sudo ctr i ls -q) + From 503af2aee8fed1df098485badec52141efcd1bb9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 24 May 2024 15:44:03 -0700 Subject: [PATCH 31/32] Added back channels --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 1f6698e35..26f8b8712 100644 --- a/Dockerfile +++ b/Dockerfile @@ -73,7 +73,7 @@ RUN chmod +x ~/mambaforge.sh && \ RUN case ${TARGETPLATFORM} in \ "linux/arm64") exit 1 ;; \ *) /opt/conda/bin/conda update -y conda && \ - /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ + /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -c anaconda -c conda-forge -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ esac && \ /opt/conda/bin/conda clean -ya From dae793e13b1fed1215285e7312a40368cee072c9 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 24 May 2024 15:45:25 -0700 Subject: [PATCH 32/32] Revert --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 26f8b8712..1f6698e35 100644 --- a/Dockerfile +++ b/Dockerfile @@ -73,7 +73,7 @@ RUN chmod +x ~/mambaforge.sh && \ RUN case ${TARGETPLATFORM} in \ "linux/arm64") exit 1 ;; \ *) /opt/conda/bin/conda update -y conda && \ - /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -c anaconda -c conda-forge -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ + /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ esac && \ /opt/conda/bin/conda clean -ya