-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 51cf71730a2cdc17aad2fc70e42e19da9762fc46 Pull Request resolved: #13
- Loading branch information
1 parent
a349409
commit 9b74f44
Showing
3 changed files
with
528 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
import torch | ||
from torchao.dtypes.uint4 import ( | ||
UInt4Tensor, | ||
PerChannelSymmetricWeightUInt4Tensor, | ||
) | ||
import unittest | ||
from unittest import TestCase, main | ||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e | ||
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer | ||
|
||
from torch._export import capture_pre_autograd_graph | ||
from torch._export import dynamic_dim | ||
from torch.testing._internal.common_quantization import ( | ||
NodeSpec as ns, | ||
QuantizationTestCase, | ||
) | ||
from torchao.quantization.utils import ( | ||
compute_error, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
replace_with_custom_fn_if_matches_filter, | ||
) | ||
from torch.ao.quantization.observer import ObserverBase | ||
from torch import nn | ||
from torch.fx import ( | ||
Node, | ||
GraphModule, | ||
) | ||
from torch.ao.quantization.quantizer import ( | ||
QuantizationAnnotation, | ||
) | ||
import copy | ||
|
||
def _apply_weight_only_uint4_quant(model): | ||
def fn(mod): | ||
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) | ||
return mod | ||
|
||
replace_with_custom_fn_if_matches_filter( | ||
model, | ||
lambda mod: fn(mod), | ||
lambda mod, fqn: isinstance(mod, torch.nn.Linear), | ||
) | ||
|
||
class TestUInt4(QuantizationTestCase): | ||
def test_basic_tensor_ops(self): | ||
x = UInt4Tensor(torch.tensor([ | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
], dtype=torch.uint8)) | ||
self.assertEqual(x.shape, (3, 16)) | ||
# TODO: make sure this returns torch.uint4 | ||
self.assertIs(x.dtype, torch.uint4) | ||
# making sure these works | ||
x.to(torch.uint8) | ||
expected = UInt4Tensor(torch.tensor([ | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
], dtype=torch.uint8)) | ||
self.assertEqual(x[0:1, :], expected) | ||
expected = UInt4Tensor(torch.tensor([ | ||
[0x23, 0x45], | ||
[0x23, 0x45], | ||
[0x23, 0x45], | ||
], dtype=torch.uint8)) | ||
self.assertEqual(x[:, 2:6], expected) | ||
torch.save(x, "uint4_tensor.pt") | ||
x = torch.load("uint4_tensor.pt") | ||
self.assertEqual(x[:, 2:6], expected) | ||
# only test locally | ||
# print("x:", x[0]) | ||
|
||
def test_gpu_quant(self): | ||
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: | ||
x = torch.randn(*x_shape) | ||
m = nn.Sequential(nn.Linear(4, 16)) | ||
y_ref = m(x) | ||
_apply_weight_only_uint4_quant(m) | ||
y_wo = m(x) | ||
# sqnr = compute_error(y_ref, y_wo) | ||
opt = torch.compile(m, fullgraph=True, mode="max-autotune") | ||
# make sure it runs | ||
opt(x) | ||
|
||
def test_pt2e_quant(self): | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( | ||
OP_TO_ANNOTATOR, | ||
QuantizationConfig, | ||
) | ||
class Uint4Observer(ObserverBase): | ||
def __init__(self, *args, **kwargs): | ||
# just faking a dtype here | ||
# TODO: make flow work with new dtypes | ||
super().__init__(dtype=torch.int8) | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
def calculate_qparams(self, **kwargs): | ||
pass | ||
|
||
def convert(self, model: GraphModule, observer_node: Node): | ||
with model.graph.inserting_before(observer_node): | ||
q_node = model.graph.call_function( | ||
torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {}) | ||
dq_node = model.graph.call_function( | ||
torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {}) | ||
observer_node.replace_all_uses_with(dq_node) | ||
model.graph.erase_node(observer_node) | ||
|
||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( | ||
_is_annotated, | ||
_mark_nodes_as_annotated, | ||
) | ||
|
||
class Int8ActUint4WeightQuantizer(Quantizer): | ||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
uint4_qspec = QuantizationSpec( | ||
dtype=torch.uint4, | ||
quant_min=0, | ||
quant_max=2**4 - 1, | ||
qscheme=torch.per_tensor_affine, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=Uint4Observer, | ||
) | ||
int8_qspec = QuantizationSpec( | ||
dtype=torch.int8, | ||
quant_min=-128, | ||
quant_max=127, | ||
qscheme=torch.per_tensor_symmetric, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer, | ||
) | ||
quantization_config = QuantizationConfig( | ||
input_activation=int8_qspec, | ||
weight=uint4_qspec, | ||
bias=None, | ||
output_activation=int8_qspec, | ||
) | ||
for n in model.graph.nodes: | ||
if n.op != "call_function" or n.target not in [ | ||
torch.ops.aten.linear.default, | ||
]: | ||
continue | ||
linear_node = n | ||
|
||
input_qspec_map = {} | ||
input_act = linear_node.args[0] | ||
assert isinstance(input_act, Node) | ||
input_qspec_map[input_act] = quantization_config.input_activation | ||
|
||
weight = linear_node.args[1] | ||
assert isinstance(weight, Node) | ||
input_qspec_map[weight] = quantization_config.weight | ||
|
||
partition = [linear_node, linear_node.args[1]] | ||
|
||
bias = linear_node.args[2] if len(linear_node.args) > 2 else None | ||
if isinstance(bias, Node): | ||
input_qspec_map[bias] = quantization_config.bias | ||
partition.append(bias) | ||
|
||
if _is_annotated(partition): | ||
continue | ||
|
||
linear_node.meta["quantization_annotation"] = QuantizationAnnotation( | ||
input_qspec_map=input_qspec_map, | ||
output_qspec=quantization_config.output_activation, | ||
_annotated=True, | ||
) | ||
_mark_nodes_as_annotated(partition) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear = torch.nn.Linear(4, 4) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
quantizer = Int8ActUint4WeightQuantizer() | ||
node_occurrence = { | ||
# for weight | ||
torch.ops.qtensors.quantize_per_tensor_uint4: 1, | ||
torch.ops.qtensors.dequantize_per_tensor_uint4: 1, | ||
# for activation | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, | ||
} | ||
node_list = [ | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default, | ||
torch.ops.qtensors.dequantize_per_tensor_uint4, | ||
torch.ops.aten.linear.default, | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default, | ||
] | ||
example_inputs = (torch.randn(2, 4),) | ||
|
||
# _test_quantizer in PT2EQuantizationTestCase | ||
# resetting dynamo cache | ||
export_with_dynamic_shape = False | ||
torch._dynamo.reset() | ||
m_eager = M().eval() | ||
|
||
# program capture | ||
m = copy.deepcopy(m_eager) | ||
m = capture_pre_autograd_graph( | ||
m, | ||
example_inputs, | ||
) | ||
|
||
m = prepare_pt2e(m, quantizer) | ||
# Calibrate | ||
m(*example_inputs) | ||
m = convert_pt2e(m, fold_quantize=False) | ||
pt2_quant_output = m(*example_inputs) | ||
|
||
node_occurrence = { | ||
ns.call_function(k): v for k, v in node_occurrence.items() | ||
} | ||
node_list = [ns.call_function(n) for n in node_list] | ||
self.checkGraphModuleNodes( | ||
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list | ||
) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .uint4 import UInt4Tensor | ||
|
||
__all__ = [ | ||
"UInt4Tensor" | ||
] |
Oops, something went wrong.