Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions test/quantization/fx/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,8 @@ def forward(self, x):
)

m = prepare_pt2e_quantizer(m, quantizer)
print("after prepare:", m)
m(*example_inputs)
m = convert_pt2e(m)
print("m:", m)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 5,
Expand Down
8 changes: 0 additions & 8 deletions torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,6 @@ def set_spec_for_operator_type(
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
""" just handling global spec for now
"""
# initialize default target_dtype_info
_DEFAULT_TARGET_DTYPE_INFO = {
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
}
for node in model.graph.nodes:
node.meta["target_dtype_info"] = copy.deepcopy(_DEFAULT_TARGET_DTYPE_INFO)

global_spec = self.operator_spec_config.global_spec
ops = self.get_supported_operator_for_operator_spec(global_spec)
# annotate the nodes from last to first since the matching is in the reversed order
Expand Down
9 changes: 7 additions & 2 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@
# list of dtypes to not add observers to
_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]

_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also work as FQ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this works for FQ as well, it is not inserted to the graph, only the dtype here is used


# note: the following default target dtype info dicts are temporary,
# should be moved to the new programmable API class soon
_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
Expand Down Expand Up @@ -497,8 +499,11 @@ def _get_arg_target_dtype_as_output(
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
else:
output_act_obs_or_fq_ctr = \
arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
if "target_dtype_info" in arg.meta:
output_act_obs_or_fq_ctr = \
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
else:
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
# TODO: should support is_dynamic here as well
return output_act_dtype
Expand Down