From cd57f547ab265a9b27e4bac8fdc4bffb9b3c8982 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 5 Aug 2025 14:22:23 -0700 Subject: [PATCH] [pt2e] Avoid getting model device once per node **Summary:** Previously, we call `assert_and_get_unqiue_device` once per node in both prepare and convert. This is expensive and unnecessary since the model device is the same across all nodes, so we should just call this once in the beginning and reuse the same model device across all the nodes. torchao version of https://github.com/pytorch/pytorch/pull/159901 Note: The prepare path is not completely done yet, since we are blocked on the pytorch PR on being merged. It's different from convert since it still calls utility functions from `torch.ao.quantization.fx`. **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e.py ``` --- torchao/quantization/pt2e/convert.py | 28 ++++++++++++++++++++++----- torchao/quantization/pt2e/observer.py | 12 ++++++++++-- torchao/quantization/pt2e/prepare.py | 25 ++++++++++++++++++++++-- torchao/quantization/pt2e/utils.py | 9 +++++++-- 4 files changed, 63 insertions(+), 11 deletions(-) diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index f205e55a6d..7123b0488c 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -49,9 +49,7 @@ ) from torch.ao.quantization.fx.utils import ( _get_module, - assert_and_get_unique_device, collect_producer_nodes, - create_getattr_from_value, graph_module_from_producer_nodes, node_arg_is_weight, ) @@ -74,6 +72,8 @@ from torchao.quantization.pt2e import FROM_NODE_KEY from torchao.quantization.pt2e.observer import _is_activation_post_process +from torchao.quantization.pt2e.utils import create_getattr_from_value +from torchao.utils import _assert_and_get_unique_device __all__ = [ "convert", @@ -129,6 +129,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node working with decomposed Tensor @@ -255,7 +256,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node): # sure that the default overload can be used. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -402,6 +407,7 @@ def _replace_observer_with_quantize_dequantize_node( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node @@ -482,7 +488,11 @@ def _replace_observer_with_quantize_dequantize_node( # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -780,6 +790,7 @@ def convert_weighted_module( backend_config: BackendConfig, is_decomposed: bool = False, is_reference: bool = False, + model_device: Optional[torch.device] = None, ) -> None: """Convert a weighted module to reference quantized module in the model If the QConfig of a QAT module is not set, the module will still be converted to @@ -868,7 +879,10 @@ def convert_weighted_module( is_ptq = weight_post_process is None if is_ptq: weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] - device = assert_and_get_unique_device(float_module) + if model_device is not None: + device = model_device + else: + device = _assert_and_get_unique_device(float_module) if device: weight_post_process.to(device) @@ -1071,6 +1085,7 @@ def convert( root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) qat_module_classes = get_qat_module_classes(backend_config) fused_module_classes = get_fused_module_classes(backend_config) + model_device = _assert_and_get_unique_device(model) for node in list(model.graph.nodes): if node.op == "placeholder": @@ -1118,6 +1133,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) else: _replace_observer_with_quantize_dequantize_node( @@ -1126,6 +1142,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) elif isinstance(mod, DeQuantStub): _replace_observer_or_dequant_stub_with_dequantize_node( @@ -1155,6 +1172,7 @@ def convert( backend_config, is_decomposed, is_reference, + model_device, ) # remove deadcode after converting observers to quant/dequant ops diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index 60962f8d41..a9e8c38439 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1908,10 +1908,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): else: scale, zero_point = self.calculate_qparams() scale_node = create_getattr_from_value( - model, model.graph, "_scale", scale + model, + model.graph, + "_scale", + scale, + scale.device if isinstance(scale, torch.Tensor) else None, ) zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point + model, + model.graph, + "_zero_point", + zero_point, + zero_point.device if isinstance(zero_point, torch.Tensor) else None, ) q_node = model.graph.call_function( diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index a1d57062f2..fa9869c915 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -38,6 +38,7 @@ SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from torchao.utils import _assert_and_get_unique_device # TODO: make pt2e folder private? __all__ = [ @@ -408,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Argument: """ Given a `node` and an `arg`, inserts an input observer between @@ -426,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_arg_to_return.append(new_inner_arg) return type(arg)(new_arg_to_return) @@ -478,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( return maybe_obs_node assert isinstance(model.graph, Graph) + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_arg = _insert_obs_or_fq( arg, input_edge_obs_or_fq, model, named_modules, model.graph ) @@ -491,6 +495,7 @@ def _maybe_insert_input_observers_for_node( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. @@ -517,6 +522,7 @@ def _maybe_insert_input_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_args.append(new_arg) @@ -541,9 +547,11 @@ def _maybe_insert_output_observer_for_node( graph: Graph, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Optional[Node]: if node in obs_or_fq_map: output_act_obs_or_fq = obs_or_fq_map[node] + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_output = _insert_obs_or_fq( node, output_act_obs_or_fq, model, named_modules, graph ) @@ -563,6 +571,7 @@ def _maybe_insert_input_and_output_observers_for_node( model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ): this_node_quantization_annotation = ( node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None @@ -578,6 +587,7 @@ def _maybe_insert_input_and_output_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) @@ -586,7 +596,13 @@ def _maybe_insert_input_and_output_observers_for_node( # this returns the new observer node if it was needed maybe_output_obs_node = _maybe_insert_output_observer_for_node( - node, model, named_modules, model.graph, obs_or_fq_map, is_qat + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + model_device, ) if maybe_output_obs_node is None: @@ -634,11 +650,16 @@ def prepare( ) if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) + model_device = _assert_and_get_unique_device(model) for node in nodes_before_observation: # TODO: simplify logic for inserting observers _maybe_insert_input_and_output_observers_for_node( - node, model, obs_or_fq_map, is_qat + node, + model, + obs_or_fq_map, + is_qat, + model_device, ) model = GraphModule(model, model.graph) diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index 849493b5fe..7ff1dbc619 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -525,7 +525,11 @@ def get_attr_name(i: int): def create_getattr_from_value( - module: torch.nn.Module, graph: Graph, prefix: str, value: Any + module: torch.nn.Module, + graph: Graph, + prefix: str, + value: Any, + device: Optional[torch.device] = None, ) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and @@ -533,7 +537,8 @@ def create_getattr_from_value( """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) - device = _assert_and_get_unique_device(module) + if device is None: + device = _assert_and_get_unique_device(module) new_value = ( value.detach().clone() if isinstance(value, torch.Tensor)