-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: be18a647187ac5156affc8884ef94f788321e942 Pull Request resolved: #13
- Loading branch information
1 parent
e16898d
commit 4ea07b9
Showing
3 changed files
with
554 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
import torch | ||
from torchao.dtypes.int4 import ( | ||
UInt4Tensor, | ||
PerChannelSymmetricWeightUInt4Tensor, | ||
) | ||
import unittest | ||
from unittest import TestCase, main | ||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e | ||
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer | ||
|
||
from torch._export import capture_pre_autograd_graph | ||
from torch._export import dynamic_dim | ||
from torch.testing._internal.common_quantization import ( | ||
NodeSpec as ns, | ||
QuantizationTestCase, | ||
) | ||
from torchao.quantization.utils import ( | ||
compute_error, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
replace_with_custom_fn_if_matches_filter, | ||
) | ||
from torch.ao.quantization.observer import ObserverBase | ||
from torch import nn | ||
from torch.fx import ( | ||
Node, | ||
GraphModule, | ||
) | ||
from torch.ao.quantization.quantizer import ( | ||
QuantizationAnnotation, | ||
) | ||
import copy | ||
|
||
def _apply_weight_only_int4_quant(model): | ||
def fn(mod): | ||
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) | ||
return mod | ||
|
||
replace_with_custom_fn_if_matches_filter( | ||
model, | ||
lambda mod: fn(mod), | ||
lambda mod, fqn: isinstance(mod, torch.nn.Linear), | ||
) | ||
|
||
class TestInt4(QuantizationTestCase): | ||
def test_basic_tensor_ops(self): | ||
x = UInt4Tensor(torch.tensor([ | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
], dtype=torch.uint8)) | ||
self.assertEqual(x.shape, (3, 16)) | ||
# TODO: make sure this returns torch.uint4 | ||
self.assertIs(x.dtype, torch.uint4) | ||
# making sure these works | ||
x.to(torch.uint8) | ||
expected = UInt4Tensor(torch.tensor([ | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
], dtype=torch.uint8)) | ||
self.assertEqual(x[0:1, :], expected) | ||
expected = UInt4Tensor(torch.tensor([ | ||
[0x23, 0x45], | ||
[0x23, 0x45], | ||
[0x23, 0x45], | ||
], dtype=torch.uint8)) | ||
self.assertEqual(x[:, 2:6], expected) | ||
torch.save(x, "uint4_tensor.pt") | ||
x = torch.load("uint4_tensor.pt") | ||
self.assertEqual(x[:, 2:6], expected) | ||
print("x:", x[0]) | ||
|
||
def test_gpu_quant(self): | ||
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: | ||
x = torch.randn(*x_shape) | ||
m = nn.Sequential(nn.Linear(4, 16)) | ||
y_ref = m(x) | ||
_apply_weight_only_int4_quant(m) | ||
y_wo = m(x) | ||
# sqnr = compute_error(y_ref, y_wo) | ||
opt = torch.compile(m, mode="max-autotune") | ||
# make sure it runs | ||
opt(x) | ||
|
||
def test_pt2e_quant(self): | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( | ||
OP_TO_ANNOTATOR, | ||
QuantizationConfig, | ||
) | ||
class Int4Observer(ObserverBase): | ||
def __init__(self, *args, **kwargs): | ||
# just faking a dtype here | ||
# TODO: make flow work with new dtypes | ||
super().__init__(dtype=torch.int8) | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
def calculate_qparams(self, **kwargs): | ||
pass | ||
|
||
def convert(self, model: GraphModule, observer_node: Node): | ||
with model.graph.inserting_before(observer_node): | ||
q_node = model.graph.call_function( | ||
torch.ops.qtensors.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) | ||
dq_node = model.graph.call_function( | ||
torch.ops.qtensors.dequantize_per_tensor_int4, (q_node, 1.0, 0), {}) | ||
observer_node.replace_all_uses_with(dq_node) | ||
model.graph.erase_node(observer_node) | ||
|
||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( | ||
_is_annotated, | ||
_mark_nodes_as_annotated, | ||
) | ||
|
||
class Int8ActInt4WeightQuantizer(Quantizer): | ||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
int4_qspec = QuantizationSpec( | ||
dtype=torch.uint4, | ||
quant_min=0, | ||
quant_max=2**4 - 1, | ||
qscheme=torch.per_tensor_affine, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=Int4Observer, | ||
) | ||
int8_qspec = QuantizationSpec( | ||
dtype=torch.int8, | ||
quant_min=-128, | ||
quant_max=127, | ||
qscheme=torch.per_tensor_symmetric, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer, | ||
) | ||
quantization_config = QuantizationConfig( | ||
input_activation=int8_qspec, | ||
weight=int4_qspec, | ||
bias=None, | ||
output_activation=int8_qspec, | ||
) | ||
for n in model.graph.nodes: | ||
if n.op != "call_function" or n.target not in [ | ||
torch.ops.aten.linear.default, | ||
]: | ||
continue | ||
linear_node = n | ||
|
||
input_qspec_map = {} | ||
input_act = linear_node.args[0] | ||
assert isinstance(input_act, Node) | ||
input_qspec_map[input_act] = quantization_config.input_activation | ||
|
||
weight = linear_node.args[1] | ||
assert isinstance(weight, Node) | ||
input_qspec_map[weight] = quantization_config.weight | ||
|
||
partition = [linear_node, linear_node.args[1]] | ||
|
||
bias = linear_node.args[2] if len(linear_node.args) > 2 else None | ||
if isinstance(bias, Node): | ||
input_qspec_map[bias] = quantization_config.bias | ||
partition.append(bias) | ||
|
||
if _is_annotated(partition): | ||
continue | ||
|
||
linear_node.meta["quantization_annotation"] = QuantizationAnnotation( | ||
input_qspec_map=input_qspec_map, | ||
output_qspec=quantization_config.output_activation, | ||
_annotated=True, | ||
) | ||
_mark_nodes_as_annotated(partition) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear = torch.nn.Linear(4, 4) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
quantizer = Int8ActInt4WeightQuantizer() | ||
node_occurrence = { | ||
# for weight | ||
torch.ops.qtensors.quantize_per_tensor_int4: 1, | ||
torch.ops.qtensors.dequantize_per_tensor_int4: 1, | ||
# for activation | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, | ||
} | ||
node_list = [ | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default, | ||
torch.ops.qtensors.dequantize_per_tensor_int4, | ||
torch.ops.aten.linear.default, | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default, | ||
] | ||
example_inputs = (torch.randn(2, 4),) | ||
|
||
# _test_quantizer in PT2EQuantizationTestCase | ||
# resetting dynamo cache | ||
export_with_dynamic_shape = False | ||
torch._dynamo.reset() | ||
m_eager = M().eval() | ||
|
||
# program capture | ||
m = copy.deepcopy(m_eager) | ||
m = capture_pre_autograd_graph( | ||
m, | ||
example_inputs, | ||
) | ||
|
||
m = prepare_pt2e(m, quantizer) | ||
# Calibrate | ||
m(*example_inputs) | ||
m = convert_pt2e(m, fold_quantize=False) | ||
pt2_quant_output = m(*example_inputs) | ||
|
||
node_occurrence = { | ||
ns.call_function(k): v for k, v in node_occurrence.items() | ||
} | ||
node_list = [ns.call_function(n) for n in node_list] | ||
self.checkGraphModuleNodes( | ||
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list | ||
) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .int4 import UInt4Tensor | ||
|
||
__all__ = [ | ||
"UInt4Tensor" | ||
] |
Oops, something went wrong.