Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding uint4 dtype implementation #13

Merged
merged 17 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
229 changes: 229 additions & 0 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import torch
from torchao.dtypes.uint4 import (
UInt4Tensor,
PerChannelSymmetricWeightUInt4Tensor,
)
import unittest
from unittest import TestCase, main
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch._export import dynamic_dim
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
)
from torchao.quantization.utils import (
compute_error,
)
from torchao.quantization.quant_api import (
replace_with_custom_fn_if_matches_filter,
)
from torch.ao.quantization.observer import ObserverBase
from torch import nn
from torch.fx import (
Node,
GraphModule,
)
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
)
import copy

def _apply_weight_only_uint4_quant(model):
def fn(mod):
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False)
return mod

replace_with_custom_fn_if_matches_filter(
model,
lambda mod: fn(mod),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)

class TestUInt4(QuantizationTestCase):
def test_basic_tensor_ops(self):
x = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertEqual(x.shape, (3, 16))
# TODO: make sure this returns torch.uint4
self.assertIs(x.dtype, torch.uint4)
# making sure these works
x.to(torch.uint8)
expected = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertEqual(x[0:1, :], expected)
expected = UInt4Tensor(torch.tensor([
[0x23, 0x45],
[0x23, 0x45],
[0x23, 0x45],
], dtype=torch.uint8))
self.assertEqual(x[:, 2:6], expected)
torch.save(x, "uint4_tensor.pt")
x = torch.load("uint4_tensor.pt")
self.assertEqual(x[:, 2:6], expected)
# only test locally
# print("x:", x[0])

def test_gpu_quant(self):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(4, 16))
y_ref = m(x)
_apply_weight_only_uint4_quant(m)
y_wo = m(x)
# sqnr = compute_error(y_ref, y_wo)
opt = torch.compile(m, fullgraph=True, mode="max-autotune")
# make sure it runs
opt(x)

def test_pt2e_quant(self):
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)
class Uint4Observer(ObserverBase):
def __init__(self, *args, **kwargs):
# just faking a dtype here
# TODO: make flow work with new dtypes
super().__init__(dtype=torch.int8)

def forward(self, x):
return x

def calculate_qparams(self, **kwargs):
pass

def convert(self, model: GraphModule, observer_node: Node):
with model.graph.inserting_before(observer_node):
q_node = model.graph.call_function(
torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {})
dq_node = model.graph.call_function(
torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {})
observer_node.replace_all_uses_with(dq_node)
model.graph.erase_node(observer_node)

from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_is_annotated,
_mark_nodes_as_annotated,
)

class Int8ActUint4WeightQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
uint4_qspec = QuantizationSpec(
dtype=torch.uint4,
quant_min=0,
quant_max=2**4 - 1,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=Uint4Observer,
)
int8_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer,
)
quantization_config = QuantizationConfig(
input_activation=int8_qspec,
weight=uint4_qspec,
bias=None,
output_activation=int8_qspec,
)
for n in model.graph.nodes:
if n.op != "call_function" or n.target not in [
torch.ops.aten.linear.default,
]:
continue
linear_node = n

input_qspec_map = {}
input_act = linear_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation

weight = linear_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = quantization_config.weight

partition = [linear_node, linear_node.args[1]]

bias = linear_node.args[2] if len(linear_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = quantization_config.bias
partition.append(bias)

if _is_annotated(partition):
continue

linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
_mark_nodes_as_annotated(partition)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear(x)

quantizer = Int8ActUint4WeightQuantizer()
node_occurrence = {
# for weight
torch.ops.qtensors.quantize_per_tensor_uint4: 1,
torch.ops.qtensors.dequantize_per_tensor_uint4: 1,
# for activation
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.qtensors.dequantize_per_tensor_uint4,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
example_inputs = (torch.randn(2, 4),)

# _test_quantizer in PT2EQuantizationTestCase
# resetting dynamo cache
export_with_dynamic_shape = False
torch._dynamo.reset()
m_eager = M().eval()

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)

m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=False)
pt2_quant_output = m(*example_inputs)

node_occurrence = {
ns.call_function(k): v for k, v in node_occurrence.items()
}
node_list = [ns.call_function(n) for n in node_list]
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .uint4 import UInt4Tensor

__all__ = [
"UInt4Tensor"
]