Skip to content

Commit

Permalink
Autoquant
Browse files Browse the repository at this point in the history
Summary: Adding autoquantization functionality, using hte do_quant api
we can test kernel speeds and pick the best quantization type (or no
quantization) for each layer.

Test Plan: python test/test.py -k "autoquant"

also tested on SAM and SDXL
pytorch-labs/segment-anything-fast#114
HDCharles/sdxl-fast@8d9942a

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 94089f74edf54f8e2122e91498b25306d322f3ab
Pull Request resolved: #38
  • Loading branch information
HDCharles committed Mar 19, 2024
1 parent 9c048eb commit f66419e
Show file tree
Hide file tree
Showing 8 changed files with 514 additions and 14 deletions.
34 changes: 23 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# torchao: PyTorch Architecture Optimization
# torchao: PyTorch Architecture Optimization

**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue**

The `torchao` package allows you to quantize and prune your models using native PyTorch.
The `torchao` package allows you to quantize and prune your models using native PyTorch.

The repo hosts both
1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4
Expand Down Expand Up @@ -38,31 +38,43 @@ pip install -e .

Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change.

### A8W8 Dynamic Quantization
### Autoquantization

```Python
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.

```python
import torch
from torchao.quantization import quant_api
import torchao

# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True
# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul
torch._inductor.config.use_mixed_mm

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_dqtensors(model)
# perform autoquantization
torchao.autoquant(model, (input))

# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
model(input)
```


### A8W8 Dynamic Quantization

```python
# convert linear modules to quantized linear modules
torchao.change_linear_weights_to_int8_dqtensors(model)
```

### A16W8 WeightOnly Quantization

```python
quant_api.change_linear_weights_to_int8_woqtensors(model)
torchao.change_linear_weights_to_int8_woqtensors(model)
```

This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
Expand All @@ -71,7 +83,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
### A16W4 WeightOnly Quantization

```python
quant_api.change_linear_weights_to_int4_woqtensors(model)
torchao.change_linear_weights_to_int4_woqtensors(model)
```

Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
Expand Down
Empty file added __init__.py
Empty file.
77 changes: 77 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn as nn
from torch._inductor.utils import run_and_get_code
from torch._dynamo import config
import torchao
from torch.ao.quantization import MinMaxObserver, QConfigMapping

from torchao.quantization.dynamic_quant import (
Expand Down Expand Up @@ -54,6 +55,13 @@
_fqn_to_op_to_shape_to_count,
LoggingTensorMode,
)
from torchao.quantization.autoquant import (
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3

)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os

Expand Down Expand Up @@ -880,6 +888,30 @@ def test_int8_weight_only_quant_subclass(self):
Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
)

def test_aq_int8_dynamic_quant_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
)

def test_aq_int8_weight_only_quant_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype
)

def test_aq_int8_weight_only_quant_2_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype
)

def test_aq_int8_weight_only_quant_3_subclass(self):
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self._test_lin_weight_subclass_impl(
AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype
)

def test_int4_weight_only_quant_subclass(self):
self._test_lin_weight_subclass_impl(
Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
Expand Down Expand Up @@ -1195,6 +1227,51 @@ def test_on_dummy_distilbert(self):
print("sqnr_pt_quant", sqnr_pt_quant)
self.assertTrue(sqnr_sq >= 8.0)

class TestAutoQuant(unittest.TestCase):
def test_autoquant_one_input(self):
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.use_mixed_mm = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._dynamo.config.automatic_dynamic_shapes = False

for m,k,n in [
(1, 1024, 1024),
(64, 1024, 1024),
(2**15, 1024, 1024),
(1, 1024, 4096),
(64, 1024, 4096),
(1, 4096, 1024),
(64, 4096, 1024),
(4096, 4096, 1024),
]:
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to("cuda").to(torch.bfloat16)
out = model(example_input)
torchao.autoquant(model, example_input)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

def test_autoquant_multi_input(self):
m1, m2, k, n = 1, 8, 1024, 1024
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).cuda().to(torch.bfloat16)
example_input = torch.randn(m1, k, device="cuda", dtype=torch.bfloat16)
example_input2 = torch.randn(m2, k, device="cuda", dtype=torch.bfloat16)
torchao.change_linears_to_autoquantizable(model)
out=model(example_input)
model(example_input2)
torchao.change_autoquantizable_to_quantized(model)
out2 = model(example_input)
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

if __name__ == "__main__":
unittest.main()
23 changes: 22 additions & 1 deletion torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
swap_conv2d_1x1_to_linear,
autoquant,
change_linears_to_autoquantizable,
change_autoquantizable_to_quantized,
)
from . import dtypes

__all__ = [
"dtypes"
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_int8_dqtensors",
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"dtypes"
]
3 changes: 3 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"dynamically_quantize_per_channel",
"dequantize_per_tensor",
"dequantize_per_channel",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"quant_int8_dynamic_linear",
"quant_int8_matmul",
"quant_int8_dynamic_per_token_linear",
Expand Down
Loading

0 comments on commit f66419e

Please sign in to comment.