-
Notifications
You must be signed in to change notification settings - Fork 333
Update the way scale is calculated for affine Symmetric #805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
In the situation where min and max can be different, like [-8, 7]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/805
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 24c0873 with merge base 144445a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -730,8 +730,10 @@ def _choose_qparams_affine( | |||
max_val_pos = max_val | |||
|
|||
if mapping_type == MappingType.SYMMETRIC.name: | |||
max_val_pos = torch.max(-min_val_neg, max_val_pos) | |||
scale = max_val_pos / (float(quant_max - quant_min) / 2) | |||
smin = min_val_neg / float(quant_min) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline that @iseeyuan will create a new mapping_type for this, and this is not always better than the existing way of computing scale. we can discuss the naming a bit later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what scenarios would it be worse? I would argue that this is always intended behavior, for any symmetric setting. Essentially, with min-max you never want there to be clipping of your values, which I believe could happen in the current implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's another issue not mentioned yet, if you take the current scheme, export the weights that are qdqed, then load them back in again with the min-max quantizer, the results change everytime you do this. With this PR, the weights will stay the same even if you apply the min-max quantizer multiple times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TiRune why you never want to clip values? it's always a trade off between clipping error and rounding error I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me, it's more about what's expected of a min-max quantizer. Sure, we always need to trade off clipping and rounding error. But that's what the MSE-based range setting, HQQ or those types of algorithms are for. The current choice for symmetric is kinda arbitrary to be 1/2 (q_min + q_max).
To me, a min-max range setter is 1) it always includes both the min-and max in the range so there's no clipping error and 2) if it's applied twice the same result comes out i.e., f(f(x)) = f(x). We use this for e.g. to export fake quant weights and load them in another library like Executorch :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another use-case of the min-max quantizer: Use it as the worst-case scenario for quantization to seed HQQ/MSE based range setting with.
E.g. you start with min-max, then search over shrinked versions of this range like in AWQ, you search over 0.99^N * the scale factor. In this algorithm version, you also expect the min-max quantizer to start off with 0 clipping error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @TiRune and @jerryzh168 ! To play safe incase there's usage of the old mapping, and quickly unblock the usage (without fixing all tests, especially those old tests that are planned to update), I added a new symmetric mapping type.
Later if we notice that the the two symmetric mappings can merge to one, we can have a PR to merge them.
Please review the code and let me know if it make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
382af0d
to
07dfc8e
Compare
@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
8f900d9
to
3bd151a
Compare
@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
def8504
to
a0aa6ee
Compare
@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -41,12 +41,18 @@ class MappingType(Enum): | |||
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) | |||
e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) | |||
|
|||
SYMMETRIC_MAX_POS_NEG is a variant of symmetric mapping, where the scale is the max of smin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about SYMMETRIC_NO_CLIPPING_ERR
? also cc @TiRune if there is suggestion for the name
error is relevant I think, maybe there is some changes to the QAT behavior of 8da4w |
torchao/quantization/GPTQ.py
Outdated
@@ -1022,13 +1024,15 @@ def __init__( | |||
precision: torch.dtype = torch.float32, | |||
scales_precision: torch.dtype = torch.float32, | |||
device: torch.device = torch.device("cpu"), | |||
mapping_type: MappingType = MappingType.SYMMETRIC_MAX_POS_NEG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually why is this set to the new type? I think it will be safer to keep the old one?
347f6d6
to
c8675aa
Compare
@@ -41,12 +41,18 @@ class MappingType(Enum): | |||
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) | |||
e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) | |||
|
|||
SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @gau-nernst wondering if this will help int8 training as well:
2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly, will need to test. All INT8 training papers I see use 127 for both +ve and -ve though.
One potential concern I foresee is about quantization speed. During training we keep doing quantization, so having a fast one is crucial. Again, will need to try out / profile to be sure 😄.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually it seems even without this change 127 can be achieved with setting quant_min/quant_max to be (-127, 127)
yeah we could benchmark perf if we decide to use AQT for int8 training
c8675aa
to
735e15d
Compare
test/quantization/test_qat.py
Outdated
@@ -212,7 +213,7 @@ def test_qat_8da4w_quantizer(self): | |||
m = M() | |||
m2 = copy.deepcopy(m) | |||
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) | |||
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) | |||
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=MappingType.SYMMETRIC) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if this is default now, we can remove the arg here I think
@@ -164,7 +164,7 @@ def test_get_group_qparams_symmetric(self): | |||
scale_obs = scale_obs.reshape(weight.shape[0], -1) | |||
|
|||
# assert that scales are identical | |||
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16) | |||
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16, mapping_type=MappingType.SYMMETRIC) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same more these
|
||
self.assertTrue(torch.equal(scale, scale_ref)) | ||
self.assertTrue(torch.equal(zero_point, zp_ref)) | ||
|
||
def test_choose_qparams_group_sym_pos_neg(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please update the name of the test as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…perplexity (#5163) Summary: Refer to pytorch/ao#805 for the details. With this change, the perplexity of a llama model is improved 4% on wikitext. Differential Revision: D62342523 Pulled By: iseeyuan
…perplexity (#5163) Summary: Refer to pytorch/ao#805 for the details. With this change, the perplexity of a llama model is improved 4% on wikitext. Differential Revision: D62342523 Pulled By: iseeyuan
* [test] Update the way scale is calculated for affine Symmetric In the situation where min and max can be different, like [-8, 7] * Update quant_primitives.py * Update test_qat.py * Update test_quant_primitives.py * Update test_quant_api.py
…perplexity (#5163) Summary: Refer to pytorch/ao#805 for the details. With this change, the perplexity of a llama model is improved 4% on wikitext. Reviewed By: mergennachin, helunwencser Differential Revision: D62342523 Pulled By: iseeyuan
…perplexity (#5163) Summary: Refer to pytorch/ao#805 for the details. With this change, the perplexity of a llama model is improved 4% on wikitext. Reviewed By: mergennachin, helunwencser Differential Revision: D62342523 Pulled By: iseeyuan
@TiRune identified this option. In the situation where the absolute value of quantized min and max can be different, like [-8, 7], we can calculate the scale factor with the pos and neg individually, and pick the larger one. It shows perplexity improvement in llama-like 4-bit weight quantized models.
before:
'word_perplexity,none': 24.198390005931635
after:
'word_perplexity,none': 23.25360136363946
In this PR, one mapping type,
SYMMETRIC_MAX_POS_NEG
is added, to get the group symmetric quantization scales as mentioned above.Please refer to the inline comments for the reasoning.