Skip to content

Conversation

iseeyuan
Copy link
Contributor

@iseeyuan iseeyuan commented Sep 4, 2024

@TiRune identified this option. In the situation where the absolute value of quantized min and max can be different, like [-8, 7], we can calculate the scale factor with the pos and neg individually, and pick the larger one. It shows perplexity improvement in llama-like 4-bit weight quantized models.

before:
'word_perplexity,none': 24.198390005931635
after:
'word_perplexity,none': 23.25360136363946

In this PR, one mapping type, SYMMETRIC_MAX_POS_NEG is added, to get the group symmetric quantization scales as mentioned above.

Please refer to the inline comments for the reasoning.

In the situation where min and max can be different, like [-8, 7]
Copy link

pytorch-bot bot commented Sep 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/805

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 24c0873 with merge base 144445a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 4, 2024
@cpuhrsch cpuhrsch requested a review from jerryzh168 September 4, 2024 19:36
@@ -730,8 +730,10 @@ def _choose_qparams_affine(
max_val_pos = max_val

if mapping_type == MappingType.SYMMETRIC.name:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
smin = min_val_neg / float(quant_min)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline that @iseeyuan will create a new mapping_type for this, and this is not always better than the existing way of computing scale. we can discuss the naming a bit later

Copy link

@TiRune TiRune Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what scenarios would it be worse? I would argue that this is always intended behavior, for any symmetric setting. Essentially, with min-max you never want there to be clipping of your values, which I believe could happen in the current implementation.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another issue not mentioned yet, if you take the current scheme, export the weights that are qdqed, then load them back in again with the min-max quantizer, the results change everytime you do this. With this PR, the weights will stay the same even if you apply the min-max quantizer multiple times.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TiRune why you never want to clip values? it's always a trade off between clipping error and rounding error I think?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, it's more about what's expected of a min-max quantizer. Sure, we always need to trade off clipping and rounding error. But that's what the MSE-based range setting, HQQ or those types of algorithms are for. The current choice for symmetric is kinda arbitrary to be 1/2 (q_min + q_max).

To me, a min-max range setter is 1) it always includes both the min-and max in the range so there's no clipping error and 2) if it's applied twice the same result comes out i.e., f(f(x)) = f(x). We use this for e.g. to export fake quant weights and load them in another library like Executorch :D

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another use-case of the min-max quantizer: Use it as the worst-case scenario for quantization to seed HQQ/MSE based range setting with.

E.g. you start with min-max, then search over shrinked versions of this range like in AWQ, you search over 0.99^N * the scale factor. In this algorithm version, you also expect the min-max quantizer to start off with 0 clipping error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @TiRune and @jerryzh168 ! To play safe incase there's usage of the old mapping, and quickly unblock the usage (without fixing all tests, especially those old tests that are planned to update), I added a new symmetric mapping type.

Later if we notice that the the two symmetric mappings can merge to one, we can have a PR to merge them.

Please review the code and let me know if it make sense.

Copy link
Contributor

@jerryzh168 jerryzh168 Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TiRune I see, yeah that makes sense I think. as @iseeyuan mentioned the current impl is used in all the existing code so changing the behavior will be bc-breaking, it will be better to add a new mapping type.

@iseeyuan iseeyuan changed the title [test] Update the way scale is calculated for affine Symmetric Update the way scale is calculated for affine Symmetric Sep 5, 2024
@facebook-github-bot
Copy link
Contributor

@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@iseeyuan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -41,12 +41,18 @@ class MappingType(Enum):
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
e.g. scale = (10.2 - (-10.2)) / (7 - (-8))

SYMMETRIC_MAX_POS_NEG is a variant of symmetric mapping, where the scale is the max of smin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about SYMMETRIC_NO_CLIPPING_ERR? also cc @TiRune if there is suggestion for the name

@jerryzh168
Copy link
Contributor

error is relevant I think, maybe there is some changes to the QAT behavior of 8da4w

@@ -1022,13 +1024,15 @@ def __init__(
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
mapping_type: MappingType = MappingType.SYMMETRIC_MAX_POS_NEG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually why is this set to the new type? I think it will be safer to keep the old one?

@iseeyuan iseeyuan force-pushed the iseeyuan-patch-1 branch 2 times, most recently from 347f6d6 to c8675aa Compare September 7, 2024 00:05
@@ -41,12 +41,18 @@ class MappingType(Enum):
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
e.g. scale = (10.2 - (-10.2)) / (7 - (-8))

SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gau-nernst wondering if this will help int8 training as well:

2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly, will need to test. All INT8 training papers I see use 127 for both +ve and -ve though.

One potential concern I foresee is about quantization speed. During training we keep doing quantization, so having a fast one is crucial. Again, will need to try out / profile to be sure 😄.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it seems even without this change 127 can be achieved with setting quant_min/quant_max to be (-127, 127)

yeah we could benchmark perf if we decide to use AQT for int8 training

@@ -212,7 +213,7 @@ def test_qat_8da4w_quantizer(self):
m = M()
m2 = copy.deepcopy(m)
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=MappingType.SYMMETRIC)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if this is default now, we can remove the arg here I think

@@ -164,7 +164,7 @@ def test_get_group_qparams_symmetric(self):
scale_obs = scale_obs.reshape(weight.shape[0], -1)

# assert that scales are identical
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16)
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16, mapping_type=MappingType.SYMMETRIC)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same more these


self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_group_sym_pos_neg(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please update the name of the test as well

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @iseeyuan and @TiRune

@jerryzh168 jerryzh168 merged commit c6abf2b into main Sep 7, 2024
20 checks passed
@jerryzh168 jerryzh168 deleted the iseeyuan-patch-1 branch September 7, 2024 06:17
facebook-github-bot pushed a commit to pytorch/executorch that referenced this pull request Sep 7, 2024
…perplexity (#5163)

Summary:
Refer to pytorch/ao#805 for the details.
With this change, the perplexity of a llama model is improved 4% on wikitext.


Differential Revision: D62342523

Pulled By: iseeyuan
facebook-github-bot pushed a commit to pytorch/executorch that referenced this pull request Sep 9, 2024
…perplexity (#5163)

Summary:
Refer to pytorch/ao#805 for the details.
With this change, the perplexity of a llama model is improved 4% on wikitext.


Differential Revision: D62342523

Pulled By: iseeyuan
jainapurva pushed a commit that referenced this pull request Sep 9, 2024
* [test] Update the way scale is calculated for affine Symmetric

In the situation where min and max can be different, like [-8, 7]

* Update quant_primitives.py

* Update test_qat.py

* Update test_quant_primitives.py

* Update test_quant_api.py
facebook-github-bot pushed a commit to pytorch/executorch that referenced this pull request Sep 11, 2024
…perplexity (#5163)

Summary:
Refer to pytorch/ao#805 for the details.
With this change, the perplexity of a llama model is improved 4% on wikitext.


Reviewed By: mergennachin, helunwencser

Differential Revision: D62342523

Pulled By: iseeyuan
facebook-github-bot pushed a commit to pytorch/executorch that referenced this pull request Sep 12, 2024
…perplexity (#5163)

Summary:
Refer to pytorch/ao#805 for the details.
With this change, the perplexity of a llama model is improved 4% on wikitext.


Reviewed By: mergennachin, helunwencser

Differential Revision: D62342523

Pulled By: iseeyuan
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants