Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QAT support for distributed finetuning #980

Merged
merged 1 commit into from
Jun 27, 2024
Merged

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented May 14, 2024

Summary: This commit adds the option to run quantization-aware training (QAT) during finetuning. QAT refers to "fake quantizing" the weights and activations during training, which performs the following transformation on the inputs but still keeps all intermediate values in floating point:

x_q = clamp((x_bf16 / scale) + zp)
x_fq = (x_q - zp) * scale

Currently only 8-bit per token dynamic activations + 4-bit grouped per channel weights (8da4w) is supported. Users can enable this by specifying a QAT quantizer in their config files:

tune run --nnodes 1 --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full

# or add this to your config file
# quantizer:
#   _component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer
#   groupsize: 256

Test Plan:

Initial results for Llama2 demonstrate that QAT is able to recover the loss in accuracy from quantization by about half for some tasks (last two rows):

hellaswag wikitext arc_easy arc_challenge
acc acc_norm word_perplexity byte_perplexity bits_per_byte acc acc_norm acc acc_norm
No quant 59.659% 76.927% 12.183 1.596 0.674 76.010% 72.054% 48.720% 47.867%
PTQ 57.150% 74.945% 12.995 1.615 0.692 75.968% 70.118% 46.416% 45.904%
QAT 58.504% 76.170% 12.199 1.596 0.675 76.431% 71.928% 47.184% 48.123%
PTQ degradation -2.509% -1.982% +0.812 +0.019 +0.018 -0.042% -1.936% -2.304% -1.963%
QAT degradation -1.155% -0.757% +0.016 +0.000 +0.001 0.421% -0.126% -1.536% 0.256%

Copy link

pytorch-bot bot commented May 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/980

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c110f45 with merge base c1c9f30 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 14, 2024
@andrewor14
Copy link
Contributor Author

By the way, the tests are failing due to torchtune's dependency on an old version of torchao, which doesn't have QAT support yet. We're about to release 0.2.0 on the torchao side (aiming early next week), so the tests won't pass until then.

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @andrewor14! A few high-level comments apart from the one on teasing this out into a separate recipe:

Currently only 8-bit per token dynamic activations + 4-bit grouped per channel weights (8da4w) is supported

A couple of questions:

  • Do we plan on adding other quantization methods? Or whats the long term support plan for this?
  • We need a lot of documentation here to make sure users understand what this actually means. Is this reasonably well understood by the general audience? for example, I'm not sure about all of the details here.

We also need to consider how to lower the bar for adopotion here. I think this needs a tutorial or a deepdive added to the torchtune docs.

cc: @ebsmothers

@@ -116,6 +116,8 @@ def __init__(self, cfg: DictConfig) -> None:
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._qat_enable_fake_quant_step = cfg.get("qat_enable_fake_quant_step", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be a bit more user friendly i.e. something like enable_qat? Or does that not make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config is about when to enable fake quant, not whether to enable QAT itself. E.g. setting this to 1000 means we will run regular finetuning for the first 1000 steps, and only enable fake quant after 1000 steps. I'll add better comments/docs about this. Do you think the name makes sense or do you have suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, then I grossly misunderstood :) Maybe something like quant_after_n_steps?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation of delaying the fake quantization until after N steps? Are there issues with stability (and if so, how does training without fake quantization first mitigate them)?

@@ -288,6 +292,18 @@ def _setup_model(
ac_option,
)

# Optionally apply quantization-aware training during finetuning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is intrusive enough, where I'd prefer this to be a separate recipe where you can remove all of the non QAT related code paths. Generally we dont want to have recipes with a bunch of if-else blocks since this:

  • reduces readability of code
  • significantly increases the chances of bugs as recipes become more complicated
  • makes maintenance really hard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can move this out to a separate recipe. However, this will require copying and pasting all the non-QAT related training code, and over time they will likely diverge from the full_finetune_distributed recipe. If you think that is preferrable to complicating this existing recipe then I'll go ahead and separate it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've been pretty good at making sure we update all of the recipes with new features. I do think QAT is something we want to publicize heavily and so having its own recipe opens up avenues for future work as well

@andrewor14
Copy link
Contributor Author

Hi @kartikayk, thanks for the comments, responding inline:

  • Do we plan on adding other quantization methods? Or whats the long term support plan for this?

Yes, in the long term we do plan to support other QAT configurations (e.g. 2- or 3-bit weight only if we can get good results), that's why I kept the quantizer specification general.

  • We need a lot of documentation here to make sure users understand what this actually means. Is this reasonably well understood by the general audience? for example, I'm not sure about all of the details here. We also need to consider how to lower the bar for adopotion here. I think this needs a tutorial or a deepdive added to the torchtune docs.

For sure. Should I add the README in this PR and add the tutorial separately?

@kartikayk
Copy link
Contributor

Should I add the README in this PR and add the tutorial separately?

I think a tutorial/deep-dive in the docs would be really helpful. You can add some details on what the quantization methods mean as well - I think this will be very useful and make the flow more noob friendly.

Here are some pointers:

I don't mind this as a follow up.

@ebsmothers let me know if you have differing thoughts on this.

@ebsmothers
Copy link
Contributor

Thanks for the PR! I think I'm in agreement with most everything that's been said already: (1) separate recipe for QAT (I agree this will make it easier to scale once we add other quant techniques anyways) and (2) add some kind of tutorial but as a follow-up (also happy to provide any pointers or guidance you need in advance).

@@ -0,0 +1,78 @@
# Config for multi-device QAT finetuning in qat_distributed.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should figure out where we wanna put QAT configs. I guess it's somewhat different than our current configs layout.. right now we segment by model at the top level. For now we only have one technique, but if we plan to support more we should think about how to split it up. I can think of 3 ways to do this:

  1. Provide a single QAT config per technique and keep them all at the top level (e.g. qat_full.yaml, qat_lora.yaml, ...)
  2. Provide QAT configs per model (so llama3/8B_qat_full_finetune.yaml etc.)
  3. Provide a separate QAT folder and put everything in there with defaults chosen for a canonical model (basically (1) but under a qat config directory)

That's not a major blocker for this PR, but lmk which one makes most sense to you. I would, however, consider at least renaming qat.yaml -> qat_full.yaml or something like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I mostly just followed quantization.yaml, eleuther_evaluation.yaml, and generation.yaml so far in this PR. I think these configs have the same problem. Both (2) and (3) make sense to me. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I did (2) for now (put it in respective llama2 and llama3 dirs). Let me know if this sounds reasonable to you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this is reasonable. It is a bit of a pain to override all the necessary config fields to change models from the command line, so I think it makes sense to separate out QAT configs by model (since for (3) it'd get really verbose to change the model from the default config anyways). The downside is the configs are slightly less visible, but given we have a standalone top-level recipe this is OK imo.

pyproject.toml Outdated
@@ -25,7 +25,8 @@ dependencies = [
"omegaconf",

# Quantization
"torchao==0.1",
# TODO: update to 0.3
"torchao==0.2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan for merging here? Will we wait until 0.3 is available? Or will we merge sooner on a nightly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll wait till 0.3 is available

recipes/qat_distributed.py Show resolved Hide resolved
recipes/qat_distributed.py Outdated Show resolved Hide resolved
recipes/qat_distributed.py Outdated Show resolved Hide resolved
recipes/quantization.md Outdated Show resolved Hide resolved
recipes/quantization.md Outdated Show resolved Hide resolved
recipes/quantization.md Outdated Show resolved Hide resolved
recipes/quantize.py Outdated Show resolved Hide resolved
Comment on lines +107 to +112
if "qat" in self._quantization_mode:
self._model = self._quantizer.convert(self._model)
else:
self._model = self._quantizer.quantize(self._model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we need to gate on quantization mode because the QAT checkpoints are in bf16, right? Noob question but if we are not doing any subsequent training why can't we just call .quantize directly and infer all the quantizer params from the checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we can, the numerics may be the the same, but officially the torchao QAT flow is:

quantizer = 8da4wQATQuantizer()
model = quantizer.prepare(model)
train(model)
model = quantizer.convert(model)

If we just call quantize here we would have to introduce a different quantizer

quantizer = 8da4wQATQuantizer()
model = quantizer.prepare(model)
train(model)
ptq_quantizer = 8da4wQuantizer()
ptq_quantizer.quantize(model)

I feel it's better to call the complete QAT flow rather than to switch quantizers in the middle

@andrewor14
Copy link
Contributor Author

Update: I think I've addressed all the comments and this PR is ready from my side. Do you have other comments @ebsmothers @kartikayk?

Note that landing is blocked right now on the torchao 0.3 release (currently scheduled for 6/26). This is because QAT was only added in torchao 0.2, but the following error was not fixed until torchao 0.3, so there's no other way to get the QAT feature unless we want to rely on nightlies, which we don't.

  File "/__w/_temp/conda_environment_9553124795/lib/python3.8/site-packages/torchtune/modules/common_utils.py", line 12, in <module>
    from torchao.dtypes.nf4tensor import NF4Tensor
  File "/__w/_temp/conda_environment_9553124795/lib/python3.8/site-packages/torchao/__init__.py", line 14, in <module>
    from . import _C
ImportError: /__w/_temp/conda_environment_9553124795/lib/python3.8/site-packages/torchao/_C.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN5torch3jit11parseSchemaERKSs

Summary: This commit adds the option to run quantization-aware
training (QAT) during finetuning. QAT refers to "fake quantizing"
the weights and activations during training, which performs the
following transformation on the inputs but still keeps all
intermediate values in floating point:

```
x_q = clamp((x_bf16 / scale) + zp)
x_fq = (x_q - zp) * scale
```

Currently only 8-bit per token dynamic activations + 4-bit grouped
per channel weights (8da4w) is supported. Users can enable this by
specifying a QAT quantizer in their config files:

```
tune run --nnodes 1 --nproc_per_node 8 qat_distributed --config qat
```

Test Plan:

Initial results for Llama2 demonstrate that QAT is able to
recover the loss in accuracy from quantization by about half
for some tasks (last two rows):

|                 | hellaswag |          |     wikitext    |                 |               | arc_easy |          | arc_challenge |          |
|-----------------|:---------:|----------|:---------------:|-----------------|---------------|:--------:|----------|:-------------:|----------|
|                 |    acc    | acc_norm | word_perplexity | byte_perplexity | bits_per_byte |    acc   | acc_norm |      acc      | acc_norm |
| No quantization |  59.659%  |  76.927% |      12.183     |      1.596      |     0.674     |  76.010% |  72.054% |    48.720%    |  47.867% |
| PTQ             |  57.150%  |  74.945% |      12.995     |      1.615      |     0.692     |  75.968% |  70.118% |    46.416%    |  45.904% |
| QAT (bf16)      |  58.435%  |  76.190% |      12.200     |      1.596      |     0.675     |  76.810% |  72.180% |    47.270%    |  47.184% |
| QAT (quantized) |  58.504%  |  76.170% |      12.199     |      1.596      |     0.675     |  76.431% |  71.928% |    47.184%    |  48.123% |
| PTQ degradation |  -2.509%  |  -1.982% |      0.812      |      0.019      |     0.018     |  -0.042% |  -1.936% |    -2.304%    |  -1.963% |
| QAT degradation |  -1.155%  |  -0.757% |      0.016      |      0.000      |     0.001     |  0.421%  |  -0.126% |    -1.536%    |  0.256%  |
@codecov-commenter
Copy link

codecov-commenter commented Jun 27, 2024

Codecov Report

Attention: Patch coverage is 9.32836% with 243 lines in your changes missing coverage. Please review.

Project coverage is 64.89%. Comparing base (52e3283) to head (c110f45).
Report is 1 commits behind head on main.

Files Patch % Lines
recipes/qat_distributed.py 0.00% 212 Missing ⚠️
tests/recipes/test_qat_distributed.py 48.57% 18 Missing ⚠️
torchtune/utils/quantization.py 46.15% 7 Missing ⚠️
recipes/quantize.py 0.00% 5 Missing ⚠️
tests/recipes/test_configs.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #980      +/-   ##
==========================================
- Coverage   66.69%   64.89%   -1.81%     
==========================================
  Files         184      186       +2     
  Lines        8578     8838     +260     
==========================================
+ Hits         5721     5735      +14     
- Misses       2857     3103     +246     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ebsmothers ebsmothers merged commit fd7c15f into pytorch:main Jun 27, 2024
29 checks passed
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants