-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable FP6-LLM kernel build on Windows #305
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,9 +38,9 @@ | |
|
||
#ifdef PIPELINE_LEVEL_SMEM | ||
template <typename TilingConfig> | ||
__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], | ||
half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], | ||
int slice_id) { | ||
__device__ __forceinline__ void B_FromSharedToReg(uint32_t Reg[][4], | ||
half (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], | ||
int slice_id) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing some |
||
#ifdef DEBUG_MODE | ||
static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); | ||
#endif | ||
|
@@ -113,7 +113,7 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[ | |
#endif | ||
|
||
__device__ __forceinline__ void | ||
MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b) | ||
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b) | ||
{ | ||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" | ||
"{ %0, %1, %2, %3}," | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
* Outputs: R1, R2 | ||
* Note: Simplified Exponent calculation is applied. | ||
*/ | ||
__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) { | ||
__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) { | ||
*R2 = *R1 & 0x80808080; | ||
*R1 = *R1 >> 2; | ||
*R1 = *R1 & 0x1f1f1f1f; | ||
|
@@ -41,7 +41,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) | |
* Outputs: R1, R2 | ||
* Note: Simplified Exponent calculation is NOT applied. | ||
*/ | ||
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) { | ||
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) { | ||
//*R2 = *R1 & 0x80808080; | ||
*R2 = *R1 & 0xc0c0c0c0; | ||
*R1 = *R1 >> 2; | ||
|
@@ -63,7 +63,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_ | |
//*R2 = 0x3c003c00; | ||
} | ||
|
||
__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) { | ||
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { | ||
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair); | ||
half* FP16_2 = FP16_1 + 1; | ||
uint32_t output; | ||
|
@@ -73,16 +73,16 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc | |
return output; | ||
} | ||
|
||
__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], | ||
u_int32_t __restrict__ *read_RPTR_Frag1, | ||
u_int32_t __restrict__ *read_RPTR_Frag2, | ||
u_int32_t *Scales) { | ||
u_int32_t *OutputRegs = reinterpret_cast<u_int32_t*> (Reg); | ||
u_int32_t *Frag1_PTR = read_RPTR_Frag1; | ||
u_int32_t *Frag2_PTR = read_RPTR_Frag2; | ||
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t Reg[][4], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing |
||
uint32_t * __restrict__ read_RPTR_Frag1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TIL about |
||
uint32_t * __restrict__ read_RPTR_Frag2, | ||
uint32_t * Scales) { | ||
uint32_t *OutputRegs = reinterpret_cast<uint32_t*> (Reg); | ||
uint32_t *Frag1_PTR = read_RPTR_Frag1; | ||
uint32_t *Frag2_PTR = read_RPTR_Frag2; | ||
half *Scale_RPTR = reinterpret_cast<half*>(Scales); | ||
u_int32_t Packed_FP6 = 0; | ||
u_int32_t tmp = 0; | ||
uint32_t Packed_FP6 = 0; | ||
uint32_t tmp = 0; | ||
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6 | ||
#pragma unroll(8) | ||
for(int i=0; i<8; i++) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,46 @@ | ||
import torch | ||
import torch | ||
import torch.utils.benchmark as benchmark | ||
def benchmark_model(model, num_runs, input_tensor): | ||
torch.cuda.synchronize() | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
# benchmark | ||
for _ in range(num_runs): | ||
with torch.autograd.profiler.record_function("timed region"): | ||
model(input_tensor) | ||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) / num_runs | ||
def profiler_runner(path, fn, *args, **kwargs): | ||
with torch.profiler.profile( | ||
activities=[torch.profiler.ProfilerActivity.CPU, | ||
torch.profiler.ProfilerActivity.CUDA], | ||
record_shapes=True) as prof: | ||
result = fn(*args, **kwargs) | ||
prof.export_chrome_trace(path) | ||
return result | ||
def get_compute_capability(): | ||
if torch.cuda.is_available(): | ||
capability = torch.cuda.get_device_capability() | ||
return float(f"{capability[0]}.{capability[1]}") | ||
return 0.0 | ||
def skip_if_compute_capability_less_than(min_capability): | ||
import unittest | ||
def decorator(test_func): | ||
def wrapper(*args, **kwargs): | ||
if get_compute_capability() < min_capability: | ||
raise unittest.SkipTest(f"Compute capability is less than {min_capability}") | ||
return test_func(*args, **kwargs) | ||
return wrapper | ||
return decorator | ||
|
||
|
||
def benchmark_model(model, num_runs, input_tensor): | ||
torch.cuda.synchronize() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these look like linting changes? can't quite see the difference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe change in end of line symbols? iirc, Windows use different symbols for end of line. |
||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
|
||
# benchmark | ||
for _ in range(num_runs): | ||
with torch.autograd.profiler.record_function("timed region"): | ||
model(input_tensor) | ||
|
||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) / num_runs | ||
|
||
def profiler_runner(path, fn, *args, **kwargs): | ||
with torch.profiler.profile( | ||
activities=[torch.profiler.ProfilerActivity.CPU, | ||
torch.profiler.ProfilerActivity.CUDA], | ||
record_shapes=True) as prof: | ||
result = fn(*args, **kwargs) | ||
prof.export_chrome_trace(path) | ||
return result | ||
|
||
def get_compute_capability(): | ||
if torch.cuda.is_available(): | ||
capability = torch.cuda.get_device_capability() | ||
return float(f"{capability[0]}.{capability[1]}") | ||
return 0.0 | ||
|
||
def skip_if_compute_capability_less_than(min_capability): | ||
import unittest | ||
def decorator(test_func): | ||
def wrapper(*args, **kwargs): | ||
if get_compute_capability() < min_capability: | ||
raise unittest.SkipTest(f"Compute capability is less than {min_capability}") | ||
return test_func(*args, **kwargs) | ||
return wrapper | ||
return decorator | ||
|
||
|
||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah feels like we should be testing this in CI, shouldn't be too hard to use windows machine for cpu but I'm not sure how abundant cuda enabled windows machines are in the github org