Skip to content

Commit

Permalink
Autoquant
Browse files Browse the repository at this point in the history
Summary: Adding autoquantization functionality, using hte do_quant api
we can test kernel speeds and pick the best quantization type (or no
quantization) for each layer.

Test Plan: python test/test.py -k "autoquant"

also tested on SAM and SDXL
pytorch-labs/segment-anything-fast#114
HDCharles/sdxl-fast@8d9942a

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: fddbaf2c203a1745e8a84980f778c45162576cbc
Pull Request resolved: #38
  • Loading branch information
HDCharles committed Mar 19, 2024
1 parent 969038f commit 71300c3
Show file tree
Hide file tree
Showing 8 changed files with 514 additions and 13 deletions.
41 changes: 31 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,50 @@ The following apis use quantized [tensor subclasses](https://pytorch.org/docs/st

This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.

### A8W8 Dynamic Quantization
### Autoquantization

The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.

Example
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.

```
import torch
from torchao.quantization import quant_api
import torchao
# inductor settings which improve torch.compile runtime for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul
torch._inductor.config.use_mixed_mm
# some user model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_dqtensors(model)
# perform autoquantization
torchao.autoquant(model, (input))
# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
model(input)
```


### A8W8 Dynamic Quantization

The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.

Example

```
# some user model and example input
...
# convert linear modules to quantized linear modules
torchao.change_linear_weights_to_int8_dqtensors(model)
# compile the model to improve performance
...
```

This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.


Expand All @@ -81,7 +102,7 @@ Example
...
# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_woqtensors(model)
torchao.change_linear_weights_to_int8_woqtensors(model)
# compile the model to improve performance
...
Expand All @@ -102,7 +123,7 @@ Example
...
# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int4_woqtensors(model)
torchao.change_linear_weights_to_int4_woqtensors(model)
# compile the model to improve performance
...
Expand Down
Empty file added __init__.py
Empty file.
66 changes: 66 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
_replace_with_custom_fn_if_matches_filter,
do_autoquant
)
from torchao.quantization.quant_primitives import (
dequantize_per_channel,
Expand Down Expand Up @@ -54,6 +55,13 @@
_fqn_to_op_to_shape_to_count,
LoggingTensorMode,
)
from torchao.quantization.autoquant import (
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3

)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os

Expand Down Expand Up @@ -880,6 +888,36 @@ def test_int8_weight_only_quant_subclass(self):
Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
)

def test_aq_int8_dynamic_quant_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
)

def test_aq_int8_weight_only_quant_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
)

def test_aq_int8_weight_only_quant_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype
)

def test_aq_int8_weight_only_quant_2_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype
)

def test_aq_int8_weight_only_quant_3_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype
)

def test_int4_weight_only_quant_subclass(self):
self._test_lin_weight_subclass_impl(
Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
Expand Down Expand Up @@ -1195,6 +1233,34 @@ def test_on_dummy_distilbert(self):
print("sqnr_pt_quant", sqnr_pt_quant)
self.assertTrue(sqnr_sq >= 8.0)

class TestAutoQuant(unittest.TestCase):
def test_autoquant(self):
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.use_mixed_mm = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._dynamo.config.automatic_dynamic_shapes = False

for m,k,n in [
(1, 1024, 1024),
(64, 1024, 1024),
(2**15, 1024, 1024),
(1, 1024, 4096),
(64, 1024, 4096),
(1, 4096, 1024),
(64, 4096, 1024),
(4096, 4096, 1024),
]:
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to("cuda").to(torch.bfloat16)
out = model(example_input)
do_autoquant(model, example_input)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

if __name__ == "__main__":
unittest.main()
24 changes: 24 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
swap_conv2d_1x1_to_linear,
autoquant,
change_linears_to_autoquantizable,
change_autoquantizable_to_quantized,
)

__all__ = [
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_int8_dqtensors",
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
]
3 changes: 3 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"dynamically_quantize_per_channel",
"dequantize_per_tensor",
"dequantize_per_channel",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"quant_int8_dynamic_linear",
"quant_int8_matmul",
"quant_int8_dynamic_per_token_linear",
Expand Down
Loading

0 comments on commit 71300c3

Please sign in to comment.