-
Notifications
You must be signed in to change notification settings - Fork 317
Add test case generator for groupwise low bit LUT based quantization #2359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b437e39
e2a3bce
d1c825c
fdce227
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -575,6 +575,192 @@ struct lowbit_embedding_test_case { | |
} | ||
}; | ||
|
||
struct groupwise_lowbit_weight_lut_test_case { | ||
//-------------------------------------------------------------------------- | ||
// Parameters | ||
//-------------------------------------------------------------------------- | ||
int m, k, n; | ||
int scale_group_size; | ||
int lut_group_size; | ||
int weight_nbit; | ||
bool has_scales, has_bias, has_clamp; | ||
float clamp_min, clamp_max; | ||
|
||
//-------------------------------------------------------------------------- | ||
// Data Tensors | ||
//-------------------------------------------------------------------------- | ||
std::vector<float> expected_output; | ||
std::vector<float> activations; | ||
std::vector<float> bias; | ||
std::vector<uint8_t> weight_qval_indices; // Indices into a LUT for each weight | ||
std::vector<float> weight_luts; // The pool of unique LUTs | ||
std::vector<float> weight_scales; // The pool of unique scales | ||
|
||
//-------------------------------------------------------------------------- | ||
// Constructor | ||
//-------------------------------------------------------------------------- | ||
groupwise_lowbit_weight_lut_test_case( | ||
int m_, int k_, int n_, int scale_group_size_, int lut_group_size_, int weight_nbit_, bool has_scales_, bool has_bias_, bool has_clamp_, | ||
float clamp_min_, float clamp_max_, | ||
std::vector<float> expected_output_, std::vector<float> activations_, | ||
std::vector<float> bias_, std::vector<uint8_t> weight_qval_indices_, | ||
std::vector<float> weight_luts_, std::vector<float> weight_scales_) | ||
: m(m_), k(k_), n(n_), | ||
scale_group_size(scale_group_size_), lut_group_size(lut_group_size_), weight_nbit(weight_nbit_), | ||
has_scales(has_scales_), | ||
has_bias(has_bias_), has_clamp(has_clamp_), clamp_min(clamp_min_), clamp_max(clamp_max_), | ||
expected_output(expected_output_), | ||
activations(activations_), | ||
bias(bias_), | ||
weight_qval_indices(weight_qval_indices_), | ||
weight_luts(weight_luts_), | ||
weight_scales(weight_scales_) | ||
{} | ||
|
||
//-------------------------------------------------------------------------- | ||
// Generator Functions (Factories) | ||
//-------------------------------------------------------------------------- | ||
|
||
private: | ||
/** | ||
* @brief The private "master" generator that provides maximum flexibility. | ||
* | ||
* This function is the core engine. It takes the exact number of scales and LUTs | ||
* to generate and constructs the test case. All other public generators are | ||
* wrappers around this one. | ||
*/ | ||
static groupwise_lowbit_weight_lut_test_case _generate_master( | ||
int m, int k, int n, | ||
int scale_group_size, // Directly controls scale change frequency | ||
int lut_group_size, // Directly controls LUT change frequency | ||
int weight_nbit, bool has_scales, | ||
bool has_bias, bool has_clamp) { | ||
|
||
// --- 0. Validation and Setup --- | ||
const int total_weights = n * k; | ||
// Frequencies are controlled by their group sizes. | ||
assert(total_weights % scale_group_size == 0); | ||
assert(total_weights % lut_group_size == 0); | ||
|
||
// The number of unique scales/LUTs is derived directly from their group size. | ||
const int num_scales = total_weights / scale_group_size; | ||
const int num_luts = total_weights / lut_group_size; | ||
const int lut_size = 1 << weight_nbit; | ||
std::mt19937 gen(std::random_device{}()); | ||
|
||
// --- 1. Generate Primary Inputs --- | ||
auto activations = get_random_vector(m * k, -1.0f, 1.0f); | ||
std::vector<float> bias_vec(n, 0.0f); | ||
if (has_bias) bias_vec = get_random_vector(n, -0.5f, 0.5f); | ||
float clamp_min = -std::numeric_limits<float>::infinity(), clamp_max = std::numeric_limits<float>::infinity(); | ||
if (has_clamp) { | ||
auto r = get_random_vector(2, -5.0f, 5.0f); | ||
clamp_min = std::min(r[0], r[1]); clamp_max = std::max(r[0], r[1]); | ||
} | ||
|
||
// --- 2. Generate Quantization Data --- | ||
// 2a. Generate the pools of unique scales and LUTs. | ||
std::vector<float> weight_scales; | ||
if (has_scales) { | ||
// Normal case: generate random scales. | ||
weight_scales = get_random_vector(num_scales, 0.001f, 0.1f); | ||
} else { | ||
// LUT-only case: create a vector where every scale is 1.0f. | ||
weight_scales.assign(num_scales, 1.0f); | ||
} | ||
|
||
auto weight_luts = get_random_vector(num_luts * lut_size, -0.2f, 0.2f); // Independent random LUTs | ||
|
||
// 2b. Generate random quantized indices for each weight. | ||
auto weight_qval_indices = std::vector<uint8_t>(total_weights); | ||
std::uniform_int_distribution<int> qval_dis(0, lut_size - 1); | ||
for (int i = 0; i < total_weights; ++i) weight_qval_indices[i] = static_cast<uint8_t>(qval_dis(gen)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @szyszyzys why can't we use get_random_lowbit_vector here? |
||
|
||
// --- 3. Compute Expected Output using the IMPLICIT mappings --- | ||
std::vector<float> expected_output(m * n); | ||
for (int m_idx = 0; m_idx < m; ++m_idx) { | ||
for (int n_idx = 0; n_idx < n; ++n_idx) { | ||
float res = 0.0f; | ||
for (int k_idx = 0; k_idx < k; ++k_idx) { | ||
float activation_val = activations[m_idx * k + k_idx]; | ||
int weight_idx = n_idx * k + k_idx; | ||
uint8_t qval_idx = weight_qval_indices[weight_idx]; | ||
|
||
int32_t scale_idx = weight_idx / scale_group_size; | ||
int32_t lut_idx = weight_idx / lut_group_size; | ||
|
||
// Dequantize: scale * LUT_value | ||
float scale = weight_scales[scale_idx]; | ||
float lut_val = weight_luts[lut_idx * lut_size + qval_idx]; | ||
res += activation_val * (scale * lut_val); | ||
} | ||
res += bias_vec[n_idx]; | ||
if (has_clamp) { res = std::clamp(res, clamp_min, clamp_max); } | ||
expected_output[m_idx * n + n_idx] = res; | ||
} | ||
} | ||
|
||
// --- 4. Construct and Return --- | ||
return groupwise_lowbit_weight_lut_test_case( | ||
m, k, n, scale_group_size, lut_group_size, weight_nbit, has_scales, | ||
has_bias, has_clamp, clamp_min, clamp_max, | ||
expected_output, | ||
activations, | ||
bias_vec, | ||
weight_qval_indices, | ||
weight_luts, | ||
weight_scales); | ||
|
||
} | ||
|
||
public: | ||
/** | ||
* @brief OVERLOAD 1: Simple generator where scales and LUTs share the same grouping. | ||
* | ||
* This is for the simplest case where a block of weights gets one scale and one LUT, | ||
* and this pattern repeats. | ||
*/ | ||
static groupwise_lowbit_weight_lut_test_case generate_per_group( | ||
int m, int k, int n, | ||
int group_size, // The size of the block for both scales and LUTs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If they are the same, we'd just integrate the scales with the LUTs, no? |
||
int weight_nbit, bool has_scales, | ||
bool has_bias, bool has_clamp) { | ||
|
||
std::cout << "[Generator Info] Using 'Per-Group' model.\n" | ||
<< " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl; | ||
|
||
// Just call the decoupled generator with the same group size for both. | ||
return _generate_master( | ||
m, k, n, | ||
group_size, /* scale_group_size */ | ||
group_size, /* lut_group_size */ | ||
weight_nbit, | ||
has_scales, | ||
has_bias, has_clamp | ||
); | ||
} | ||
|
||
/** | ||
* @brief OVERLOAD 2: Advanced generator with separate grouping for scales and LUTs. | ||
*/ | ||
static groupwise_lowbit_weight_lut_test_case generate_with_decoupled_grouping( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a flag for has_scales. When set to false, make all the scales 1.0. |
||
int m, int k, int n, | ||
int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales, | ||
bool has_bias, bool has_clamp) { | ||
|
||
std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n" | ||
<< " - Scales will switch every " << scale_group_size << " weights.\n" | ||
<< " - LUTs will switch every " << lut_group_size << " weights." << std::endl; | ||
|
||
return _generate_master( | ||
m, k, n, | ||
scale_group_size, lut_group_size, | ||
weight_nbit, has_scales, | ||
has_bias, has_clamp | ||
); | ||
} | ||
}; | ||
|
||
} // namespace torchao | ||
|
||
#endif // defined(__aarch64__) || defined(__ARM_NEON) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we use get_random_lowbit_vector?