From ade34e0b4615840e7dbca9c41c8caf16b354738f Mon Sep 17 00:00:00 2001 From: Naveen Suda Date: Sun, 12 Oct 2025 21:04:15 -0700 Subject: [PATCH] call named_modules once per model prepare Summary: Previously, dict(model.named_modules()) was called for every node in the model, which makes it very slow especially for LLMs which have large number of nodes. Hence, moved it outside such that it is only called once per model prepare. Differential Revision: D84389318 --- torchao/quantization/pt2e/prepare.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index ef9bb297de..fe869373d4 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -585,6 +585,7 @@ def _maybe_insert_input_and_output_observers_for_node( node: Node, model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + named_modules: dict[str, torch.nn.Module], is_qat: bool, model_device: Optional[torch.device] = None, ): @@ -594,7 +595,6 @@ def _maybe_insert_input_and_output_observers_for_node( if this_node_quantization_annotation is None: return - named_modules = dict(model.named_modules(remove_duplicate=False)) _maybe_insert_input_observers_for_node( node, None, # qconfig @@ -666,6 +666,7 @@ def prepare( if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) model_device = _assert_and_get_unique_device(model) + named_modules = dict(model.named_modules(remove_duplicate=False)) for node in nodes_before_observation: # TODO: simplify logic for inserting observers @@ -673,6 +674,7 @@ def prepare( node, model, obs_or_fq_map, + named_modules, is_qat, model_device, )