-
Notifications
You must be signed in to change notification settings - Fork 317
Add test case generator for groupwise low bit LUT based quantization #2359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2359
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 PendingAs of commit fdce227 with merge base d72a6d1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
const int lut_size = 1 << weight_nbit; | ||
|
||
// Generate random quantized indices (this remains the same) | ||
auto weight_qvals = std::vector<uint8_t>(total_weights); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use get_random_lowbit_vector for this?
std::vector<float> base_codebook(lut_size); | ||
float start_val = -(static_cast<float>(lut_size) / 2.0f) + 0.5f; | ||
for(int i = 0; i < lut_size; ++i) { | ||
base_codebook[i] = start_val + i; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't base codebook just be output of get_random_vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can. I tried to make it simpler as the scale is randomly generated.
} | ||
|
||
// 2c. Create the final LUTs by scaling the base codebook for each group | ||
std::vector<float> weight_luts(num_weight_groups * lut_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two group sizes here. There is a group size for the LUT (e.g., if we have 2 LUTs for 100 values, then the lut_group_size is 50; you could also represent this with n_luts).
There is also the group size for the scale. For example, if we have 4 scales for 100 values, then the scale_group_size is 25).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used the single LUT for save all the luts. Will update this part.
float activation_val = activations[m_idx * k + k_idx]; | ||
int weight_idx = n_idx * k + k_idx; | ||
int group_idx = weight_idx / weight_group_size; | ||
uint8_t lut_index = weight_qvals[weight_idx]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight_qvals looks more like weight_qval_idxs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, gonna do some renaming for consistentcy
int weight_idx = n_idx * k + k_idx; | ||
int group_idx = weight_idx / weight_group_size; | ||
uint8_t lut_index = weight_qvals[weight_idx]; | ||
float weight_dequant_val = weight_luts[group_idx * lut_size + lut_index]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the scale applied?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale applied on line 672. I can move it here.
*/ | ||
static groupwise_lowbit_weight_lut_test_case generate_with_grouping( | ||
int m, int k, int n, | ||
int weight_group_size, int scale_group_size, int lut_group_size, int weight_nbit, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is weight_group_size?
for (int i = 0; i < num_weight_groups; ++i) group_to_lut_map[i] = lut_map_dis(gen); | ||
|
||
// 2c. Generate random quantized indices for each weight. | ||
auto weight_qval_indices = std::vector<uint8_t>(total_weights); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we use get_random_lowbit_vector?
…control the frequency of group change.
// 2b. Generate random quantized indices for each weight. | ||
auto weight_qval_indices = std::vector<uint8_t>(total_weights); | ||
std::uniform_int_distribution<int> qval_dis(0, lut_size - 1); | ||
for (int i = 0; i < total_weights; ++i) weight_qval_indices[i] = static_cast<uint8_t>(qval_dis(gen)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szyszyzys why can't we use get_random_lowbit_vector here?
for (int k_idx = 0; k_idx < k; ++k_idx) { | ||
float activation_val = activations[m_idx * k + k_idx]; | ||
int weight_idx = n_idx * k + k_idx; | ||
uint8_t qval = weight_qval_indices[weight_idx]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: qval_idx
*/ | ||
static groupwise_lowbit_weight_lut_test_case generate_per_group( | ||
int m, int k, int n, | ||
int group_size, // The size of the block for both scales and LUTs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If they are the same, we'd just integrate the scales with the LUTs, no?
/** | ||
* @brief OVERLOAD 2: Advanced generator with separate grouping for scales and LUTs. | ||
*/ | ||
static groupwise_lowbit_weight_lut_test_case generate_with_decoupled_grouping( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a flag for has_scales. When set to false, make all the scales 1.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks great! Approving PR, left a few comments.
Looks like there are some CI errors as well |
fca5d8c
to
2c9b99b
Compare
2c9b99b
to
fdce227
Compare
No description provided.