|
| 1 | +""" |
| 2 | +Traditional calibration flow has the following flow (see static_quant.py for code examples): |
| 3 | +
|
| 4 | +(1). insert input/output observers to the modules |
| 5 | +(2). run the model with calibration data so the observers in the model can record the statistics of the data flowing through them, observation does not change the output of a layer |
| 6 | +(3). convert the observed module to quantized module (or quantize the weights with the quantization parameters based on the observer statistics) |
| 7 | +
|
| 8 | +By GPTQ like flow we mean a flow that does not fit into the above flow very well because |
| 9 | +(1) optimize (quantize) one layer (module) at a time and the the output of each layer is calculated based on the optimized (quantized) module, and then pass down to the next layer, this means layers are not independent |
| 10 | +(2) with each optimization step, we need to use all the input data for that layer instead of just some derived statistics like min_val/max_val |
| 11 | +
|
| 12 | +To use the traditional flow, we'd need to |
| 13 | +(1). insert observers for the layer we want to optimize, that will record all the inputs |
| 14 | +(2). each time, run the entire model upto layer N, then optimize layer N, and then |
| 15 | +continue the process for layer N+1, this means we'll need to run O(N^2) layers in total. |
| 16 | +
|
| 17 | +So we'd like to use a flow that only runs each layer a constant time so we get O(N) time complexity. |
| 18 | +
|
| 19 | +In this tutorial we mainly use two things: |
| 20 | +(1) MultiTensor subclass https://gist.github.com/HDCharles/a1b575bbf8875f994af8a01b225e1227 |
| 21 | +
|
| 22 | +It stores a list of Tensors (calibration data). This is used to pass around all the calibration data to a layer, we can optimize the layer, and then output another MultiTensor object for future layers. |
| 23 | +
|
| 24 | +(2) Module forward pre hooks (https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook) |
| 25 | +
|
| 26 | +This is used for modifying the behavior of the forward function of `module`, it allows [modification](https://discuss.pytorch.org/t/use-forward-pre-hook-to-modify-nn-module-parameters/108498/2) of the module itself, and also allows modifying the input of the module. |
| 27 | +
|
| 28 | +This can be used when we try to optimize (quantize) the layer, and then want the next layer to consume the output of the optimized layer directly. |
| 29 | +""" |
| 30 | +import torch |
| 31 | +import torch.nn as nn |
| 32 | +from torch.utils._pytree import tree_flatten, tree_unflatten |
| 33 | +import gc |
| 34 | +from typing import Tuple, Dict, Any |
| 35 | +from torchao.quantization.utils import compute_error |
| 36 | +from torchao.dtypes import to_affine_quantized_static |
| 37 | +from torchao.quantization import quantize_ |
| 38 | +from torchao.quantization import to_linear_activation_quantized |
| 39 | +from torchao.quantization import LinearActivationQuantizedTensor |
| 40 | +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
| 41 | +from torchao.quantization.observer import ( |
| 42 | + AffineQuantizedMinMaxObserver, |
| 43 | + PerTensor, |
| 44 | +) |
| 45 | +from torchao.quantization.quant_primitives import ( |
| 46 | + MappingType, |
| 47 | + fake_quantize_affine, |
| 48 | +) |
| 49 | + |
| 50 | +torch.manual_seed(0) |
| 51 | + |
| 52 | +class MultiTensor(torch.Tensor): |
| 53 | + @staticmethod |
| 54 | + def __new__(cls, input, **kwargs): |
| 55 | + if isinstance(input, (list, tuple)): |
| 56 | + input = input[0] |
| 57 | + kwargs["dtype"]=kwargs.get("dtype", input.dtype) |
| 58 | + shape = kwargs.pop("shape", input.shape) |
| 59 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) |
| 60 | + |
| 61 | + def __init__(self, input, **kwargs): |
| 62 | + self.values = [] |
| 63 | + self.count = 0 |
| 64 | + self.add_tensors(input) |
| 65 | + self.debug = True |
| 66 | + |
| 67 | + def __repr__(self): |
| 68 | + return ( |
| 69 | + f"{self.__class__.__name__}(data={self.values})" |
| 70 | + ) |
| 71 | + |
| 72 | + def __iter__(self): |
| 73 | + for v in self.values: |
| 74 | + yield v |
| 75 | + |
| 76 | + def add_tensors(self, input): |
| 77 | + if isinstance(input, (tuple, list)): |
| 78 | + for inp in input: |
| 79 | + self.add_tensors(inp) |
| 80 | + else: |
| 81 | + assert isinstance(input, torch.Tensor), f"MultiTensor can only use add_tensors for Tensors or lists of tensors but got {type(input)}" |
| 82 | + self.count += 1 |
| 83 | + self.values.append(input) |
| 84 | + return self |
| 85 | + |
| 86 | + def pad_to_length(self, length): |
| 87 | + if self.count > length: |
| 88 | + return self |
| 89 | + self.add_tensors([self.values[-1]]*(length-self.count)) |
| 90 | + return self |
| 91 | + |
| 92 | + @classmethod |
| 93 | + def __torch_function__(cls, func, types, args=(), kwargs=None, skip_gptq=False): |
| 94 | + def flat_to_grouped(flat): |
| 95 | + # size of biggest MultiTensor |
| 96 | + multi_tensor_size = max( |
| 97 | + [x.count if isinstance(x, MultiTensor) else 1 for x in flat] |
| 98 | + ) |
| 99 | + # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]] |
| 100 | + grouped = list( |
| 101 | + zip( |
| 102 | + *[ |
| 103 | + x.pad_to_length(multi_tensor_size).values if isinstance(x, MultiTensor) else [x] * multi_tensor_size for x in flat] |
| 104 | + ) |
| 105 | + ) |
| 106 | + return grouped |
| 107 | + |
| 108 | + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] |
| 109 | + # where A is nontensor, b's,c's are tensors |
| 110 | + def grouped_to_flat(grouped): |
| 111 | + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)] |
| 112 | + flat_tups = list(zip(*grouped)) |
| 113 | + # convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] |
| 114 | + flattened = [ |
| 115 | + cls(tup).cpu() if isinstance(tup[0], torch.Tensor) else tup[0] for tup in flat_tups |
| 116 | + ] |
| 117 | + # need to check that getting rid of all but one from each nonTensor tuple is OK |
| 118 | + non_tensors_equal=min([True]+[ |
| 119 | + min([True]+[ # handle situation where tuples have size 0 |
| 120 | + tup[0]==x for x in tup # check all elements match |
| 121 | + ]) for tup in flat_tups if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors |
| 122 | + ]) |
| 123 | + return flattened, non_tensors_equal |
| 124 | + |
| 125 | + kwargs = {} if kwargs is None else kwargs |
| 126 | + # combine args and kwargs and remove lists and tuples |
| 127 | + flat_args, spec = tree_flatten((args, kwargs)) |
| 128 | + # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]] |
| 129 | + grouped_args = flat_to_grouped(flat_args) |
| 130 | + # run function for each of the multitensors and return a multitensor |
| 131 | + outputs = [] |
| 132 | + with torch._C.DisableTorchFunctionSubclass(): |
| 133 | + for inp in grouped_args: |
| 134 | + # inp = tensors_to_cuda(inp) |
| 135 | + cur_args, cur_kwargs = tree_unflatten(inp, spec) |
| 136 | + out = func(*cur_args, **cur_kwargs) |
| 137 | + # outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out) |
| 138 | + outputs.append(out) |
| 139 | + grouped_outputs = [tree_flatten(x)[0] for x in outputs] |
| 140 | + out_spec = tree_flatten(outputs[0])[1] |
| 141 | + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] |
| 142 | + flat_outputs, non_tensors_equal = grouped_to_flat(grouped_outputs) |
| 143 | + assert non_tensors_equal, ( |
| 144 | + f"ERR: found a function in model: {func} which " |
| 145 | + +"caused an error in MultiInput, the function dispatch only works for functions" |
| 146 | + +" with Tensor outputs or that have the same non-Tensor output value for all across all inputs" |
| 147 | + ) |
| 148 | + return tree_unflatten(flat_outputs, out_spec) |
| 149 | + |
| 150 | + @classmethod |
| 151 | + def __torch_dispatch__(cls, func, types, args=(), kwargs={}, skip_gptq=False): |
| 152 | + pass |
| 153 | + |
| 154 | + def __tensor_flatten__(self): |
| 155 | + return ["values"], None |
| 156 | + |
| 157 | + @classmethod |
| 158 | + def __tensor_unflatten__( |
| 159 | + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride |
| 160 | + ): |
| 161 | + return cls(tensor_data_dict["values"]) |
| 162 | + |
| 163 | +class M(torch.nn.Module): |
| 164 | + def __init__(self): |
| 165 | + super().__init__() |
| 166 | + self.linear = torch.nn.Linear(64, 128) |
| 167 | + |
| 168 | + def forward(self, x): |
| 169 | + x = self.linear(x) |
| 170 | + return x |
| 171 | + |
| 172 | +def _is_linear(mod, fqn): |
| 173 | + return isinstance(mod, torch.nn.Linear) |
| 174 | + |
| 175 | +# Adapted from https://github.com/pytorch/ao/pull/581 |
| 176 | +def prepare_model_for_optimization_(model): |
| 177 | + def forward_pre_hook( |
| 178 | + module, |
| 179 | + args: Tuple[MultiTensor], |
| 180 | + kwargs: Dict[str, Any], |
| 181 | + ): |
| 182 | + # remove the hook to avoid recursive calls |
| 183 | + module._forward_pre_hook_handle.remove() |
| 184 | + # we'll have a single MultiTensor as argument, that contains a list of activation Tensors |
| 185 | + # from previous layer |
| 186 | + |
| 187 | + # we can use the MultiTensor to calculate the quantization parameters for each input Tensor |
| 188 | + act_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32) |
| 189 | + for inp in args[0]: |
| 190 | + act_obs(inp) |
| 191 | + |
| 192 | + input_scale, input_zp = act_obs.calculate_qparams() |
| 193 | + |
| 194 | + # we can optimize/modify the module here |
| 195 | + module.input_scale = input_scale |
| 196 | + module.input_zp = input_zp |
| 197 | + |
| 198 | + # rerun the module with quantized and dequantized inputs |
| 199 | + new_input = [] |
| 200 | + for inp in args[0]: |
| 201 | + new_input.append(fake_quantize_affine(inp, inp.shape, input_scale, input_zp, torch.uint8)) |
| 202 | + |
| 203 | + mt = MultiTensor(new_input) |
| 204 | + |
| 205 | + # tuple of modified args and kwargs |
| 206 | + return ((mt,), {}) |
| 207 | + |
| 208 | + def _register_forward_pre_hook(module: torch.nn.Module): |
| 209 | + """Adds a forward pre hook for the module, that runs before module.forward is run that can |
| 210 | + modify the module and the input of the module |
| 211 | + docs for `module.register_forward_pre_hook` can be found in https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook |
| 212 | + """ |
| 213 | + forward_pre_hook_handle = module.register_forward_pre_hook( |
| 214 | + forward_pre_hook, with_kwargs=True |
| 215 | + ) |
| 216 | + module._forward_pre_hook_handle = forward_pre_hook_handle |
| 217 | + return module |
| 218 | + |
| 219 | + _replace_with_custom_fn_if_matches_filter( |
| 220 | + model, _register_forward_pre_hook, _is_linear |
| 221 | + ) |
| 222 | + |
| 223 | +# using a function to align with the API in quant_api |
| 224 | +def apply_activation_static_quant(): |
| 225 | + def _apply_activation_static_quant(observed_linear): |
| 226 | + target_dtype = torch.uint8 |
| 227 | + |
| 228 | + # we can quantize the weight here as well |
| 229 | + |
| 230 | + # activation quantization |
| 231 | + act_scale, act_zero_point = observed_linear.input_scale, observed_linear.input_zp |
| 232 | + input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype) |
| 233 | + observed_linear.weight = torch.nn.Parameter(to_linear_activation_quantized(observed_linear.weight, input_quant_func), requires_grad=False) |
| 234 | + |
| 235 | + del observed_linear.input_scale |
| 236 | + del observed_linear.input_zp |
| 237 | + return observed_linear |
| 238 | + |
| 239 | + return _apply_activation_static_quant |
| 240 | + |
| 241 | + |
| 242 | +example_inputs = (torch.randn(32, 64),) |
| 243 | +m = M().eval() |
| 244 | +before_quant = m(*example_inputs) |
| 245 | +prepare_model_for_optimization_(m) |
| 246 | +inputs = [] |
| 247 | +for _ in range(10): |
| 248 | + inputs.append(torch.randn(32, 64)) |
| 249 | + |
| 250 | +mt_input = MultiTensor(inputs) |
| 251 | + |
| 252 | +out = m(mt_input) |
| 253 | + |
| 254 | +# just quantizing activation since we only observed quantization, this could be extended to support |
| 255 | +# quantizing weight as well |
| 256 | +quantize_(m, apply_activation_static_quant(), _is_linear) |
| 257 | +for l in m.modules(): |
| 258 | + if isinstance(l, torch.nn.Linear): |
| 259 | + assert isinstance(l.weight, LinearActivationQuantizedTensor) |
| 260 | + |
| 261 | +after_quant = m(*example_inputs) |
| 262 | +print("sqnr:", compute_error(before_quant, after_quant)) |
| 263 | +assert compute_error(before_quant, after_quant) > 35 |
0 commit comments