[moe training] refactor configs, recipes; support converting linears + grouped gemms in a single quantize_() call#3862
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3862
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6df211d with merge base 920c502 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
326d30b to
62159aa
Compare
c7556e0 to
49ab389
Compare
torchao/quantization/quant_api.py
Outdated
| # Once we've updated all the transform functions to take in a custom_param kwarg, we can delete this object and the subsequent check | ||
| # TODO see https://github.com/pytorch/ao/issues/3252 for more details | ||
| from torchao.prototype.moe_training.config import ( | ||
| FP8GroupedMMConfig, |
There was a problem hiding this comment.
hmm prototype in main api
cc @jerryzh168 is that okay?
There was a problem hiding this comment.
I think we shouldn't do this, there are some imports here that started here and moved to prototype, but we shouldn't add new ones here
There was a problem hiding this comment.
@jerryzh168 any thoughts on how to allowlist the training configs to use quantize_() + FqnToConfig, if we are hesitant about importing from prototype here?
There was a problem hiding this comment.
also before this, we could also drop some of the configs that's moved to prototype
There was a problem hiding this comment.
Yeah I was planning on taking a stab #3252 tonight, will tag u guys in the PRs
If you have a list of configs that have been moved to prototype @jerryzh168, please lmk.
There was a problem hiding this comment.
yeah check this: #2752, we can remove code for prototype configs, and also remove code for v1 config (basically everything implemented with AQT), and in the end we can remove AQT
There was a problem hiding this comment.
sounds good, thanks @jcaip lmk when that's ready and i'll rebase on top of it
There was a problem hiding this comment.
cc @danielvegamyhre if you rebase on top of #3894 you should no longer need to import for CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS
There was a problem hiding this comment.
rebased on top of #3894 now that it has landed, and removed the prototype import in quant_api.py.
planning to land this once CI is green
49ab389 to
ecc59d0
Compare
…ped gemms in a single quantize_() call
ecc59d0 to
6df211d
Compare
Summary
To prepare the MoE training prototype to stable we need the dev experience to be consistent with the rest of torchao, namely the fp8/mxfp8 linear training.
This PR:
quantize_()callThese files contain the important changes, everything else is peripheral:
torchao/prototype/moe_training/config.pytorchao/prototype/moe_training/conversion_utils.pytorchao/prototype/moe_training/mxfp8_grouped_mm.pytorchao/prototype/moe_training/fp8_grouped_mm.pytorchao/prototype/moe_training/tensor.pytorchao/quantization/quant_api.pyNew model conversion API
New recipes
New configs
Testing
Added new test:
pytest test/prototype/moe_training/test_fqn_to_config.pyExisting tests passing:
./test/prototype/moe_training/test_everything.shDetailed change list (optional read)
Configuration Refactoring
Created new config.py file containing all MoE training configurations:
torch._dynamo.nonstrict_tracemode, which is in turn needed to support accept pre-quantized MXTensor inputs and regular torch.Tensor outputs and backward input gradients)Code Organization
Split grouped GEMM autograd func implementations:
Refactored conversion_utils.py:
Updates to tensor.py:
Support linear and grouped mm conversion in single quantize_() call:
API Changes: