-
Notifications
You must be signed in to change notification settings - Fork 37
Faster int4 gemm #751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster int4 gemm #751
Conversation
stack-info: PR: #751, branch: PaulZhang12/stack/11
2fa6702
to
7ba6841
Compare
stack-info: PR: #751, branch: PaulZhang12/stack/11
7ba6841
to
aa9d60a
Compare
stack-info: PR: #751, branch: PaulZhang12/stack/11
aa9d60a
to
6d0c645
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this make this faster? Can you share the output Triton diff?
examples/int4_gemm.py
Outdated
|
||
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) | ||
block_size_k_packed = hl.register_block_size(K // 2) | ||
hl.register_reduction_dim(block_size_k_packed * 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block_size_k_packed * 2 is the K dimension, of which we want to reduce over in the inner loop for the accumulator
@jansel I have not been able to get this to work yet due to this issue #753 that I was describing in the chat yesterday. Essentially, we don't want to hl.dot here, as that involves a lot of reshaping/splitting for padding, and tensor cores don't support int4 anymore. We want to just do a basic multiply + reduction, which is also done in the reference pt2 implementation |
#809 (currently in review) should unblock this PR. |
@PaulZhang12 #809 is merged - please let me know if you still see any error, thanks! |
stack-info: PR: #751, branch: PaulZhang12/stack/11
6d0c645
to
9b73baf
Compare
stack-info: PR: #751, branch: PaulZhang12/stack/11
9b73baf
to
d5792a1
Compare
test/test_examples.py
Outdated
torch.backends.cuda.matmul.fp32_precision = "tf32" | ||
torch.backends.cudnn.conv.fp32_precision = "tf32" | ||
# torch.backends.cuda.matmul.fp32_precision = "tf32" | ||
# torch.backends.cudnn.conv.fp32_precision = "tf32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops?
stack-info: PR: #751, branch: PaulZhang12/stack/11
d5792a1
to
c9d8521
Compare
x_2d = x.reshape(-1, x.size(-1)) | ||
w_int8 = w.to(torch.int8) | ||
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2) | ||
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious to double-check: do other backends in TritonBench also run this preprocess part outside of the measured kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 yes they do. There are two versions, one preprocessed, one not. The comparisons I did were against the preprocessed pt2 version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @PaulZhang12 !
Stacked PRs:
Faster int4 gemm