New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[quant] Add 4-bit embedding_bag prepack/unpack support using quint4x2 #45751
Conversation
Summary: Use the torch.quint4x2 dtype to create 4-bit packed tensors Test Plan: python test/test_quantization.py TestEmbeddingBagOps Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/supriyar/191/base #45751 +/- ##
=======================================================
Coverage ? 68.20%
=======================================================
Files ? 410
Lines ? 53245
Branches ? 0
=======================================================
Hits ? 36314
Misses ? 16931
Partials ? 0 Continue to review full report at Codecov.
|
weight_data = | ||
reinterpret_cast<uint8_t*>(weight_contig.data_ptr<c10::quint8>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember data_ptr does not do type check, so maybe you can just do weight_contig.data_ptr<uint8_t*>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is applicable for qint types. It throws the error "expected scalar type Byte but found QUInt8"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry, I meant this: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/templates/TensorBody.h#L354
so:
static_cast<uint8_t*>(weight_contig.data_ptr())
zero_points.toType(c10::kFloat), | ||
0, // The output channel axis is 0 | ||
device(c10::kCPU).dtype(c10::kQUInt4x2)); | ||
output_data = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
…ng quint4x2" Summary: Use the torch.quint4x2 dtype to create 4-bit packed tensors Test Plan: python test/test_quantization.py TestEmbeddingBagOps Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
at::parallel_for( | ||
0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) { | ||
for (int64_t row = start_idx; row < end_idx; ++row) { | ||
const uint8_t* input_row = weight_data + row * embedding_cols; | ||
std::uint8_t* output_row = output_data + row * output_columns; | ||
at::Half* output_row_scale_bias = | ||
reinterpret_cast<at::Half*>(output_row + embedding_cols); | ||
output_row_scale_bias[0] = weight_scales[row]; | ||
output_row_scale_bias[1] = weight_bias[row]; | ||
for (int64_t col = 0; col < embedding_cols; ++col) { | ||
// The weight values have already been packed, so here we just | ||
// store it in the output tensor. | ||
output_row[col] = input_row[col]; | ||
} | ||
} | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: seems like this could be reused with the above section if the at::Half
is templatized? Def optional though since LOC is low.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think there would be similar LOC added in templatizing it so planning to skip it for now.
…ng quint4x2" Summary: Use the torch.quint4x2 dtype to create 4-bit packed tensors Test Plan: python test/test_quantization.py TestEmbeddingBagOps Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D24120997](https://our.internmc.facebook.com/intern/diff/D24120997) [ghstack-poisoned]
…ng quint4x2" Summary: Use the torch.quint4x2 dtype to create 4-bit packed tensors Test Plan: python test/test_quantization.py TestEmbeddingBagOps Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D24120997](https://our.internmc.facebook.com/intern/diff/D24120997) [ghstack-poisoned]
This pull request has been merged in 5c283fa. |
Stack from ghstack:
Summary:
Use the torch.quint4x2 dtype to create 4-bit packed tensors
Test Plan:
python test/test_quantization.py TestEmbeddingBagOps
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D24120997