Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx quant: more typehints, part 3 #48794

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 32 additions & 12 deletions torch/quantization/quantize_fx.py
Expand Up @@ -6,15 +6,16 @@
from .fx.utils import graph_pretty_str # noqa: F401
from .fx.utils import get_custom_module_class_keys # noqa: F401
from torch.nn.intrinsic import _FusedModule
from typing import Dict, Any, List, Callable

def _check_is_graph_module(model):
def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
raise ValueError(
'input model must be a GraphModule, ' +
'Got type:' + str(type(model)) + ' Please make ' +
'sure to follow the tutorials.')

def _swap_ff_with_fxff(model):
def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
r""" Swap FloatFunctional with FXFloatFunctional
"""
modules_to_swap = []
Expand All @@ -28,7 +29,9 @@ def _swap_ff_with_fxff(model):
del model._modules[name]
model._modules[name] = torch.nn.quantized.FXFloatFunctional()

def _fuse_fx(graph_module, fuse_custom_config_dict=None):
def _fuse_fx(
graph_module: GraphModule,
fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" Internal helper function to fuse modules in preparation for quantization

Args:
Expand All @@ -39,7 +42,8 @@ def _fuse_fx(graph_module, fuse_custom_config_dict=None):
return fuser.fuse(graph_module, fuse_custom_config_dict)

class CustomTracer(Tracer):
def __init__(self, skipped_module_names, skipped_module_classes):
def __init__(self, skipped_module_names: List[str],
skipped_module_classes: List[Callable]):
super().__init__()
self.skipped_module_names = skipped_module_names
self.skipped_module_classes = skipped_module_classes
Expand All @@ -52,7 +56,9 @@ def is_leaf_module(self, m, module_qualified_name):
isinstance(m, _FusedModule)


def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False):
def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
r""" Internal helper function for prepare_fx
Args:
`model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx`
Expand Down Expand Up @@ -93,7 +99,9 @@ def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standal
prepare_custom_config_dict=prepare_custom_config_dict,
is_standalone_module=is_standalone_module)

def _prepare_standalone_module_fx(model, qconfig_dict, prepare_custom_config_dict=None):
def _prepare_standalone_module_fx(
model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
standalone_module means it a submodule that is not inlined in parent module,
Expand All @@ -104,7 +112,8 @@ def _prepare_standalone_module_fx(model, qconfig_dict, prepare_custom_config_dic
"""
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True)

def fuse_fx(model, fuse_custom_config_dict=None):
def fuse_fx(model: torch.nn.Module,
fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
Args:
Expand All @@ -128,7 +137,9 @@ def fuse_fx(model, fuse_custom_config_dict=None):
graph_module = torch.fx.symbolic_trace(model) # type: ignore
return _fuse_fx(graph_module, fuse_custom_config_dict)

def prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None):
def prepare_fx(
model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" Prepare a model for post training static quantization

Args:
Expand Down Expand Up @@ -247,7 +258,9 @@ def calibrate(model, data_loader):
'eval mode'
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)

def prepare_qat_fx(model, qconfig_dict, prepare_custom_config_dict=None):
def prepare_qat_fx(
model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" Prepare a model for quantization aware training
Args:
`model`: torch.nn.Module model, must be in train mode
Expand Down Expand Up @@ -282,14 +295,19 @@ def train_loop(model, train_data):
'train mode'
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)

def _convert_fx(graph_module, debug, convert_custom_config_dict=None, is_standalone_module=False):
def _convert_fx(
graph_module: GraphModule, debug: bool,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
""" `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx`
"""
_check_is_graph_module(graph_module)
quantizer = Quantizer()
return quantizer.convert(graph_module, debug, convert_custom_config_dict, is_standalone_module)

def convert_fx(graph_module, debug=False, convert_custom_config_dict=None):
def convert_fx(
graph_module: GraphModule, debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" Convert a calibrated or trained model to a quantized model
Args:
`graph_module`: A prepared and calibrated/trained model (GraphModule)
Expand Down Expand Up @@ -346,7 +364,9 @@ def convert_fx(graph_module, debug=False, convert_custom_config_dict=None):
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
return _convert_fx(graph_module, debug, convert_custom_config_dict)

def _convert_standalone_module_fx(graph_module, debug=False, convert_custom_config_dict=None):
def _convert_standalone_module_fx(
graph_module: GraphModule, debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx`
and convert it to a quantized model

Expand Down