Skip to content

Commit

Permalink
Adding uint4 dtype implementation
Browse files Browse the repository at this point in the history
Summary:
We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core.
This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this

Test Plan:
python test/dtypes/test_int4.py

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 51cf71730a2cdc17aad2fc70e42e19da9762fc46
Pull Request resolved: #13
  • Loading branch information
jerryzh168 committed Feb 10, 2024
1 parent a349409 commit 9b74f44
Show file tree
Hide file tree
Showing 3 changed files with 528 additions and 0 deletions.
229 changes: 229 additions & 0 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import torch
from torchao.dtypes.uint4 import (
UInt4Tensor,
PerChannelSymmetricWeightUInt4Tensor,
)
import unittest
from unittest import TestCase, main
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch._export import dynamic_dim
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
)
from torchao.quantization.utils import (
compute_error,
)
from torchao.quantization.quant_api import (
replace_with_custom_fn_if_matches_filter,
)
from torch.ao.quantization.observer import ObserverBase
from torch import nn
from torch.fx import (
Node,
GraphModule,
)
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
)
import copy

def _apply_weight_only_uint4_quant(model):
def fn(mod):
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False)
return mod

replace_with_custom_fn_if_matches_filter(
model,
lambda mod: fn(mod),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)

class TestUInt4(QuantizationTestCase):
def test_basic_tensor_ops(self):
x = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertEqual(x.shape, (3, 16))
# TODO: make sure this returns torch.uint4
self.assertIs(x.dtype, torch.uint4)
# making sure these works
x.to(torch.uint8)
expected = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertEqual(x[0:1, :], expected)
expected = UInt4Tensor(torch.tensor([
[0x23, 0x45],
[0x23, 0x45],
[0x23, 0x45],
], dtype=torch.uint8))
self.assertEqual(x[:, 2:6], expected)
torch.save(x, "uint4_tensor.pt")
x = torch.load("uint4_tensor.pt")
self.assertEqual(x[:, 2:6], expected)
# only test locally
# print("x:", x[0])

def test_gpu_quant(self):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(4, 16))
y_ref = m(x)
_apply_weight_only_uint4_quant(m)
y_wo = m(x)
# sqnr = compute_error(y_ref, y_wo)
opt = torch.compile(m, fullgraph=True, mode="max-autotune")
# make sure it runs
opt(x)

def test_pt2e_quant(self):
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)
class Uint4Observer(ObserverBase):
def __init__(self, *args, **kwargs):
# just faking a dtype here
# TODO: make flow work with new dtypes
super().__init__(dtype=torch.int8)

def forward(self, x):
return x

def calculate_qparams(self, **kwargs):
pass

def convert(self, model: GraphModule, observer_node: Node):
with model.graph.inserting_before(observer_node):
q_node = model.graph.call_function(
torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {})
dq_node = model.graph.call_function(
torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {})
observer_node.replace_all_uses_with(dq_node)
model.graph.erase_node(observer_node)

from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_is_annotated,
_mark_nodes_as_annotated,
)

class Int8ActUint4WeightQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
uint4_qspec = QuantizationSpec(
dtype=torch.uint4,
quant_min=0,
quant_max=2**4 - 1,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=Uint4Observer,
)
int8_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer,
)
quantization_config = QuantizationConfig(
input_activation=int8_qspec,
weight=uint4_qspec,
bias=None,
output_activation=int8_qspec,
)
for n in model.graph.nodes:
if n.op != "call_function" or n.target not in [
torch.ops.aten.linear.default,
]:
continue
linear_node = n

input_qspec_map = {}
input_act = linear_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation

weight = linear_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = quantization_config.weight

partition = [linear_node, linear_node.args[1]]

bias = linear_node.args[2] if len(linear_node.args) > 2 else None
if isinstance(bias, Node):
input_qspec_map[bias] = quantization_config.bias
partition.append(bias)

if _is_annotated(partition):
continue

linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
_mark_nodes_as_annotated(partition)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear(x)

quantizer = Int8ActUint4WeightQuantizer()
node_occurrence = {
# for weight
torch.ops.qtensors.quantize_per_tensor_uint4: 1,
torch.ops.qtensors.dequantize_per_tensor_uint4: 1,
# for activation
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.qtensors.dequantize_per_tensor_uint4,
torch.ops.aten.linear.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
example_inputs = (torch.randn(2, 4),)

# _test_quantizer in PT2EQuantizationTestCase
# resetting dynamo cache
export_with_dynamic_shape = False
torch._dynamo.reset()
m_eager = M().eval()

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)

m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=False)
pt2_quant_output = m(*example_inputs)

node_occurrence = {
ns.call_function(k): v for k, v in node_occurrence.items()
}
node_list = [ns.call_function(n) for n in node_list]
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .uint4 import UInt4Tensor

__all__ = [
"UInt4Tensor"
]

0 comments on commit 9b74f44

Please sign in to comment.