Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FP6-LLM kernel build on Windows #305

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def read_requirements(file_path):
CUDAExtension,
BuildExtension,
CUDA_HOME,
IS_WINDOWS
)


Expand All @@ -49,21 +50,38 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = ["-fopenmp"]
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
]
}
if not IS_WINDOWS:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah feels like we should be testing this in CI, shouldn't be too hard to use windows machine for cpu but I'm not sure how abundant cuda enabled windows machines are in the github org

extra_link_args = ["-fopenmp"]
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
]
}
else:
extra_link_args = []
extra_compile_args = {
"cxx": [
"/O2" if not debug_mode else "/Od",
"/openmp",
"/permissive-"
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0"
]
}

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["cxx"].append("-g" if not IS_WINDOWS else "/ZI")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
extra_link_args.extend(["-O0", "-g"] if not IS_WINDOWS else ["/DEBUG"])



this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
Expand Down
6 changes: 3 additions & 3 deletions torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
// Trible-Buffer for B Tile
half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#ifdef PIPELINE_LEVEL_SMEM
half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#endif
half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
//
bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter;
// Copying A tile from Global to Register, Bypassing L1, using double-buffer
Expand Down
8 changes: 4 additions & 4 deletions torchao/csrc/cuda/fp6_llm/ptx_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@

#ifdef PIPELINE_LEVEL_SMEM
template <typename TilingConfig>
__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4],
half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
int slice_id) {
__device__ __forceinline__ void B_FromSharedToReg(uint32_t Reg[][4],
half (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
int slice_id) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing some __restrict__ here.
For Reg[][4], I don't know if we can add __restrict__ directly. Otherwise, maybe we need to change it to a pointer (so we can add back __restrict__). From what I know, Reg[][4] is still passed as pointer, but it allows us to do 2d-indexing (last dim is compile-time constant, so it translates to 4 * first_index + second_index).

#ifdef DEBUG_MODE
static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );
#endif
Expand Down Expand Up @@ -113,7 +113,7 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[
#endif

__device__ __forceinline__ void
MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b)
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
{
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
Expand Down
4 changes: 2 additions & 2 deletions torchao/csrc/cuda/fp6_llm/utils_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t (
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales)
{
// Writing registers
Expand All @@ -59,7 +59,7 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales,
int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching
{
Expand Down
10 changes: 5 additions & 5 deletions torchao/csrc/cuda/fp6_llm/utils_gmem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc
* 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads
*/
template<int MaxNumOfLinesToCopy, int BLOCK_WARPS>
__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
bool Pred = true) {
__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
bool Pred = true) {
// static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
const int NumOfGroups = NumOfThreads / 8;
Expand Down
24 changes: 12 additions & 12 deletions torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* Outputs: R1, R2
* Note: Simplified Exponent calculation is applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) {
__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) {
*R2 = *R1 & 0x80808080;
*R1 = *R1 >> 2;
*R1 = *R1 & 0x1f1f1f1f;
Expand All @@ -41,7 +41,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2)
* Outputs: R1, R2
* Note: Simplified Exponent calculation is NOT applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) {
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) {
//*R2 = *R1 & 0x80808080;
*R2 = *R1 & 0xc0c0c0c0;
*R1 = *R1 >> 2;
Expand All @@ -63,7 +63,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_
//*R2 = 0x3c003c00;
}

__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) {
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) {
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
half* FP16_2 = FP16_1 + 1;
uint32_t output;
Expand All @@ -73,16 +73,16 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc
return output;
}

__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4],
u_int32_t __restrict__ *read_RPTR_Frag1,
u_int32_t __restrict__ *read_RPTR_Frag2,
u_int32_t *Scales) {
u_int32_t *OutputRegs = reinterpret_cast<u_int32_t*> (Reg);
u_int32_t *Frag1_PTR = read_RPTR_Frag1;
u_int32_t *Frag2_PTR = read_RPTR_Frag2;
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t Reg[][4],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing __restrict__ here

uint32_t * __restrict__ read_RPTR_Frag1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL about uint32_t vs u_int32_t lol

uint32_t * __restrict__ read_RPTR_Frag2,
uint32_t * Scales) {
uint32_t *OutputRegs = reinterpret_cast<uint32_t*> (Reg);
uint32_t *Frag1_PTR = read_RPTR_Frag1;
uint32_t *Frag2_PTR = read_RPTR_Frag2;
half *Scale_RPTR = reinterpret_cast<half*>(Scales);
u_int32_t Packed_FP6 = 0;
u_int32_t tmp = 0;
uint32_t Packed_FP6 = 0;
uint32_t tmp = 0;
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
#pragma unroll(8)
for(int i=0; i<8; i++) {
Expand Down
84 changes: 42 additions & 42 deletions torchao/utils.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,46 @@
import torch
import torch
import torch.utils.benchmark as benchmark
def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs
def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result
def get_compute_capability():
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
return float(f"{capability[0]}.{capability[1]}")
return 0.0
def skip_if_compute_capability_less_than(min_capability):
import unittest
def decorator(test_func):
def wrapper(*args, **kwargs):
if get_compute_capability() < min_capability:
raise unittest.SkipTest(f"Compute capability is less than {min_capability}")
return test_func(*args, **kwargs)
return wrapper
return decorator


def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these look like linting changes? can't quite see the difference

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change in end of line symbols? iirc, Windows use different symbols for end of line.

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

def get_compute_capability():
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
return float(f"{capability[0]}.{capability[1]}")
return 0.0

def skip_if_compute_capability_less_than(min_capability):
import unittest
def decorator(test_func):
def wrapper(*args, **kwargs):
if get_compute_capability() < min_capability:
raise unittest.SkipTest(f"Compute capability is less than {min_capability}")
return test_func(*args, **kwargs)
return wrapper
return decorator


def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
Expand Down
Loading