Skip to content

[CUDA] FpA IntB Gemm Kernel Test #25109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tianleiwu
Copy link
Contributor

Enhance MatMulNBits CUDA kernel testing:
(1) Add a kernel testing for different cuda kernels used in MatMulNBits.
(2) Refactoring the gemm profiler to use cuda allocator
(2) Add verbose logging macros.
(3) Adjustments to speed up compiling when sm90 is excluded from build.

Example kernel test output:
image

@tianleiwu tianleiwu marked this pull request as draft June 18, 2025 22:03
@tianleiwu tianleiwu marked this pull request as ready for review June 19, 2025 20:38
@@ -39,7 +40,9 @@
#include "contrib_ops/cuda/llm/cutlass_heuristic.h"
#include "contrib_ops/cuda/llm/cutlass_type_conversion.h"
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h"
#ifndef EXCLUDE_SM_90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is sm90 particularly slow during compilation and not newer ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm90 uses a new set of files here.

Blackwell GPU will fall back to use sm80 kernel.

Since CI pipeline uses sm75 or sm86, so no need to compile sm90. This skip those files and might speed up build.

@@ -374,14 +375,18 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm89,
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
#ifndef EXCLUDE_SM_90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed since you already do a check without macro?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can avoid compiling sm90_dispatch_gemm_to_cutlass in CI pipeline, or when your GPU is not H100/H200.

@@ -67,6 +58,18 @@ void kernel_launcher(int arch, Params& params, cudaStream_t s) {

EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true);
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true);
#endif
} else {
// if (arch >= 89)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reserved for a new op that supports alpha.

#define PRETTY_FUNCTION __PRETTY_FUNCTION__
#endif

#define ORT_LLM_VERBOSE 0 // Set to 1 for verbose, 2 for max verbosity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#ifndef ORT_LLM_VERBOSE
#define ORT_LLM_VERBOSE 0
#endif

So we can externally modify it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants