Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions torch/ao/quantization/pt2e/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _update_shared_with(edge_or_node: EdgeOrNode, qspec: QuantizationSpecBase, s
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
_union(sharing_with, edge_or_node, shared_with_map)

# TODO: simplify this
def _find_root_qspec(
qspec: QuantizationSpecBase,
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
Expand Down Expand Up @@ -113,6 +114,24 @@ def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode,
edge_or_node_to_qspec[output_node] = qspec
return edge_or_node_to_qspec

def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
root_qspec = None
if edge_or_node in edge_or_node_to_qspec:
qspec = edge_or_node_to_qspec[edge_or_node]
root_qspec = _find_root_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
# TODO: add assertions for types of root qspecs
if (
root_qspec is not None and
_has_same_dtype(root_qspec, input_edge_root_qspec) and
_has_same_is_dynamic(root_qspec, input_edge_root_qspec)
):
# the input arg to the node should reuse the existing output observer for arg
# since dtype is the same (we may want to extend this to be a more strict check
# in the future)
# so we point from `input_edge` to `arg` (output of the argument)
_union(edge_or_node, input_edge, shared_with_map)


def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
"""Map from edge/node to the group ID, generated from quantization annotations,
edge/node with the same group ID should use the same observer/fake_quant instance
Expand Down Expand Up @@ -179,21 +198,23 @@ def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, Quanti
assert isinstance(input_edge, tuple)
arg, n = input_edge
if n.meta["quantization_annotation"].allow_implicit_sharing:
arg_as_output_root_qspec = None
if arg in edge_or_node_to_qspec:
arg_as_output_qspec = edge_or_node_to_qspec[arg]
arg_as_output_root_qspec = _find_root_qspec(arg_as_output_qspec, edge_or_node_to_qspec, shared_with_map)
# TODO: add assertions for types of root qspecs
if (
arg_as_output_root_qspec is not None and
_has_same_dtype(arg_as_output_root_qspec, input_edge_root_qspec) and
_has_same_is_dynamic(arg_as_output_root_qspec, input_edge_root_qspec)
):
# the input arg to the node should reuse the existing output observer for arg
# since dtype is the same (we may want to extend this to be a more strict check
# in the future)
# so we point from `input_edge` to `arg` (output of the argument)
_union(arg, input_edge, shared_with_map)
# sharing with previous output
_union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map)

# sharing with other users of the previous output
# (arg, user)
for user in arg.users:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a refactor change, right?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a refactor, we removed this logic from insert observers and added it here

if user is n:
continue
arg_to_user_edge = (arg, user)
_union_input_edge_with(
input_edge,
input_edge_root_qspec,
arg_to_user_edge,
edge_or_node_to_qspec,
shared_with_map
)

_update_shared_with(input_edge, qspec, shared_with_map)

# now that we get the sharing relations between all edges and nodes, we can assingn group ids
Expand Down Expand Up @@ -281,10 +302,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
# otherwise, we'll insert a new observer/fake_quant node

existing_obs_node = None
# skip inserting new observers if there is an observer inserted for the arg before
# that has the same dtype that we want to insert here
# alternatively we could have a dedup pass after we insert all observers to deduplicate
# observers
# skip inserting new observers if the same observer instance is inserted before for another user
# Example:
# conv1 -> obs1 -> existing_obs -> conv2
# \ -> conv3
Expand All @@ -296,19 +314,10 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
if not _is_activation_post_process_node(maybe_obs_node, named_modules):
continue
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if (
type(maybe_obs_mod) == type(input_edge_obs_or_fq) and
maybe_obs_mod.dtype == input_edge_obs_or_fq.dtype
):
input_edge_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
existing_obs_node = maybe_obs_node
break

if existing_obs_node is None:
new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
else:
new_arg = existing_obs_node
if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
return maybe_obs_node

new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
return new_arg

def _maybe_insert_input_observers_for_node(
Expand Down