Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions hgemm/hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ def get_device_capability():
TOATL_TFLOPS: dict[str, float] = {}
CUBLAS_TOTAL_TFLOPS = 0


def make_block_swizzle_stride(N: int, K: int):
# make swizzle stride as N/8,N/4,N/2 and multiples of 256
if args.swizzle_factor is None:
swizzle_factor = 0.5 if N <= 4096 else 0.25
if all((N >= 14848, K > 8192, N % 8 == 0)):
swizzle_factor = 0.125
else:
swizzle_factor = args.swizzle_factor

swizzle_stride = int(N * swizzle_factor)
swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1

return swizzle_stride


def run_benchmark(perf_func: callable,
a: torch.Tensor, b: torch.Tensor,
tag: str, out: Optional[torch.Tensor] = None,
Expand All @@ -121,13 +137,7 @@ def run_benchmark(perf_func: callable,
if 'tn' in tag:
N = b.size(0)
if swizzle:
# make swizzle stride as N/4 or N/2 and multiples of 256
if args.swizzle_factor is None:
swizzle_factor = 0.5 if N <= 4096 else 0.25
else:
swizzle_factor = args.swizzle_factor
swizzle_stride = int((int(N * swizzle_factor) // 256) * 256)
swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1
swizzle_stride = make_block_swizzle_stride(N, K)
swizzle = swizzle if swizzle_stride >= 256 else False
else:
swizzle_stride = 1 # means no thread block swizzle
Expand Down Expand Up @@ -187,7 +197,6 @@ def run_benchmark(perf_func: callable,
print(f"{out_info:>42}: {out_val}, time:{mean_time}ms, "
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}")
if show_matrix: print(out)
time.sleep(args.sleep_duration)
if args.plot_flops:
STATIS_INFO[tag] = STATIS_INFO.get(tag, [])
STATIS_INFO[tag].append(TFLOPS)
Expand All @@ -196,6 +205,9 @@ def run_benchmark(perf_func: callable,
else:
global CUBLAS_TOTAL_TFLOPS
CUBLAS_TOTAL_TFLOPS += TFLOPS

torch.cuda.synchronize()
time.sleep(args.sleep_duration)
return out, mean_time


Expand Down
10 changes: 5 additions & 5 deletions hgemm/hgemm_mma_stage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2014,7 +2014,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(
// s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB
// s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~17KB
Expand Down Expand Up @@ -2144,7 +2144,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(
// s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB
// s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~17KB
Expand Down Expand Up @@ -2275,7 +2275,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(
// s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
// s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~35KB
Expand Down Expand Up @@ -2407,7 +2407,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(
// s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
// s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~35KB
Expand Down Expand Up @@ -2540,7 +2540,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(
// s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
// s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~35KB
Expand Down
2 changes: 1 addition & 1 deletion hgemm/hgemm_mma_stage_tn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(
constexpr int BK = MMA_K;

if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2:
Expand Down
10 changes: 5 additions & 5 deletions hgemm/hgemm_wmma_stage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(
// s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB
// s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~17KB
Expand Down Expand Up @@ -1158,7 +1158,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(
// s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB
// s6: 6*128*(16)*2=24KB, 6*16*(128+16)*2=27KB, ~51KB > 48KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~17KB
Expand Down Expand Up @@ -1293,7 +1293,7 @@ void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(
// s3: 3*256*(16)*2=24KB, 3*16*(256+16)*2=25.5KB, ~50KB > 48KB
// s4: 4*256*(16)*2=32KB, 4*16*(256+16)*2=34KB, ~66KB
if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2: // ~33KB
Expand Down Expand Up @@ -1418,7 +1418,7 @@ void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(
constexpr int BK = WMMA_K * WARP_TILE_K;

if (swizzle) {
assert(swizzle_stride % 256 == 0);
// assert(swizzle_stride % 256 == 0);
switch (stages)
{
case 2:
Expand Down Expand Up @@ -1457,4 +1457,4 @@ void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(
break;
}
}
}
}