Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 34 additions & 49 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,58 +412,43 @@ def maybe_insert_output_observer_for_node(
root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))

if qhandler is not None:
assert qconfig is not None
if qhandler is None:
return None

is_standalone_module = qhandler is not None and \
isinstance(qhandler, StandaloneModuleQuantizeHandler)

should_insert_observer = \
qhandler.should_insert_observer_for_output(
qconfig, model.training)
# TODO(future PR): move the following logic to
# should_insert_observer_for_output
should_insert_observer = should_insert_observer and \
activation_is_statically_quantized(qconfig)

# we never insert observers to output of standalone module, we assume
# if needed, they are inserted inside the standalone module
should_insert_observer = should_insert_observer and \
(not is_standalone_module)

if should_insert_observer:
act_post_process_ctr = qconfig.activation
if activation_is_int8_quantized(qconfig):
act_post_process_ctr = \
get_default_output_activation_post_process_map().get(
matched_pattern,
act_post_process_ctr)
observer = act_post_process_ctr()
new_obs = insert_observer(node, observer, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs.name] = \
node_name_to_target_dtype[node.name]
return new_obs
assert qconfig is not None
assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'

elif node.op == 'output':
prev_node = node.args[0]
assert isinstance(prev_node, Node)
prev_node_dtype = node_name_to_target_dtype[prev_node.name]
node_dtype = node_name_to_target_dtype[node.name]
should_insert_observer = (
prev_node_dtype == torch.float and
node_dtype != torch.float
)
if should_insert_observer:
assert qconfig is not None
observer = qconfig.activation()
new_obs = insert_observer(
prev_node, observer, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs.name] = node_dtype
return new_obs
is_standalone_module = qhandler is not None and \
isinstance(qhandler, StandaloneModuleQuantizeHandler)

return None
should_insert_observer = \
qhandler.should_insert_observer_for_output(
qconfig, model.training)
# TODO(future PR): move the following logic to
# should_insert_observer_for_output
should_insert_observer = should_insert_observer and \
activation_is_statically_quantized(qconfig)

# we never insert observers to output of standalone module, we assume
# if needed, they are inserted inside the standalone module
should_insert_observer = should_insert_observer and \
(not is_standalone_module)

if should_insert_observer:
act_post_process_ctr = qconfig.activation
if activation_is_int8_quantized(qconfig):
act_post_process_ctr = \
get_default_output_activation_post_process_map().get(
matched_pattern,
act_post_process_ctr)
observer = act_post_process_ctr()
new_obs = insert_observer(node, observer, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs.name] = \
node_name_to_target_dtype[node.name]
return new_obs
else:
return None

def maybe_insert_observers_before_graph_output(
graph_output_node: Node,
Expand Down