Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx][api] Remove inplace option from prepare_fx #46954

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 2 additions & 4 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -278,13 +278,11 @@ def forward(self, x):

model = M().eval()
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(
model, qconfig_dict, inplace=False)
prepared = prepare_fx(model, qconfig_dict)
test_only_eval_fn(model, self.img_data_2d)
non_inplace_model = convert_fx(prepared, inplace=True)

prepared = prepare_fx(
model, qconfig_dict, inplace=True)
prepared = prepare_fx(model, qconfig_dict)
test_only_eval_fn(model, self.img_data_2d)
inplace_model = convert_fx(prepared, inplace=True)

Expand Down
8 changes: 3 additions & 5 deletions torch/quantization/fx/quantize.py
Expand Up @@ -308,7 +308,7 @@ def get_qconfig(module_name):
self.modules[node.target].qconfig = module_qconfig
self.qconfig_map[node.name] = module_qconfig

def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module):
def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalone_module):
""" standalone_module means it a submodule that is not inlined in parent module,
and will be quantized separately as one unit.

Expand All @@ -324,8 +324,6 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_
"""
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)
additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {})
self.patterns = get_default_quant_patterns().copy()
for k, v in additional_quant_patterns.items():
Expand Down Expand Up @@ -534,8 +532,8 @@ def restore_state(self, observed):
self.patterns = observed._patterns
self.qconfig_map = observed._qconfig_map

def prepare(self, model, qconfig_dict, inplace=False, prepare_custom_config_dict=None, is_standalone_module=False):
return self._prepare(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module)
def prepare(self, model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False):
return self._prepare(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module)

def _run_weight_observers(self, observed):
r''' Extract the subgraph that produces the weight for dynamic quant
Expand Down
21 changes: 8 additions & 13 deletions torch/quantization/quantize_fx.py
Expand Up @@ -49,10 +49,10 @@ def is_leaf_module(self, m, module_qualified_name):
type(m) in self.skipped_module_classes


def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, is_standalone_module=False):
def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False):
r""" Internal helper function for prepare_fx
Args:
`model`, `qconfig_dict`, `inplace` `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx`
`model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx`
`is_standalone_module`: a boolean flag indicates whether we are
quantizing a standalone module or not, a standalone module
is a submodule of the parent module that is not inlined in the
Expand Down Expand Up @@ -84,11 +84,10 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i
return quantizer.prepare(
graph_module,
qconfig_dict,
inplace=True,
prepare_custom_config_dict=prepare_custom_config_dict,
is_standalone_module=is_standalone_module)

def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
def _prepare_standalone_module_fx(model, qconfig_dict, prepare_custom_config_dict=None):
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
standalone_module means it a submodule that is not inlined in parent module,
Expand All @@ -104,7 +103,7 @@ def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_cu

"""
torch._C._log_api_usage_once("quantization_api.quantize_fx._prepare_standalone_module_fx")
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True)
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True)


def fuse_fx(model, fuse_custom_config_dict=None):
Expand All @@ -131,7 +130,7 @@ def fuse_fx(model, fuse_custom_config_dict=None):
graph_module = torch.fx.symbolic_trace(model)
return _fuse_fx(graph_module, fuse_custom_config_dict)

def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
def prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None):
r""" Prepare a model for post training static quantization

Args:
Expand Down Expand Up @@ -164,8 +163,6 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
# qconfig == None means fusion and quantization should be skipped for anything
# matching the rule
}
`inplace`: flag for carry out model transformations in-place,
the original module is mutated
`prepare_custom_config_dict`: customization configuration dictionary for
quantization tool:
prepare_custom_config_dict = {
Expand Down Expand Up @@ -241,15 +238,13 @@ def calibrate(model, data_loader):
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
assert not model.training, 'prepare_fx only works for models in' + \
'eval mode'
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict)
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)

def prepare_qat_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
def prepare_qat_fx(model, qconfig_dict, prepare_custom_config_dict=None):
r""" Prepare a model for quantization aware training
Args:
`model`: torch.nn.Module model, must be in train mode
`qconfig_dict`: see :func:`~torch.quantization.prepare_fx`
`inplace`: flag for carry out model transformations in-place,
the original module is mutated
`prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx`

Return:
Expand Down Expand Up @@ -278,7 +273,7 @@ def train_loop(model, train_data):
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
assert model.training, 'prepare_qat_fx only works for models in ' + \
'train mode'
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict)
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)

def _convert_fx(graph_module, inplace, debug, convert_custom_config_dict=None, is_standalone_module=False):
""" `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx`
Expand Down