Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,99 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...


def linear_forward_8da8w(
x,
weight_int8,
scales,
zeros,
out_features,
precision,
):
from torchao.quantization.utils import per_token_dynamic_quant

x = per_token_dynamic_quant(x)
n_bit = 8
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel(
weight_int8,
scales,
zeros,
0,
quant_min,
quant_max,
torch.int8,
out_dtype=precision,
)
c = torch.nn.functional.linear(x, w_dq)

return c


class Int8DynActInt8WeightLinear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]

in_features: int
out_features: int
weight: torch.Tensor

"""
This module implements a dynamic quantized linear layer with int8 weight.
Weights are per channel quantized. Parameters of importance
precision: precision of input and output. e.g. torch.float32 means input
activation is float32 and output is float32.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
precision: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.precision = precision

if dtype is not None:
raise ValueError("Please specify 'precision' instead of 'dtype'")

# currently storing unpacked int8 weights
self.register_buffer(
"weight",
torch.empty((out_features, in_features), dtype=torch.int8),
)
self.register_buffer(
"scales",
torch.empty(
(out_features),
dtype=torch.float32,
),
)
self.register_buffer(
"zeros",
torch.empty(
(out_features),
dtype=torch.float32,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.precision)
return linear_forward_8da8w(
input,
self.weight,
self.scales,
self.zeros,
self.out_features,
self.precision,
)


#########################################################################
##### embedding table quantization ######

Expand Down
Loading