diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 7d12a7316896..ba1f58af402e 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -6,15 +6,16 @@ from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from torch.nn.intrinsic import _FusedModule +from typing import Dict, Any, List, Callable -def _check_is_graph_module(model): +def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): raise ValueError( 'input model must be a GraphModule, ' + 'Got type:' + str(type(model)) + ' Please make ' + 'sure to follow the tutorials.') -def _swap_ff_with_fxff(model): +def _swap_ff_with_fxff(model: torch.nn.Module) -> None: r""" Swap FloatFunctional with FXFloatFunctional """ modules_to_swap = [] @@ -28,7 +29,9 @@ def _swap_ff_with_fxff(model): del model._modules[name] model._modules[name] = torch.nn.quantized.FXFloatFunctional() -def _fuse_fx(graph_module, fuse_custom_config_dict=None): +def _fuse_fx( + graph_module: GraphModule, + fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Internal helper function to fuse modules in preparation for quantization Args: @@ -39,7 +42,8 @@ def _fuse_fx(graph_module, fuse_custom_config_dict=None): return fuser.fuse(graph_module, fuse_custom_config_dict) class CustomTracer(Tracer): - def __init__(self, skipped_module_names, skipped_module_classes): + def __init__(self, skipped_module_names: List[str], + skipped_module_classes: List[Callable]): super().__init__() self.skipped_module_names = skipped_module_names self.skipped_module_classes = skipped_module_classes @@ -52,7 +56,9 @@ def is_leaf_module(self, m, module_qualified_name): isinstance(m, _FusedModule) -def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False): +def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: r""" Internal helper function for prepare_fx Args: `model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx` @@ -93,7 +99,9 @@ def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standal prepare_custom_config_dict=prepare_custom_config_dict, is_standalone_module=is_standalone_module) -def _prepare_standalone_module_fx(model, qconfig_dict, prepare_custom_config_dict=None): +def _prepare_standalone_module_fx( + model: torch.nn.Module, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. standalone_module means it a submodule that is not inlined in parent module, @@ -104,7 +112,8 @@ def _prepare_standalone_module_fx(model, qconfig_dict, prepare_custom_config_dic """ return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) -def fuse_fx(model, fuse_custom_config_dict=None): +def fuse_fx(model: torch.nn.Module, + fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. Fusion rules are defined in torch.quantization.fx.fusion_pattern.py Args: @@ -128,7 +137,9 @@ def fuse_fx(model, fuse_custom_config_dict=None): graph_module = torch.fx.symbolic_trace(model) # type: ignore return _fuse_fx(graph_module, fuse_custom_config_dict) -def prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None): +def prepare_fx( + model: torch.nn.Module, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Prepare a model for post training static quantization Args: @@ -247,7 +258,9 @@ def calibrate(model, data_loader): 'eval mode' return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) -def prepare_qat_fx(model, qconfig_dict, prepare_custom_config_dict=None): +def prepare_qat_fx( + model: torch.nn.Module, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Prepare a model for quantization aware training Args: `model`: torch.nn.Module model, must be in train mode @@ -282,14 +295,19 @@ def train_loop(model, train_data): 'train mode' return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) -def _convert_fx(graph_module, debug, convert_custom_config_dict=None, is_standalone_module=False): +def _convert_fx( + graph_module: GraphModule, debug: bool, + convert_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx` """ _check_is_graph_module(graph_module) quantizer = Quantizer() return quantizer.convert(graph_module, debug, convert_custom_config_dict, is_standalone_module) -def convert_fx(graph_module, debug=False, convert_custom_config_dict=None): +def convert_fx( + graph_module: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" Convert a calibrated or trained model to a quantized model Args: `graph_module`: A prepared and calibrated/trained model (GraphModule) @@ -346,7 +364,9 @@ def convert_fx(graph_module, debug=False, convert_custom_config_dict=None): torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") return _convert_fx(graph_module, debug, convert_custom_config_dict) -def _convert_standalone_module_fx(graph_module, debug=False, convert_custom_config_dict=None): +def _convert_standalone_module_fx( + graph_module: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model