Skip to content

Commit

Permalink
Merge pull request #472 from jiewlmrh/master
Browse files Browse the repository at this point in the history
对于int8模型,增加对fp16输入的支持
  • Loading branch information
ztxz16 committed Jul 1, 2024
2 parents 73b6d27 + fe97b29 commit dc8e853
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/devices/cuda/fastllm-cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ bool FastllmCudaBatchMatMulBatch(void **i0s, void **i1s, void **os,
bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v,
const fastllm::Data &mask, const fastllm::Data &output, int group, float scale);
bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);
bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);

void FastllmCudaSetDevice(int gpu_id);
#ifdef __cplusplus
Expand Down
2 changes: 2 additions & 0 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ namespace fastllm {
if (input.dataType == DataType::FLOAT16) {
if (weight.dataType == DataType::FLOAT16) {
FastllmCudaHalfMatMulFloat16(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT8){
FastllmCudaHalfMatMulFloatInt8(input, weight, bias, output, n, m, k);
} else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
}
Expand Down
85 changes: 84 additions & 1 deletion src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,6 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
checkCudaErrors("Error: CUDA error when moving bias to device!", state);
weight.extraCudaData.push_back((void*)cudaBiasData);
}

float *cudaScales = (float*)weight.extraCudaData[0];
uint8_t *cudaZeropoints = (uint8_t*)weight.extraCudaData[1];
float *cudaBiasData = (float*)weight.extraCudaData[2];
Expand Down Expand Up @@ -3768,6 +3767,90 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei
return true;
}

bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) {
float *cudaScales;
cudaError_t state = cudaSuccess;
state = cudaMalloc(&cudaScales, k * sizeof(float));
state = cudaMemcpy(cudaScales, weight.scales.data(), k * sizeof(float), cudaMemcpyHostToDevice);
weight.extraCudaHalfData.push_back((void*)cudaScales);

uint8_t *cudaZeropoints;
state = cudaMalloc(&cudaZeropoints, k);
uint8_t *zeropoints = new uint8_t[k];
for (int i = 0; i < k; i++) {
zeropoints[i] = weight.perChannelsConfigs[i].zeroPoint;
}
state = cudaMemcpy(cudaZeropoints, zeropoints, k, cudaMemcpyHostToDevice);
delete[] zeropoints;
weight.extraCudaHalfData.push_back((void*)cudaZeropoints);

half *cudaBiasData;
state = cudaMalloc(&cudaBiasData, k * sizeof(half));
if (bias.dims.size() > 0) {
float *tempBiasData;
state = cudaMalloc(&tempBiasData, k * sizeof(float));
state = cudaMemcpy(tempBiasData, (uint8_t*)bias.cudaData, k * sizeof(float), cudaMemcpyDeviceToDevice);
int threadPerBlock = std::min(256, k);
FastllmCudaFloat2HalfKernel <<< (k - 1) / threadPerBlock + 1, threadPerBlock>>>(tempBiasData, cudaBiasData, k);
state = cudaFree(tempBiasData);
} else {
state = cudaMemset(cudaBiasData, 0, k * sizeof(half));
}
checkCudaErrors("Error: CUDA error when moving bias to device!", state);
weight.extraCudaHalfData.push_back((void*)cudaBiasData);
}
float *cudaScales = (float*)weight.extraCudaHalfData[0];
uint8_t *cudaZeropoints = (uint8_t*)weight.extraCudaHalfData[1];

half *cudaInput = (half*)FastllmCudaPrepareInput(input);
half *cudaOutput = (half*)FastllmCudaPrepareOutput(output);

auto fastllmCublasHandle = getFastllmCublasHandle();
half *cudaFp16Weight;

cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half));

__half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F;
cublasStatus_t status;

int len = n * m;
int threadPerBlock = std::min(256, len);

len = k * m;

FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaZeropoints,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
&h_alpha, cudaFp16Weight, AType,
m, cudaInput, BType,
m, &h_beta,
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));

if (status != CUBLAS_STATUS_SUCCESS) {
printf("Error: cublas error.\n");
throw("cublas error");
exit(0);
}

if (bias.dims.size() > 0) {
half *cudaBiasData = (half*)weight.extraCudaHalfData[2];
FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k);
}

FastllmCudaFree(cudaFp16Weight);
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

void FastllmCudaSetDevice(int gpu_id) {
cudaSetDevice(gpu_id);
}

0 comments on commit dc8e853

Please sign in to comment.