-
Notifications
You must be signed in to change notification settings - Fork 22.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Quantization] Add metadata porting for nodes added by quantization (#…
…107107) Summary: This diff adds adding metadata to q-dq nodes by inferring the quatization intent from node annotations. Annotations on the node are way for user to specify how a node or subgraph is supposed to be quantized. We continue to use that information to copy metadata on Q/DQ node from appropriate nodes. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D48488416](https://our.internmc.facebook.com/intern/diff/D48488416) Pull Request resolved: #107107 Approved by: https://github.com/jerryzh168 ghstack dependencies: #107105, #107106, #107899, #107900
- Loading branch information
1 parent
d6a9c2b
commit ffc0c46
Showing
3 changed files
with
551 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,367 @@ | ||
# Owner(s): ["oncall: quantization"] | ||
import copy | ||
|
||
import unittest | ||
from typing import List | ||
|
||
import torch | ||
import torch._export as export | ||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e | ||
from torch.ao.quantization.quantizer import Quantizer | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer import ( | ||
get_symmetric_quantization_config, | ||
) | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR | ||
|
||
from torch.fx import Node | ||
|
||
from torch.testing._internal.common_quantization import QuantizationTestCase | ||
|
||
|
||
class TestHelperModules: | ||
class Conv2dWithObsSharingOps(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = torch.nn.Conv2d(3, 3, 3) | ||
self.hardtanh = torch.nn.Hardtanh() | ||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) | ||
self.linear = torch.nn.Linear(3, 3) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.adaptive_avg_pool2d(x) | ||
x = self.hardtanh(x) | ||
x = x.view(-1, 3) | ||
x = self.linear(x) | ||
return x | ||
|
||
|
||
def _tag_partitions( | ||
backend_name: str, op_name: str, annotated_partitions: List[List[Node]] | ||
): | ||
for index, partition_nodes in enumerate(annotated_partitions): | ||
tag_name = backend_name + "_" + op_name + "_" + str(index) | ||
for node in partition_nodes: | ||
assert "quantization_tag" not in node.meta, f"{node} is already tagged" | ||
node.meta["quantization_tag"] = tag_name | ||
|
||
|
||
_QUANT_OPS = { | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default, | ||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, | ||
torch.ops.quantized_decomposed.quantize_per_channel.default, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default, | ||
torch.ops.quantized_decomposed.choose_qparams.tensor, | ||
} | ||
|
||
|
||
class TestMetaDataPorting(QuantizationTestCase): | ||
def _test_metadata_porting( | ||
self, | ||
model, | ||
example_inputs, | ||
quantizer, | ||
node_tags=None, | ||
): | ||
m_eager = model.eval() | ||
|
||
# program capture | ||
m = copy.deepcopy(m_eager) | ||
m = export.capture_pre_autograd_graph( | ||
m, | ||
example_inputs, | ||
) | ||
|
||
m = prepare_pt2e(m, quantizer) | ||
# Calibrate | ||
m(*example_inputs) | ||
m = convert_pt2e(m) | ||
|
||
pt2_quant_output = m(*example_inputs) | ||
recorded_node_tags = {} | ||
for n in m.graph.nodes: | ||
if ( | ||
n.op == "call_function" | ||
and n.target in _QUANT_OPS | ||
and "quantization_tag" in n.meta | ||
): | ||
if n.target not in recorded_node_tags: | ||
recorded_node_tags[n.target] = set() | ||
if n.meta["quantization_tag"] in recorded_node_tags[n.target]: | ||
raise ValueError( | ||
f"{n} has tag {n.meta['quantization_tag']} that is associated with another node of the same type" | ||
) | ||
recorded_node_tags[n.target].add(n.meta["quantization_tag"]) | ||
self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys())) | ||
for k, v in recorded_node_tags.items(): | ||
self.assertEqual(v, node_tags[k]) | ||
|
||
def test_simple_metadata_porting(self): | ||
""" | ||
Model under test | ||
conv2d -> avgpool -> hardtanh -> linear | ||
Check quantization tags on conv2d, avgpool and linear are correctly set | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
quantization_config = get_symmetric_quantization_config( | ||
is_per_channel=True | ||
) | ||
annotated_partitions = OP_TO_ANNOTATOR["linear"]( | ||
gm, quantization_config | ||
) | ||
_tag_partitions(backend_string, "linear", annotated_partitions) | ||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( | ||
gm, quantization_config | ||
) | ||
_tag_partitions(backend_string, "conv2d", annotated_partitions) | ||
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( | ||
gm, quantization_config | ||
) | ||
_tag_partitions( | ||
backend_string, "adaptive_avg_pool2d", annotated_partitions | ||
) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 5),) | ||
quantize_per_tensor_tags = { | ||
"BackendA_conv2d_0", | ||
"BackendA_adaptive_avg_pool2d_0", | ||
"BackendA_linear_0", | ||
} | ||
dequantize_per_tensor_tags = { | ||
"BackendA_adaptive_avg_pool2d_0", | ||
"BackendA_conv2d_0", | ||
"BackendA_linear_0", | ||
} | ||
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} | ||
node_tags = { | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, | ||
} | ||
self._test_metadata_porting( | ||
TestHelperModules.Conv2dWithObsSharingOps(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
node_tags, | ||
) | ||
|
||
def test_metadata_porting_with_no_quant_inbetween(self): | ||
""" | ||
Model under test | ||
conv2d -> avgpool -> hardtanh -> linear | ||
Dont quantize avgpool | ||
Check quantization tags on conv2d and linear are correctly set | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
quantization_config = get_symmetric_quantization_config( | ||
is_per_channel=True | ||
) | ||
annotated_partitions = OP_TO_ANNOTATOR["linear"]( | ||
gm, quantization_config | ||
) | ||
_tag_partitions(backend_string, "linear", annotated_partitions) | ||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( | ||
gm, quantization_config | ||
) | ||
_tag_partitions(backend_string, "conv2d", annotated_partitions) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 5),) | ||
quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} | ||
dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} | ||
dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} | ||
node_tags = { | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, | ||
} | ||
self._test_metadata_porting( | ||
TestHelperModules.Conv2dWithObsSharingOps(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
node_tags, | ||
) | ||
|
||
@unittest.skip("Temporarily disabled") | ||
def test_metadata_porting_for_dq(self): | ||
""" | ||
Model under test | ||
conv2d -> avgpool -> hardtanh -> linear | ||
Quantize all except linear. | ||
Quantize linear with dynamic quantization | ||
Check quantization tags on conv2d, avgpool and linear are correctly set | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
# static quantiazation | ||
quantization_config = get_symmetric_quantization_config( | ||
is_per_channel=True | ||
) | ||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( | ||
gm, quantization_config | ||
) | ||
_tag_partitions(backend_string, "conv2d", annotated_partitions) | ||
annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( | ||
gm, quantization_config | ||
) | ||
_tag_partitions( | ||
backend_string, "adaptive_avg_pool2d", annotated_partitions | ||
) | ||
|
||
# dynamic quantization | ||
quantization_config_dynamic = get_symmetric_quantization_config( | ||
is_per_channel=True, is_dynamic=True | ||
) | ||
annotated_partitions = OP_TO_ANNOTATOR["linear"]( | ||
gm, quantization_config_dynamic | ||
) | ||
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 5),) | ||
quantize_per_tensor_tags = { | ||
"BackendA_conv2d_0", | ||
"BackendA_adaptive_avg_pool2d_0", | ||
} | ||
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} | ||
choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} | ||
dequantize_per_tensor_tags = { | ||
"BackendA_adaptive_avg_pool2d_0", | ||
"BackendA_conv2d_0", | ||
} | ||
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} | ||
dequantize_per_channel_tags = { | ||
"BackendA_conv2d_0", | ||
"BackendA_linear_dynamic_0", | ||
} | ||
node_tags = { | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, | ||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, | ||
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags, | ||
} | ||
self._test_metadata_porting( | ||
TestHelperModules.Conv2dWithObsSharingOps(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
node_tags, | ||
) | ||
|
||
def test_metadata_porting_for_two_dq(self): | ||
""" | ||
Model under test | ||
conv2d -> avgpool -> hardtanh -> linear | ||
Quantize linear and conv with dynamic quantization | ||
Check quantization tags on conv2d, avgpool and linear are correctly set | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
|
||
# dynamic quantization | ||
quantization_config_dynamic = get_symmetric_quantization_config( | ||
is_per_channel=True, is_dynamic=True | ||
) | ||
annotated_partitions = OP_TO_ANNOTATOR["conv2d"]( | ||
gm, quantization_config_dynamic | ||
) | ||
_tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions) | ||
annotated_partitions = OP_TO_ANNOTATOR["linear"]( | ||
gm, quantization_config_dynamic | ||
) | ||
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 5),) | ||
choose_qparams_tensor_tags = { | ||
"BackendA_conv2d_dynamic_0", | ||
"BackendA_linear_dynamic_0", | ||
} | ||
quantize_per_tensor_tensor_tags = { | ||
"BackendA_conv2d_dynamic_0", | ||
"BackendA_linear_dynamic_0", | ||
} | ||
dequantize_per_tensor_tensor_tags = { | ||
"BackendA_conv2d_dynamic_0", | ||
"BackendA_linear_dynamic_0", | ||
} | ||
dequantize_per_channel_tags = { | ||
"BackendA_conv2d_dynamic_0", | ||
"BackendA_linear_dynamic_0", | ||
} | ||
node_tags = { | ||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, | ||
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, | ||
} | ||
self._test_metadata_porting( | ||
TestHelperModules.Conv2dWithObsSharingOps(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
node_tags, | ||
) | ||
|
||
def test_metadata_porting_for_dq_no_static_q(self): | ||
""" | ||
Model under test | ||
conv2d -> avgpool -> hardtanh -> linear | ||
Dont quantize anything except linear. | ||
Quantize linear with dynamic quantization | ||
Check quantization tags on conv2d, avgpool and linear are correctly set | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
# dynamic quantization | ||
quantization_config_dynamic = get_symmetric_quantization_config( | ||
is_per_channel=True, is_dynamic=True | ||
) | ||
annotated_partitions = OP_TO_ANNOTATOR["linear"]( | ||
gm, quantization_config_dynamic | ||
) | ||
_tag_partitions(backend_string, "linear_dynamic", annotated_partitions) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 5),) | ||
choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"} | ||
quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} | ||
dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} | ||
dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"} | ||
node_tags = { | ||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, | ||
torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, | ||
} | ||
self._test_metadata_porting( | ||
TestHelperModules.Conv2dWithObsSharingOps(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
node_tags, | ||
) |
Oops, something went wrong.