Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/quantization/fp8/per_token_group_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"

__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
__device__ __forceinline__ float GroupReduceMax(float val) {
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;

val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
Expand Down Expand Up @@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
threads_per_group, // stride in group
scalar_op_cache); // scalar handler

local_absmax = GroupReduceMax(local_absmax, lane_id);
local_absmax = GroupReduceMax(local_absmax);

float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
Expand Down