Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ def prepare(
node_name_to_scope: Dict[str, Tuple[str, type]],
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False) -> ObservedGraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Expand Down
28 changes: 19 additions & 9 deletions torch/quantization/quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ def create_node(self, kind : str, target : Target,
return node

def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
equalization_qconfig_dict: Dict[str, Any] = None,
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False) -> ObservedGraphModule:
r""" Internal helper function for prepare_fx
Args:
Expand Down Expand Up @@ -203,7 +204,8 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
def _prepare_standalone_module_fx(
model: torch.nn.Module,
qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
prepare_custom_config_dict: Dict[str, Any] = None,
backend_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
standalone_module means it a submodule that is not inlined in parent module,
Expand All @@ -224,7 +226,7 @@ def _prepare_standalone_module_fx(
same as input_quantized_idxs configuration provided
for the standalone module
"""
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True)
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict, is_standalone_module=True)

def fuse_fx(model: torch.nn.Module,
fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
Expand Down Expand Up @@ -265,8 +267,9 @@ def fuse_fx(model: torch.nn.Module,

def prepare_fx(
model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
equalization_qconfig_dict: Dict[str, Any] = None) -> ObservedGraphModule:
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None) -> ObservedGraphModule:
r""" Prepare a model for post training static quantization

Args:
Expand Down Expand Up @@ -392,6 +395,11 @@ def prepare_fx(
with a similar structure as qconfig_dict except it will contain
configurations specific to equalization techniques such as input-weight
equalization.
`backend_config_dict`: a dictionary that specifies how operators are quantized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update this to what is supported currently? Right now, this PR is just introducing the argument, should we have some functionality with this dict before adding it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that is introduced in the next PR.. sorry, will probably merge these things next time I add a new argument

in a backend, this includes how the operaetors are observed,
supported fusion patterns, how quantize/dequantize ops are
inserted, supported dtypes etc. The structure of the dictionary is still WIP
and will change in the future, please don't use right now.


Return:
Expand Down Expand Up @@ -420,16 +428,18 @@ def calibrate(model, data_loader):
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
assert not model.training, 'prepare_fx only works for models in ' + \
'eval mode'
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict)
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict)

def prepare_qat_fx(
model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None) -> ObservedGraphModule:
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None) -> ObservedGraphModule:
r""" Prepare a model for quantization aware training
Args:
`model`: torch.nn.Module model, must be in train mode
`qconfig_dict`: see :func:`~torch.quantization.prepare_fx`
`prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx`
`backend_config_dict`: see :func:`~torch.quantization.prepare_fx`

Return:
A GraphModule with fake quant modules (configured by qconfig_dict), ready for
Expand Down Expand Up @@ -457,7 +467,7 @@ def train_loop(model, train_data):
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
assert model.training, 'prepare_qat_fx only works for models in ' + \
'train mode'
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict)

def _convert_fx(
graph_module: GraphModule, is_reference: bool,
Expand Down