From c8267ba9885e92159123077551fcc5f9edbf1976 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 24 Sep 2024 17:10:21 -0700 Subject: [PATCH] Add Int8DynActInt8WeightLinear module (#5605) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5605 Adding Int8DynActInt8WeightLinear for Per Channel DQ Linear Reviewed By: mergennachin Differential Revision: D63339550 --- .../llama2/source_transformation/quantize.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index da832f8285a..7ef51ac93c0 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -379,6 +379,99 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # return F.linear(input, self.weight.to(dtype=input.dtype)) * se... +def linear_forward_8da8w( + x, + weight_int8, + scales, + zeros, + out_features, + precision, +): + from torchao.quantization.utils import per_token_dynamic_quant + + x = per_token_dynamic_quant(x) + n_bit = 8 + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + w_dq = torch.ops.quantized_decomposed.dequantize_per_channel( + weight_int8, + scales, + zeros, + 0, + quant_min, + quant_max, + torch.int8, + out_dtype=precision, + ) + c = torch.nn.functional.linear(x, w_dq) + + return c + + +class Int8DynActInt8WeightLinear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + + in_features: int + out_features: int + weight: torch.Tensor + + """ + This module implements a dynamic quantized linear layer with int8 weight. + Weights are per channel quantized. Parameters of importance + precision: precision of input and output. e.g. torch.float32 means input + activation is float32 and output is float32. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.precision = precision + + if dtype is not None: + raise ValueError("Please specify 'precision' instead of 'dtype'") + + # currently storing unpacked int8 weights + self.register_buffer( + "weight", + torch.empty((out_features, in_features), dtype=torch.int8), + ) + self.register_buffer( + "scales", + torch.empty( + (out_features), + dtype=torch.float32, + ), + ) + self.register_buffer( + "zeros", + torch.empty( + (out_features), + dtype=torch.float32, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(self.precision) + return linear_forward_8da8w( + input, + self.weight, + self.scales, + self.zeros, + self.out_features, + self.precision, + ) + + ######################################################################### ##### embedding table quantization ######