diff --git a/hgemm/hgemm.py b/hgemm/hgemm.py index a002ae0b..6122c7e9 100644 --- a/hgemm/hgemm.py +++ b/hgemm/hgemm.py @@ -226,14 +226,16 @@ def run_benchmark(perf_func: callable, args.enable_mma, args.enable_mma_all, args.enable_wmma, args.enable_wmma_all, args.enable_cuda, args.enable_cuda_all, args.enable_torch)): run_benchmark(lib.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c) + if args.enable_torch: + run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)") if args.enable_mma_tn: + MAX_TFLOPS = -1 + print("-" * 68 + "MMA(TN)" + "-" * 55) run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem)", c, stages=3) run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem)", c, stages=2) run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) if not args.disable_cublas_tn: run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b.transpose(1, 0), "tn(cublas)", c) - if args.enable_torch: - run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)") torch.cuda.synchronize() print("-" * 130)