Skip to content

Commit 7e69ee3

Browse files
authored
Example for GPTQ-like calibration flow (#721)
* Example for GPTQ-like calibration flow Summary: This PR adds example for GPTQ-like calibration flow, where we (1) optimize (quantize) one module at a time (2) with each optimization step, we need to get a set of all calibration data (3) the output of each module is calculated based on the optimized (quantized) module, and then pass down to the next module In this tutorial we mainly use two things: (1) MultiTensor subclass https://gist.github.com/HDCharles/a1b575bbf8875f994af8a01b225e1227 (2) Module forward hooks Potential use cases: GPTQ, Auto-Round(#581), SpinQuant (according to Tijmen) Test Plan: python tutorials/calibration_flow/gptq_like.py Reviewers: Subscribers: Tasks: Tags: * simplified example and add some docs for linear activation quantized tensor
1 parent 614c667 commit 7e69ee3

File tree

2 files changed

+275
-2
lines changed

2 files changed

+275
-2
lines changed

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,17 @@
2020

2121
class LinearActivationQuantizedTensor(TorchAOBaseTensor):
2222
"""
23-
Applies activation quantization for linear operator
23+
Applies activation quantization for linear operator, this is used to support
24+
dynamic quantization or static quantization, user can pass in a `input_quant_func`
25+
that is used to quantize the activation
26+
27+
Args:
28+
`original_weight_tensor`: the weight tensor, if weight need to be quantized as well, we'd need
29+
to apply quantization to weight first, e.g. for int8 dynamic activation int8 weight quantization
30+
we will first apply int8 quantization to weight and then apply LinearActivationQuantizedTensor
31+
on top of it
32+
`input_quant_func` (Callable[[torch.Tensor], torch.Tensor]): a function that takes a high precision floating point tensor and returns
33+
a quantized tensor, this is used to quantize input
2434
"""
2535
def __new__(
2636
cls,
@@ -38,7 +48,7 @@ def __new__(
3848
def __init__(
3949
self,
4050
original_weight_tensor: torch.Tensor,
41-
input_quant_func: Callable,
51+
input_quant_func: Callable[[torch.Tensor], torch.Tensor],
4252
):
4353
self.original_weight_tensor = original_weight_tensor
4454
self.input_quant_func = input_quant_func
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
"""
2+
Traditional calibration flow has the following flow (see static_quant.py for code examples):
3+
4+
(1). insert input/output observers to the modules
5+
(2). run the model with calibration data so the observers in the model can record the statistics of the data flowing through them, observation does not change the output of a layer
6+
(3). convert the observed module to quantized module (or quantize the weights with the quantization parameters based on the observer statistics)
7+
8+
By GPTQ like flow we mean a flow that does not fit into the above flow very well because
9+
(1) optimize (quantize) one layer (module) at a time and the the output of each layer is calculated based on the optimized (quantized) module, and then pass down to the next layer, this means layers are not independent
10+
(2) with each optimization step, we need to use all the input data for that layer instead of just some derived statistics like min_val/max_val
11+
12+
To use the traditional flow, we'd need to
13+
(1). insert observers for the layer we want to optimize, that will record all the inputs
14+
(2). each time, run the entire model upto layer N, then optimize layer N, and then
15+
continue the process for layer N+1, this means we'll need to run O(N^2) layers in total.
16+
17+
So we'd like to use a flow that only runs each layer a constant time so we get O(N) time complexity.
18+
19+
In this tutorial we mainly use two things:
20+
(1) MultiTensor subclass https://gist.github.com/HDCharles/a1b575bbf8875f994af8a01b225e1227
21+
22+
It stores a list of Tensors (calibration data). This is used to pass around all the calibration data to a layer, we can optimize the layer, and then output another MultiTensor object for future layers.
23+
24+
(2) Module forward pre hooks (https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook)
25+
26+
This is used for modifying the behavior of the forward function of `module`, it allows [modification](https://discuss.pytorch.org/t/use-forward-pre-hook-to-modify-nn-module-parameters/108498/2) of the module itself, and also allows modifying the input of the module.
27+
28+
This can be used when we try to optimize (quantize) the layer, and then want the next layer to consume the output of the optimized layer directly.
29+
"""
30+
import torch
31+
import torch.nn as nn
32+
from torch.utils._pytree import tree_flatten, tree_unflatten
33+
import gc
34+
from typing import Tuple, Dict, Any
35+
from torchao.quantization.utils import compute_error
36+
from torchao.dtypes import to_affine_quantized_static
37+
from torchao.quantization import quantize_
38+
from torchao.quantization import to_linear_activation_quantized
39+
from torchao.quantization import LinearActivationQuantizedTensor
40+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
41+
from torchao.quantization.observer import (
42+
AffineQuantizedMinMaxObserver,
43+
PerTensor,
44+
)
45+
from torchao.quantization.quant_primitives import (
46+
MappingType,
47+
fake_quantize_affine,
48+
)
49+
50+
torch.manual_seed(0)
51+
52+
class MultiTensor(torch.Tensor):
53+
@staticmethod
54+
def __new__(cls, input, **kwargs):
55+
if isinstance(input, (list, tuple)):
56+
input = input[0]
57+
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
58+
shape = kwargs.pop("shape", input.shape)
59+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
60+
61+
def __init__(self, input, **kwargs):
62+
self.values = []
63+
self.count = 0
64+
self.add_tensors(input)
65+
self.debug = True
66+
67+
def __repr__(self):
68+
return (
69+
f"{self.__class__.__name__}(data={self.values})"
70+
)
71+
72+
def __iter__(self):
73+
for v in self.values:
74+
yield v
75+
76+
def add_tensors(self, input):
77+
if isinstance(input, (tuple, list)):
78+
for inp in input:
79+
self.add_tensors(inp)
80+
else:
81+
assert isinstance(input, torch.Tensor), f"MultiTensor can only use add_tensors for Tensors or lists of tensors but got {type(input)}"
82+
self.count += 1
83+
self.values.append(input)
84+
return self
85+
86+
def pad_to_length(self, length):
87+
if self.count > length:
88+
return self
89+
self.add_tensors([self.values[-1]]*(length-self.count))
90+
return self
91+
92+
@classmethod
93+
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_gptq=False):
94+
def flat_to_grouped(flat):
95+
# size of biggest MultiTensor
96+
multi_tensor_size = max(
97+
[x.count if isinstance(x, MultiTensor) else 1 for x in flat]
98+
)
99+
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
100+
grouped = list(
101+
zip(
102+
*[
103+
x.pad_to_length(multi_tensor_size).values if isinstance(x, MultiTensor) else [x] * multi_tensor_size for x in flat]
104+
)
105+
)
106+
return grouped
107+
108+
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
109+
# where A is nontensor, b's,c's are tensors
110+
def grouped_to_flat(grouped):
111+
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)]
112+
flat_tups = list(zip(*grouped))
113+
# convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
114+
flattened = [
115+
cls(tup).cpu() if isinstance(tup[0], torch.Tensor) else tup[0] for tup in flat_tups
116+
]
117+
# need to check that getting rid of all but one from each nonTensor tuple is OK
118+
non_tensors_equal=min([True]+[
119+
min([True]+[ # handle situation where tuples have size 0
120+
tup[0]==x for x in tup # check all elements match
121+
]) for tup in flat_tups if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors
122+
])
123+
return flattened, non_tensors_equal
124+
125+
kwargs = {} if kwargs is None else kwargs
126+
# combine args and kwargs and remove lists and tuples
127+
flat_args, spec = tree_flatten((args, kwargs))
128+
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
129+
grouped_args = flat_to_grouped(flat_args)
130+
# run function for each of the multitensors and return a multitensor
131+
outputs = []
132+
with torch._C.DisableTorchFunctionSubclass():
133+
for inp in grouped_args:
134+
# inp = tensors_to_cuda(inp)
135+
cur_args, cur_kwargs = tree_unflatten(inp, spec)
136+
out = func(*cur_args, **cur_kwargs)
137+
# outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
138+
outputs.append(out)
139+
grouped_outputs = [tree_flatten(x)[0] for x in outputs]
140+
out_spec = tree_flatten(outputs[0])[1]
141+
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
142+
flat_outputs, non_tensors_equal = grouped_to_flat(grouped_outputs)
143+
assert non_tensors_equal, (
144+
f"ERR: found a function in model: {func} which "
145+
+"caused an error in MultiInput, the function dispatch only works for functions"
146+
+" with Tensor outputs or that have the same non-Tensor output value for all across all inputs"
147+
)
148+
return tree_unflatten(flat_outputs, out_spec)
149+
150+
@classmethod
151+
def __torch_dispatch__(cls, func, types, args=(), kwargs={}, skip_gptq=False):
152+
pass
153+
154+
def __tensor_flatten__(self):
155+
return ["values"], None
156+
157+
@classmethod
158+
def __tensor_unflatten__(
159+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
160+
):
161+
return cls(tensor_data_dict["values"])
162+
163+
class M(torch.nn.Module):
164+
def __init__(self):
165+
super().__init__()
166+
self.linear = torch.nn.Linear(64, 128)
167+
168+
def forward(self, x):
169+
x = self.linear(x)
170+
return x
171+
172+
def _is_linear(mod, fqn):
173+
return isinstance(mod, torch.nn.Linear)
174+
175+
# Adapted from https://github.com/pytorch/ao/pull/581
176+
def prepare_model_for_optimization_(model):
177+
def forward_pre_hook(
178+
module,
179+
args: Tuple[MultiTensor],
180+
kwargs: Dict[str, Any],
181+
):
182+
# remove the hook to avoid recursive calls
183+
module._forward_pre_hook_handle.remove()
184+
# we'll have a single MultiTensor as argument, that contains a list of activation Tensors
185+
# from previous layer
186+
187+
# we can use the MultiTensor to calculate the quantization parameters for each input Tensor
188+
act_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
189+
for inp in args[0]:
190+
act_obs(inp)
191+
192+
input_scale, input_zp = act_obs.calculate_qparams()
193+
194+
# we can optimize/modify the module here
195+
module.input_scale = input_scale
196+
module.input_zp = input_zp
197+
198+
# rerun the module with quantized and dequantized inputs
199+
new_input = []
200+
for inp in args[0]:
201+
new_input.append(fake_quantize_affine(inp, inp.shape, input_scale, input_zp, torch.uint8))
202+
203+
mt = MultiTensor(new_input)
204+
205+
# tuple of modified args and kwargs
206+
return ((mt,), {})
207+
208+
def _register_forward_pre_hook(module: torch.nn.Module):
209+
"""Adds a forward pre hook for the module, that runs before module.forward is run that can
210+
modify the module and the input of the module
211+
docs for `module.register_forward_pre_hook` can be found in https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook
212+
"""
213+
forward_pre_hook_handle = module.register_forward_pre_hook(
214+
forward_pre_hook, with_kwargs=True
215+
)
216+
module._forward_pre_hook_handle = forward_pre_hook_handle
217+
return module
218+
219+
_replace_with_custom_fn_if_matches_filter(
220+
model, _register_forward_pre_hook, _is_linear
221+
)
222+
223+
# using a function to align with the API in quant_api
224+
def apply_activation_static_quant():
225+
def _apply_activation_static_quant(observed_linear):
226+
target_dtype = torch.uint8
227+
228+
# we can quantize the weight here as well
229+
230+
# activation quantization
231+
act_scale, act_zero_point = observed_linear.input_scale, observed_linear.input_zp
232+
input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype)
233+
observed_linear.weight = torch.nn.Parameter(to_linear_activation_quantized(observed_linear.weight, input_quant_func), requires_grad=False)
234+
235+
del observed_linear.input_scale
236+
del observed_linear.input_zp
237+
return observed_linear
238+
239+
return _apply_activation_static_quant
240+
241+
242+
example_inputs = (torch.randn(32, 64),)
243+
m = M().eval()
244+
before_quant = m(*example_inputs)
245+
prepare_model_for_optimization_(m)
246+
inputs = []
247+
for _ in range(10):
248+
inputs.append(torch.randn(32, 64))
249+
250+
mt_input = MultiTensor(inputs)
251+
252+
out = m(mt_input)
253+
254+
# just quantizing activation since we only observed quantization, this could be extended to support
255+
# quantizing weight as well
256+
quantize_(m, apply_activation_static_quant(), _is_linear)
257+
for l in m.modules():
258+
if isinstance(l, torch.nn.Linear):
259+
assert isinstance(l.weight, LinearActivationQuantizedTensor)
260+
261+
after_quant = m(*example_inputs)
262+
print("sqnr:", compute_error(before_quant, after_quant))
263+
assert compute_error(before_quant, after_quant) > 35

0 commit comments

Comments
 (0)