Skip to content

Conversation

PaulZhang12
Copy link
Contributor

@PaulZhang12 PaulZhang12 commented Oct 1, 2025

Stacked PRs:


Faster int4 gemm

             x_val    preprocessed_eager_int4_gemm-speedup    preprocessed_torch_compile_int4_gemm-speedup    helion_int4_gemm_tritonbench-speedup
------------------  --------------------------------------  ----------------------------------------------  --------------------------------------
(1, 1, 1280, 8192)                                 1.64983                                        14.0875                                 12.9307
(1, 1, 8192, 1024)                                 1.68524                                        16.9107                                 12.9251
(1, 1, 7168, 8192)                                 1.70041                                        19.4632                                 35.083
(1, 1, 8192, 3584)                                 1.67347                                        18.1199                                 28.2511
(4, 1, 1280, 8192)                                 1.64907                                         8.99873                                 9.14968
(4, 1, 8192, 1024)                                 1.69727                                        10.2491                                 11.6143
           average                                 1.67588                                        14.6382                                 18.3256

PaulZhang12 added a commit that referenced this pull request Oct 1, 2025
stack-info: PR: #751, branch: PaulZhang12/stack/11
@PaulZhang12 PaulZhang12 force-pushed the PaulZhang12/stack/11 branch from 2fa6702 to 7ba6841 Compare October 1, 2025 21:28
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 1, 2025
PaulZhang12 added a commit that referenced this pull request Oct 1, 2025
stack-info: PR: #751, branch: PaulZhang12/stack/11
@PaulZhang12 PaulZhang12 force-pushed the PaulZhang12/stack/11 branch from 7ba6841 to aa9d60a Compare October 1, 2025 21:32
PaulZhang12 added a commit that referenced this pull request Oct 1, 2025
stack-info: PR: #751, branch: PaulZhang12/stack/11
@PaulZhang12 PaulZhang12 force-pushed the PaulZhang12/stack/11 branch from aa9d60a to 6d0c645 Compare October 1, 2025 21:38
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this make this faster? Can you share the output Triton diff?


C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
block_size_k_packed = hl.register_block_size(K // 2)
hl.register_reduction_dim(block_size_k_packed * 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_size_k_packed * 2 is the K dimension, of which we want to reduce over in the inner loop for the accumulator

@PaulZhang12
Copy link
Contributor Author

@jansel I have not been able to get this to work yet due to this issue #753 that I was describing in the chat yesterday. Essentially, we don't want to hl.dot here, as that involves a lot of reshaping/splitting for padding, and tensor cores don't support int4 anymore. We want to just do a basic multiply + reduction, which is also done in the reference pt2 implementation

@yf225
Copy link
Contributor

yf225 commented Oct 7, 2025

#809 (currently in review) should unblock this PR.

@yf225
Copy link
Contributor

yf225 commented Oct 7, 2025

@PaulZhang12 #809 is merged - please let me know if you still see any error, thanks!

PaulZhang12 added a commit that referenced this pull request Oct 9, 2025
stack-info: PR: #751, branch: PaulZhang12/stack/11
@PaulZhang12 PaulZhang12 force-pushed the PaulZhang12/stack/11 branch from 6d0c645 to 9b73baf Compare October 9, 2025 20:18
PaulZhang12 added a commit that referenced this pull request Oct 9, 2025
stack-info: PR: #751, branch: PaulZhang12/stack/11
@PaulZhang12 PaulZhang12 force-pushed the PaulZhang12/stack/11 branch from 9b73baf to d5792a1 Compare October 9, 2025 20:19
Comment on lines 19 to 20
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
# torch.backends.cuda.matmul.fp32_precision = "tf32"
# torch.backends.cudnn.conv.fp32_precision = "tf32"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops?

stack-info: PR: #751, branch: PaulZhang12/stack/11
@PaulZhang12 PaulZhang12 force-pushed the PaulZhang12/stack/11 branch from d5792a1 to c9d8521 Compare October 9, 2025 20:37
@PaulZhang12 PaulZhang12 requested review from jansel and yf225 October 9, 2025 21:04
x_2d = x.reshape(-1, x.size(-1))
w_int8 = w.to(torch.int8)
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to double-check: do other backends in TritonBench also run this preprocess part outside of the measured kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 yes they do. There are two versions, one preprocessed, one not. The comparisons I did were against the preprocessed pt2 version

@PaulZhang12 PaulZhang12 requested a review from yf225 October 10, 2025 14:46
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @PaulZhang12 !

@PaulZhang12 PaulZhang12 merged commit 6624d6d into main Oct 10, 2025
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants