-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Tensor Subclass based Quantization API #391
Comments
Regarding 1, apart from what I have feedbacked in #384, starting to think of another alternative quantizer = Int4WeightOnlyQuantizer(groupsize=32)
quantizer.quantize(model) But then this feels like the old api Personally I don't really like a function returning a function, like the current Another option is to expose from functools import partial
quantize(model, partial(apply_int4wo_quant, groupsize=32)) Also, since the quantization is in-place, I think it's good to use |
For the manual API why have both a string and a |
Is there a tutorial or end-to-end example of how to compose these APIs to implement a non-trivial quantization method (e.g., AWQ, GPTQ, etc.) and specialized deployment layout (e.g., Marlin)? Basically a reference impl of how these tools can be used to facilitate the translation of research ideas to deployment-ready libraries. If not, happy to work on one. |
the quantizer API is actually what I have been thinking about before as "Unified Quantization API": https://github.com/pytorch/ao/blob/main/torchao/quantization/unified.py and these two APIs will cover most of the current quant flows, it's also used by QAT prototype: ao/torchao/quantization/prototype/qat.py Line 22 in d0af941
the partial function idea has been raised in our meetings before as well, but that also doesn't seem very straightforward to use. For now I'm planning to just use also in the ideal future I think we'd expect modeling user just use the autoquant and not worry about all these details |
so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names |
Not yet, so my understanding is that this doc talks about how we build the fundamental "dtype" of quantization, it can serve as a building block for more sophisticated quantization method that can utilize the "dtype" as a data representation. I'm planning to put up an example of static quant (with module swap) that could potentially help demonstrate how these other techniques (e.g. ones that require calibration etc.) can be implemented in similar ways. please feel free to work on a tutorial to show how a real world end to end quantization example looks like utilizing the "dtype" that we build with tensor subclass in this doc we also plan to build out hqq with this design #255, cc @HDCharles, this one also doesn't not require calibration though. |
But they are already importing the |
yeah, we are thinking of just removing these for now, it would be better for people to also see the docstrings for these things, and an extra import doesn't seem to be a big issue |
About subclasses: I hope there would still be way to (when needed) register custom fused kernels which do e.g. q-someop-dq in a fused way, without having a separate kernel launches for q and dq. I know this type of graph matching is possible with torch.compile, but I hope that the explicit introduction of subclasses (and seemingly mainly used for representational/expressiveness/dispatch purpose) will not make this more complicated. Also, hoping that it will work nicely with profiling/tracing to know exactly what kernel is getting invoked and exactly where any q/dq is happening (especially for autoquant regimes). This is kind of similar to what was originally done with quint8 dtype, right? (except now it will allow user-powered extension and dispatch is based on subclass type instead of dtype) |
yeah I think we should still be able to register inductor fusion passes, but one thing here is, q/dq ops are no longer large ops in the torch.compile path, we are planning to keep them as smaller aten ops (sub/mul etc.) so these can participate in normal inductor optimization directly, so the optimization story will be a bit different for inductor/torch.compile I think. However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in #434
yeah we can definitely provide additional information on what kernel is picked for autoquant, cc @HDCharles
yes, this is similar to quint8, except it's built in python with tensor subclasses extension point, this allows us to stay out of core and have faster iteration speed as well. for dispatch, I feel it could also continue to use dtype as well, after we sort out the dtype story: #442 |
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Based on the example, it seems like it would be the property of DTypeTensor that decides whether to use q-dq or not, right? |
So what I understand from this proposal, as far as wrapping LayoutTensor and DTypeTensor is concerned is that, A. Static quantization (both activation and weights are quantized) It is not clear how the proposed API addresses 1, but I presume you have ideas so I will assume it will work. Tensor subclass as I understand does/can do two things: 1) override representation of the tensor, e.g. linear.weight changed from torch.Tensor to DTypeTensor and 2) also change the dispatch behavior to dictate how an op with DTypeTensor should be executed. On the DTypeLayout: I feel that having each backend or kernel that has its own special layout for execution should be its own tensor subclass, however this can also result in proliferation, e.g. DTypeLayoutCUDA, DTypeLayoutCUDAMySecialPacking, DTypeLayoutMetalDefault etc. I actually liked PT2E workflow in this regard where representation was canonical and execution semantics, arising from weight packing etc, were done as a separate transform. If I were to think of the same here, then I would say for 4-bit there is DTypeTensor and DTypeDefaultLayout and subsequent transforms can replace the tensor subclass with their backend specific tensor subclass. Separate from above: For the comment on using q-dq based dispatch vs. fused op, I think we can allow overriding behavior where users can plugin their own implementation, including custom fused ops, for a specific DTypeTensor subclass that uses a specific DTypeLayout tensor. |
yeah this is correct
yeah working on an example for this right now
I should probably add more docs for this one, right now it's implemented by applying a ao/torchao/quantization/quant_api.py Lines 355 to 356 in a895699
LienarActQuantizedTensor , when dispatching to linear op, we'll apply the quantization function to input_quant_func to the input, and then continue the dispatch: ao/torchao/quantization/subclass.py Line 657 in a895699
ao/torchao/dtypes/affine_quantized_tensor.py Lines 550 to 554 in a895699
also I want to highlight that dynamic quant, static quant is not considered as purely a dtype problem, since this also involves flows (how to convert my model to use these quantized tensors?), I'm also working on giving more details/examples of how to do that as well.
yeah I think so, user should be able to customize what they would like to say by implementing a new LayoutTensor type I think, although I guess the difference here is user has to reason through different dispatch layers to figure out what is the final representation they will see in the end, like the dynamic quant example. |
@jerryzh168 please note that my questions/responses are not motivated by whether it works for executorch or not. My comment on canonical representation was to borrow the same concept from PT2E where quantization and execution of quantized ops are separated. In the current APIs proposed, it is not the case and thats what I was highlighting |
And this I mean for eager model not for export. Basically in exported graph there is a) quant and b) lowering. What is the equivalent of that in eager mode subclass based API and whether it is useful to have that |
Summary: Addressing feedback for `quantize` API from #391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
@kimishpatel, I see, yeah I think separation of quant and lowering makes sense for executorch stack, but for eager it is not really applicable, since in eager people would just expect to quantize a model and get acceleration, require eager mode use case to do an extra lowering step seems to change the UX for eager mode? what do you think? |
Status: Draft
Updated: 06/17/2024
Objective
In this doc we’ll talk about Tensor subclass based quantization API for modeling users and developers.
Modeling User API
Modeling users refer to people who use quantization APIs to quantize their model for speed up, memory saving, power saving etc. Our main goal for modeling user API is for that to be easy to use without the need to fully understand technical details.
We expect users to use two types of APIs: (1). Manual quantization API with a direct API call (2). Automatic quantization API based on some objectives. It looks like the following:
1. Manual API call
2. autoquantization
autoquant is a tool to automatically quantize the eligible layers with a type of quantization (int8 weight only, int8 dynamic quant, int4 weight only or new dtypes) based on performance for quantizing that individual layer. We’ll have APIs for people to add new dtypes to be searched in the tool.
Developer API
Developers could be people who are doing research to figure out the best quantization algorithm, or people who we supporting dtype for emerging hardwares.
Prerequisites
We are relying on tensor subclass (and also torch.compile) for our developer facing API, we'll update this section for more OSS available tutorials.
Some externally available resources:
Why Tensor Subclass?
There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things:
(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core
(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization
(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques
Example Code for a new Quantization Technique or DType
Please feel free to start with https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations.
Basic Structure
A tensor subclass needs to define a few basic methods:
__new__
,__init__
,__tensor_flatten__
,__tensor_unflatten__
and also dispatch functions for torch functions
__torch_function__
and aten ops__torch_dispatch__
Here is an example of basic structure:
Operator Support
There are two types of operator support, torch function and aten ops. For torch functions (e.g. torch.nn.functional.linear), we’ll need to overwrite
__torch_function__
callback in the Tensor subclass, for aten ops (e.g. torch.ops.aten.mm), we’ll need to overwrite__torch_dispatch__
callback function.For a new dtype, we’d like people to define the following decorator:
And we can implement the operator dispatch with the following:
What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are:
torch_function: torch.nn.functional.linear
torch_dispatch: torch.ops.aten.addmm.default, torch.ops.aten.mm.default, torch.ops.aten.detach.default, torch.ops.aten.t.default
You can also find the ops that can be overwritten in torch_function or torch_dispatch with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see
Optimized Operators
section for more details):We are still working on a table that talks about for each feature what are the operators that need to be supported.
Optimized Operators
Optimized operators for cpu/cuda/mps can be implemented through https://github.com/pytorch/ao/tree/main/torchao/csrc e.g. int4 cuda, and accessible through torch.ops.my_custom_op
For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in torch_function or torch_dispatch and dispatch to target operators, for example:
ao/torchao/dtypes/aqt.py
Lines 348 to 355 in cbc74ee
Packing/Layout
Sometimes the quantized weights has to be packed in order to yield optimal performance. For this we want to extend the “layout” concept in Tensor and introduce an indirection for tensor data storage, see #278 for more details.
Native tensors have a hardcoded list of selections of layout: https://github.com/pytorch/pytorch/blob/647815049ec28a72dc1bb6a977791927bba058d5/c10/core/Layout.h#L11, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout.
The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. And the extension of layout can be achieved at python level tensor subclasses without modifying C++ pytorch core code.
Here is an example (see notebook for full code):
Flow
After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.
For model level API, people can reuse torchao.quantize that allows people to apply a tensor subclass conversion to weight of linear, and allows filtering function: https://github.com/pytorch/ao/blob/aeee551b15eebeaabf98ffab9a00addc675a12a9/torchao/quantization/quant_api.py (TODO: replace this with torchao doc website link when that's ready)
See
Modeling User API
section for examples of weight only/dynamic quant/static quant model level APIs based on the factory function.Using torch.compile for Performance
Note: currently, we need to use the following:
In order to be compatible with torch.compile. To aim for performance optimization, we should run through torch.compile with fullgraph mode first, and remove any unnecessary graph breaks. You can add TORCH_LOGS=”output_code” when you run the script in order to see the inductor generated code. e.g.
TORCH_LOGS=”output_code” python example.py
Serialization
This test shows how we expect save/load to work for a model quantized with tensor subclass based API:
What's Next
The text was updated successfully, but these errors were encountered: