Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 0 additions & 55 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@
StandaloneModuleConfigEntry,
)

from torch.ao.quantization.fx.prepare import (
is_activation_post_process_node,
)

from torch.ao.quantization.fx.qconfig_utils import (
maybe_adjust_qconfig_for_module_name_object_type_order,
)
Expand Down Expand Up @@ -6544,57 +6540,6 @@ def forward(self, x):
])
m3(*example_inputs)

def test_getitem_wrapped_in_observers(self):
"""
Test that, for cases when there are observers around a getitem node:
(1) These observers are the same, and
(2) The pattern (dequant - getitem - quant) will be fused during the lowering step
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv1d(in_channels=5, out_channels=5, kernel_size=5, padding=0)
self.conv2 = torch.nn.Conv1d(in_channels=5, out_channels=5, kernel_size=5, padding=0)

def forward(self, inputs):
# inputs: [1, 5, 10]
x1 = self.conv1(inputs)
# x: [1, 5, 6]
x1 = x1 + inputs[:, :, -6:]
x2 = self.conv2(x1)
x2 = x2 + x1[:, :, -2:]
return x2

m = M()
m.eval()
qconfig_mapping = get_default_qconfig_mapping()
m = prepare_fx(m, qconfig_mapping, example_inputs=torch.rand(1, 5, 10))

# Input and output observers of getitem should be the same
modules = dict(m.named_modules(remove_duplicate=False))
for n in m.graph.nodes:
if not is_activation_post_process_node(n, modules):
continue
if n.args[0].op != "call_function" or n.args[0].target != operator.getitem:
continue
getitem_node = n.args[0]
input_observer_node = getitem_node.args[0]
output_observer_node = n
if not is_activation_post_process_node(input_observer_node, modules):
continue
input_observer = getattr(m, input_observer_node.name)
output_observer = getattr(m, output_observer_node.name)
self.assertTrue(input_observer is output_observer,
"Input observer %s for %s is not the same as output observer %s" %
(input_observer_node.name, getitem_node.name, output_observer_node.name))

m(torch.rand(1, 5, 10))
m = convert_fx(m)

# There should only be one dequantize at the end
self.checkGraphModuleNodes(m, expected_node_occurrence={
ns.call_method("dequantize") : 1,
})

@skipIfNoFBGEMM
def test_fixed_qparams_ops(self):
Expand Down
5 changes: 2 additions & 3 deletions torch/ao/quantization/fx/_lower_to_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def is_copy_node(node, modules):
torch.flatten,
torch.mean,
operator.floordiv,
operator.getitem
]
method_list = [
"clamp",
Expand Down Expand Up @@ -932,7 +931,7 @@ def special_pattern_replacement(model: QuantizedGraphModule):

return model

def _lower_getattr_tensor_metadata_op(model: QuantizedGraphModule):
def _lower_getattr_tensor_metadta_op(model: QuantizedGraphModule):
""" Modified the graph of the model inplace, to skip extra dequantize op before
the general tensor shape ops when possible
"""
Expand Down Expand Up @@ -961,7 +960,7 @@ def _lower_to_native_backend(
_lower_static_weighted_ref_functional(model, qconfig_map)
_lower_dynamic_weighted_ref_functional(model, qconfig_map)
_lower_quantized_binary_op(model, qconfig_map)
_lower_getattr_tensor_metadata_op(model)
_lower_getattr_tensor_metadta_op(model)
special_pattern_replacement(model)
model = fold_weight(model, node_name_to_scope)
model.graph.eliminate_dead_code()
Expand Down
25 changes: 1 addition & 24 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,12 +1314,7 @@ def insert_observers_for_model(
node_name_to_target_dtype, qconfig_map,
model, modules, graph)

# Second pass: Look for getitem nodes and make the input and output observers the same.
# Note: This is meant to be a workaround for the lack of dtype propagation. In the future,
# we should remove this pass if we can differentiate between tensors and non-tensors
# (e.g. dictionaries, lists) as getitem arguments.
_make_getitem_share_input_output_observers(model)

#
# After this point, the current node has input and output observers
# that it needs for itself inserted.
#
Expand All @@ -1334,24 +1329,6 @@ def insert_observers_for_model(

return results_node

def _make_getitem_share_input_output_observers(model: GraphModule):
"""
For patterns (obs0 - getitem - obs1), make the output observer the same as the input observer,
such that the new pattern becomes (obs0 - getitem - obs0). Note that this does not handle
patterns with multiple nodes between the two observers, e.g. (obs0 - reshape - getitem - obs1).
"""
modules = dict(model.named_modules(remove_duplicate=False))
for node in model.graph.nodes:
if not is_activation_post_process_node(node, modules):
continue
if node.args[0].op != "call_function" or node.args[0].target != operator.getitem:
continue
getitem_node = node.args[0]
assert(isinstance(getitem_node, Node))
if not is_activation_post_process_node(getitem_node.args[0], modules):
continue
maybe_make_input_output_share_observers(getitem_node, model, modules)

def _validate_fixed_qparams_qconfigs(model: GraphModule, qconfig_map: Dict[str, QConfigAny]):
"""
Validate whether the correct observers are configured for fixed qparams ops in the model, if any.
Expand Down