Skip to content

Commit

Permalink
Enable pickling model prepared with QAT qconfig (#109288)
Browse files Browse the repository at this point in the history
Summary:
Resolving error:

AttributeError: Can't pickle local object '_add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device'

by moving nested function out to the main module

Test Plan: Added test to CI

Reviewed By: andrewor14

Differential Revision: D49187352

Pull Request resolved: #109288
Approved by: https://github.com/andrewor14
  • Loading branch information
sindish authored and pytorchmergebot committed Sep 28, 2023
1 parent c71a64c commit 419ec3b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
14 changes: 13 additions & 1 deletion test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Owner(s): ["oncall: quantization"]

import pickle
from collections import OrderedDict
import contextlib
import torch
Expand Down Expand Up @@ -9714,6 +9714,18 @@ def forward(self, x):
out_ref = converted_ref(inp)

torch.testing.assert_close(out, out_ref)

@override_qengines
def test_qat_pickle(self):
model = torch.nn.Sequential(nn.Conv2d(3, 32, 5), nn.BatchNorm2d(32), nn.ReLU())
example_inputs = torch.randn(1, 3, 128, 128)

qengine = torch.backends.quantized.engine
qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)}
model = prepare_qat_fx(model, qconfig_dict, example_inputs)
pickle.dumps(model)


if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
Expand Down
17 changes: 10 additions & 7 deletions torch/ao/quantization/qconfig.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import namedtuple
from functools import partial
from typing import Optional, Any, Union, Type

import torch
Expand Down Expand Up @@ -458,6 +459,13 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig],
QConfigAny = Optional[QConfig]
QConfigAny.__module__ = "torch.ao.quantization.qconfig"

def _get_factory_kwargs_based_on_module_device(module):
assert isinstance(module, torch.nn.Module)
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
device = next(iter(devices)) if len(devices) > 0 else None
return None if device is None else {'device': device}

def _add_module_to_qconfig_obs_ctr(
qconfig: QConfigAny,
module: Optional[nn.Module]) -> Any:
Expand All @@ -477,19 +485,14 @@ def _add_module_to_qconfig_obs_ctr(
if module is None or qconfig is None or qconfig._fields != ('activation', 'weight'):
return qconfig

def get_factory_kwargs_based_on_module_device():
assert isinstance(module, torch.nn.Module)
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
device = next(iter(devices)) if len(devices) > 0 else None
return None if device is None else {'device': device}

def configure_constructor_to_put_obs_on_module_device(original_constructor):
get_factory_kwargs_based_on_module_device_with_model = partial(_get_factory_kwargs_based_on_module_device, module)
try:
# check if constructor can accept factory_kwargs
check = original_constructor.with_args(factory_kwargs=None)
check()
return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device)
return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device_with_model)
except AttributeError: # qconfig doesn't have activation or weight
return original_constructor
except TypeError: # the class doesn't accept factory_kwargs argument
Expand Down

0 comments on commit 419ec3b

Please sign in to comment.