diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index f205e55a6d..7123b0488c 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -49,9 +49,7 @@ ) from torch.ao.quantization.fx.utils import ( _get_module, - assert_and_get_unique_device, collect_producer_nodes, - create_getattr_from_value, graph_module_from_producer_nodes, node_arg_is_weight, ) @@ -74,6 +72,8 @@ from torchao.quantization.pt2e import FROM_NODE_KEY from torchao.quantization.pt2e.observer import _is_activation_post_process +from torchao.quantization.pt2e.utils import create_getattr_from_value +from torchao.utils import _assert_and_get_unique_device __all__ = [ "convert", @@ -129,6 +129,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node working with decomposed Tensor @@ -255,7 +256,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node): # sure that the default overload can be used. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -402,6 +407,7 @@ def _replace_observer_with_quantize_dequantize_node( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node @@ -482,7 +488,11 @@ def _replace_observer_with_quantize_dequantize_node( # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -780,6 +790,7 @@ def convert_weighted_module( backend_config: BackendConfig, is_decomposed: bool = False, is_reference: bool = False, + model_device: Optional[torch.device] = None, ) -> None: """Convert a weighted module to reference quantized module in the model If the QConfig of a QAT module is not set, the module will still be converted to @@ -868,7 +879,10 @@ def convert_weighted_module( is_ptq = weight_post_process is None if is_ptq: weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] - device = assert_and_get_unique_device(float_module) + if model_device is not None: + device = model_device + else: + device = _assert_and_get_unique_device(float_module) if device: weight_post_process.to(device) @@ -1071,6 +1085,7 @@ def convert( root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) qat_module_classes = get_qat_module_classes(backend_config) fused_module_classes = get_fused_module_classes(backend_config) + model_device = _assert_and_get_unique_device(model) for node in list(model.graph.nodes): if node.op == "placeholder": @@ -1118,6 +1133,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) else: _replace_observer_with_quantize_dequantize_node( @@ -1126,6 +1142,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) elif isinstance(mod, DeQuantStub): _replace_observer_or_dequant_stub_with_dequantize_node( @@ -1155,6 +1172,7 @@ def convert( backend_config, is_decomposed, is_reference, + model_device, ) # remove deadcode after converting observers to quant/dequant ops diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index 60962f8d41..a9e8c38439 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1908,10 +1908,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): else: scale, zero_point = self.calculate_qparams() scale_node = create_getattr_from_value( - model, model.graph, "_scale", scale + model, + model.graph, + "_scale", + scale, + scale.device if isinstance(scale, torch.Tensor) else None, ) zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point + model, + model.graph, + "_zero_point", + zero_point, + zero_point.device if isinstance(zero_point, torch.Tensor) else None, ) q_node = model.graph.call_function( diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index a1d57062f2..fa9869c915 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -38,6 +38,7 @@ SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from torchao.utils import _assert_and_get_unique_device # TODO: make pt2e folder private? __all__ = [ @@ -408,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Argument: """ Given a `node` and an `arg`, inserts an input observer between @@ -426,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_arg_to_return.append(new_inner_arg) return type(arg)(new_arg_to_return) @@ -478,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( return maybe_obs_node assert isinstance(model.graph, Graph) + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_arg = _insert_obs_or_fq( arg, input_edge_obs_or_fq, model, named_modules, model.graph ) @@ -491,6 +495,7 @@ def _maybe_insert_input_observers_for_node( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. @@ -517,6 +522,7 @@ def _maybe_insert_input_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_args.append(new_arg) @@ -541,9 +547,11 @@ def _maybe_insert_output_observer_for_node( graph: Graph, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Optional[Node]: if node in obs_or_fq_map: output_act_obs_or_fq = obs_or_fq_map[node] + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_output = _insert_obs_or_fq( node, output_act_obs_or_fq, model, named_modules, graph ) @@ -563,6 +571,7 @@ def _maybe_insert_input_and_output_observers_for_node( model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ): this_node_quantization_annotation = ( node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None @@ -578,6 +587,7 @@ def _maybe_insert_input_and_output_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) @@ -586,7 +596,13 @@ def _maybe_insert_input_and_output_observers_for_node( # this returns the new observer node if it was needed maybe_output_obs_node = _maybe_insert_output_observer_for_node( - node, model, named_modules, model.graph, obs_or_fq_map, is_qat + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + model_device, ) if maybe_output_obs_node is None: @@ -634,11 +650,16 @@ def prepare( ) if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) + model_device = _assert_and_get_unique_device(model) for node in nodes_before_observation: # TODO: simplify logic for inserting observers _maybe_insert_input_and_output_observers_for_node( - node, model, obs_or_fq_map, is_qat + node, + model, + obs_or_fq_map, + is_qat, + model_device, ) model = GraphModule(model, model.graph) diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index 849493b5fe..7ff1dbc619 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -525,7 +525,11 @@ def get_attr_name(i: int): def create_getattr_from_value( - module: torch.nn.Module, graph: Graph, prefix: str, value: Any + module: torch.nn.Module, + graph: Graph, + prefix: str, + value: Any, + device: Optional[torch.device] = None, ) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and @@ -533,7 +537,8 @@ def create_getattr_from_value( """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) - device = _assert_and_get_unique_device(module) + if device is None: + device = _assert_and_get_unique_device(module) new_value = ( value.detach().clone() if isinstance(value, torch.Tensor)