-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[quant][graphmode][api] Add backend_config_dict to prepare_fx api #64135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb12aac
de000f6
132d3eb
1fa2c8d
0fe5109
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,8 +140,9 @@ def create_node(self, kind : str, target : Target, | |
return node | ||
|
||
def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, | ||
prepare_custom_config_dict: Dict[str, Any] = None, | ||
equalization_qconfig_dict: Dict[str, Any] = None, | ||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None, | ||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None, | ||
backend_config_dict: Optional[Dict[str, Any]] = None, | ||
is_standalone_module: bool = False) -> ObservedGraphModule: | ||
r""" Internal helper function for prepare_fx | ||
Args: | ||
|
@@ -203,7 +204,8 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, | |
def _prepare_standalone_module_fx( | ||
model: torch.nn.Module, | ||
qconfig_dict: Any, | ||
prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: | ||
prepare_custom_config_dict: Dict[str, Any] = None, | ||
backend_config_dict: Dict[str, Any] = None) -> GraphModule: | ||
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the | ||
parent module. | ||
standalone_module means it a submodule that is not inlined in parent module, | ||
|
@@ -224,7 +226,7 @@ def _prepare_standalone_module_fx( | |
same as input_quantized_idxs configuration provided | ||
for the standalone module | ||
""" | ||
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) | ||
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict, is_standalone_module=True) | ||
|
||
def fuse_fx(model: torch.nn.Module, | ||
fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: | ||
|
@@ -265,8 +267,9 @@ def fuse_fx(model: torch.nn.Module, | |
|
||
def prepare_fx( | ||
model: torch.nn.Module, qconfig_dict: Any, | ||
prepare_custom_config_dict: Dict[str, Any] = None, | ||
equalization_qconfig_dict: Dict[str, Any] = None) -> ObservedGraphModule: | ||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None, | ||
equalization_qconfig_dict: Optional[Dict[str, Any]] = None, | ||
backend_config_dict: Optional[Dict[str, Any]] = None) -> ObservedGraphModule: | ||
r""" Prepare a model for post training static quantization | ||
|
||
Args: | ||
|
@@ -392,6 +395,11 @@ def prepare_fx( | |
with a similar structure as qconfig_dict except it will contain | ||
configurations specific to equalization techniques such as input-weight | ||
equalization. | ||
`backend_config_dict`: a dictionary that specifies how operators are quantized | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we update this to what is supported currently? Right now, this PR is just introducing the argument, should we have some functionality with this dict before adding it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, that is introduced in the next PR.. sorry, will probably merge these things next time I add a new argument |
||
in a backend, this includes how the operaetors are observed, | ||
supported fusion patterns, how quantize/dequantize ops are | ||
inserted, supported dtypes etc. The structure of the dictionary is still WIP | ||
and will change in the future, please don't use right now. | ||
|
||
|
||
Return: | ||
|
@@ -420,16 +428,18 @@ def calibrate(model, data_loader): | |
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") | ||
assert not model.training, 'prepare_fx only works for models in ' + \ | ||
'eval mode' | ||
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict) | ||
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict) | ||
|
||
def prepare_qat_fx( | ||
model: torch.nn.Module, qconfig_dict: Any, | ||
prepare_custom_config_dict: Dict[str, Any] = None) -> ObservedGraphModule: | ||
prepare_custom_config_dict: Optional[Dict[str, Any]] = None, | ||
backend_config_dict: Optional[Dict[str, Any]] = None) -> ObservedGraphModule: | ||
r""" Prepare a model for quantization aware training | ||
Args: | ||
`model`: torch.nn.Module model, must be in train mode | ||
`qconfig_dict`: see :func:`~torch.quantization.prepare_fx` | ||
`prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx` | ||
`backend_config_dict`: see :func:`~torch.quantization.prepare_fx` | ||
|
||
Return: | ||
A GraphModule with fake quant modules (configured by qconfig_dict), ready for | ||
|
@@ -457,7 +467,7 @@ def train_loop(model, train_data): | |
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") | ||
assert model.training, 'prepare_qat_fx only works for models in ' + \ | ||
'train mode' | ||
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) | ||
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict) | ||
|
||
def _convert_fx( | ||
graph_module: GraphModule, is_reference: bool, | ||
|
Uh oh!
There was an error while loading. Please reload this page.