diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index ef9bb297de..fe869373d4 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -585,6 +585,7 @@ def _maybe_insert_input_and_output_observers_for_node( node: Node, model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + named_modules: dict[str, torch.nn.Module], is_qat: bool, model_device: Optional[torch.device] = None, ): @@ -594,7 +595,6 @@ def _maybe_insert_input_and_output_observers_for_node( if this_node_quantization_annotation is None: return - named_modules = dict(model.named_modules(remove_duplicate=False)) _maybe_insert_input_observers_for_node( node, None, # qconfig @@ -666,6 +666,7 @@ def prepare( if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) model_device = _assert_and_get_unique_device(model) + named_modules = dict(model.named_modules(remove_duplicate=False)) for node in nodes_before_observation: # TODO: simplify logic for inserting observers @@ -673,6 +674,7 @@ def prepare( node, model, obs_or_fq_map, + named_modules, is_qat, model_device, )