Skip to content

Commit

Permalink
Adding uint4 dtype implementation
Browse files Browse the repository at this point in the history
Summary:
We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core.
This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this

Test Plan:
python test/dtypes/test_int4.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: aa80ed3ab3fb73db5aa985f06cad55f4f1e3c0b0
Pull Request resolved: #13
  • Loading branch information
jerryzh168 committed Dec 7, 2023
1 parent e16898d commit 0c72951
Show file tree
Hide file tree
Showing 3 changed files with 452 additions and 0 deletions.
292 changes: 292 additions & 0 deletions test/dtypes/test_int4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import torch
from torchao.dtypes.int4 import UInt4Tensor
import unittest
from unittest import TestCase, main
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch._export import dynamic_dim
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
)
from torchao.quantization.utils import (
compute_error,
)
from torchao.quantization.quant_api import (
replace_with_custom_fn_if_matches_filter,
)
from torch import nn
import copy

def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps

# get min and max
min_val, max_val = torch.aminmax(x, dim=1)

# calculate scale and zero point based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device

# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scale is the same dtype as the original tensor
scale = torch.clamp(scale, min=eps).to(x.dtype)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

# quantize based on qmin/qmax/scale/zp
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x.transpose(0, 1) / scale
x_round = torch.round(x_div)
x_zp = x_round + zero_point
x_zp = x_zp.transpose(0, 1)
quant = torch.clamp(x_zp, quant_min, quant_max)
if target_dtype == "int4":
quant = UInt4Tensor.from_unpacked(quant.to(torch.uint8)).view(quant.size())
else:
quant = quant.to(target_dtype)

return quant, scale, zero_point

class _WeightOnlyInt4QuantLinear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
w_int4 = kwargs.pop("w_int4")
scales = kwargs.pop("scales")
super().__init__(*args, **kwargs)
self.w_int4 = w_int4
self.scales = scales

def forward(self, x):
# if len(x.shape)<=2:
# y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales
# else: # turn x into 2d tensor, then undo it for y
x_view = x.view(-1, x.shape[-1])
y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales
y = y.reshape(*x.shape[:-1], -1)
if self.bias is not None:
y += self.bias
return y

@classmethod
def from_float(cls, mod):
w_fp32 = mod.weight
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
w_fp32, 0, 15, "int4"
)
# create the new module with a toy size to ensure initialization is fast
fake_in_features, fake_out_features = 8, 8
new_mod = cls(
fake_in_features,
fake_out_features,
bias=mod.bias is not None,
w_int4=w_int4.t().contiguous(),
scales=scales,
)
new_mod.in_features = mod.in_features
new_mod.out_features = mod.out_features
del new_mod.weight
new_mod.bias = mod.bias
device_to_use = next(mod.parameters()).device
new_mod.to(device_to_use)
return new_mod

def _apply_weight_only_int4_quant(model):
replace_with_custom_fn_if_matches_filter(
model,
_WeightOnlyInt4QuantLinear.from_float,
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)

class TestInt4(QuantizationTestCase):
def test_basic_tensor_ops(self):
x = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertTrue(x.shape, (3, 8))
# making sure these works
x.to(torch.uint8)
expected = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertTrue(x[0:1, :] == expected)
expected = UInt4Tensor(torch.tensor([
[0x23, 0x45],
[0x23, 0x45],
[0x23, 0x45],
], dtype=torch.uint8))
self.assertTrue(x[:, 2:6] == expected)

def test_gpu_quant(self):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(4, 16))
y_ref = m(x)
_apply_weight_only_int4_quant(m)
y_wo = m(x)
# sqnr = compute_error(y_ref, y_wo)
opt = torch.compile(m, mode="max-autotune")
# make sure it runs
opt(x)

def test_aten_ir(self):
from torch.library import Library, impl
test_lib = Library("test_int4", "DEF")
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
def quantize_per_tensor_int4(
input: torch.Tensor,
scale: float,
zero_point: int,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)

test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
def dequantize_per_tensor_int4(
input: torch.Tensor,
scale: float,
zero_point: int,
) -> torch.Tensor:
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale

# class QuantizePerTensorUInt4(torch.autograd.Function):
# @staticmethod
# def forward(
# ctx,
# input: torch.Tensor,
# scale: float,
# zero_point: int,
# ) -> torch.Tensor:
# inv_scale = 1.0 / scale
# return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8))

# class DeQuantizePerTensorUInt4(torch.autograd.Function):
# @staticmethod
# def forward(
# ctx,
# input: torch.Tensor,
# scale: float,
# zero_point: int,
# ) -> torch.Tensor:
# return (input.to(torch.float32) - zero_point) * scale

class M(torch.nn.Module):
def forward(self, x, y):
return x + y

example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),)
m = M().eval()
m = capture_pre_autograd_graph(m, example_inputs)
for n in m.graph.nodes:
if n.target == torch.ops.aten.add.Tensor:
with m.graph.inserting_before(n):
q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {})
dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {})
n.replace_input_with(n.args[0], dq)
m.recompile()

# TODO: need more extension points from quant flow side
@unittest.skip("need more extension points from quant flow side")
def test_pt2e_quant(self):
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)

class Int4ActQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
int4_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-2**3,
quant_max=2**3 - 1,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
int8_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
quantization_config = QuantizationConfig(
input_activation=int8_qspec,
weight=int4_qspec,
bias=None,
output_activation=int8_qspec,
)
OP_TO_ANNOTATOR["conv"](model, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
return self.conv(x)

quantizer = Int4ActQuantizer()
node_occurrence = {
# one for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
example_inputs = (torch.randn(1, 3, 3, 3),)

# _test_quantizer in PT2EQuantizationTestCase
# resetting dynamo cache
export_with_dynamic_shape = False
torch._dynamo.reset()
m_eager = M().eval()

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
constraints=[dynamic_dim(example_inputs[0], 0)] if export_with_dynamic_shape else [],
)

m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)

pt2_quant_output = m(*example_inputs)
node_occurrence = {
ns.call_function(k): v for k, v in expected_node_occurrence.items()
}
if expected_node_list is None:
expected_node_list = []
node_list = [ns.call_function(n) for n in expected_node_list]
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .int4 import UInt4Tensor

__all__ = [
"UInt4Tensor"
]

0 comments on commit 0c72951

Please sign in to comment.