Skip to content

Commit

Permalink
[quantization] fix run_arg tiny bug (#48537)
Browse files Browse the repository at this point in the history
Summary:
This fix allows the calibration function to take in more than one positional argument.

Pull Request resolved: #48537

Reviewed By: zhangguanheng66

Differential Revision: D25255764

Pulled By: jerryzh168

fbshipit-source-id: 3ce20dbed95fd26664a186bd4a992ab406bba827
  • Loading branch information
leimao authored and facebook-github-bot committed Dec 2, 2020
1 parent f61de25 commit 0db7346
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
14 changes: 7 additions & 7 deletions test/quantization/test_numeric_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compare_and_validate_results(float_model, q_model):
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, test_only_eval_fn, self.img_data_2d)
q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
compare_and_validate_results(model, q_model)

@override_qengines
Expand All @@ -126,7 +126,7 @@ def compare_and_validate_results(float_model, q_model):
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, test_only_eval_fn, self.calib_data)
q_model = quantize(model, test_only_eval_fn, [self.calib_data])
compare_and_validate_results(model, q_model)

@override_qengines
Expand Down Expand Up @@ -197,7 +197,7 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data):
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, test_only_eval_fn, self.img_data_2d)
q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
compare_and_validate_results(
model, q_model, module_swap_list, self.img_data_2d[0][0]
)
Expand All @@ -223,7 +223,7 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data):
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, test_only_eval_fn, self.calib_data)
q_model = quantize(model, test_only_eval_fn, [self.calib_data])
compare_and_validate_results(model, q_model, module_swap_list, linear_data)

@override_qengines
Expand All @@ -233,7 +233,7 @@ def test_compare_model_stub_submodule_static(self):
qengine = torch.backends.quantized.engine

model = ModelWithSubModules().eval()
q_model = quantize(model, test_only_eval_fn, self.img_data_2d)
q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
module_swap_list = [SubModule]
ob_dict = compare_model_stub(
model, q_model, module_swap_list, self.img_data_2d[0][0]
Expand Down Expand Up @@ -350,7 +350,7 @@ def compare_and_validate_results(float_model, q_model, data):
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, test_only_eval_fn, self.img_data_2d)
q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
compare_and_validate_results(model, q_model, self.img_data_2d[0][0])

@override_qengines
Expand All @@ -376,7 +376,7 @@ def compare_and_validate_results(float_model, q_model, data):
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, test_only_eval_fn, self.calib_data)
q_model = quantize(model, test_only_eval_fn, [self.calib_data])
compare_and_validate_results(model, q_model, linear_data)

@override_qengines
Expand Down
28 changes: 14 additions & 14 deletions test/quantization/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ def checkQuantized(model):
base = AnnotatedSingleLayerLinearModel(qengine)
base.qconfig = qconfig
keys_before = set(list(base.state_dict().keys()))
model = quantize(base, test_only_eval_fn, self.calib_data)
model = quantize(base, test_only_eval_fn, [self.calib_data])
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
self.assertEqual(keys_before, keys_after) # simple check that nothing changed

# in-place version
model = AnnotatedSingleLayerLinearModel(qengine)
model.qconfig = qconfig
quantize(model, test_only_eval_fn, self.calib_data, inplace=True)
quantize(model, test_only_eval_fn, [self.calib_data], inplace=True)
checkQuantized(model)

@skipIfNoFBGEMM
Expand Down Expand Up @@ -162,7 +162,7 @@ def checkQuantized(model):

# test one line API
model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn,
self.calib_data)
[self.calib_data])
checkQuantized(model)

def test_nested1(self):
Expand Down Expand Up @@ -204,7 +204,7 @@ def checkQuantized(model):

# test one line API
model = quantize(AnnotatedNestedModel(qengine), test_only_eval_fn,
self.calib_data)
[self.calib_data])
checkQuantized(model)


Expand Down Expand Up @@ -245,7 +245,7 @@ def checkQuantized(model):

# test one line API
model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn,
self.calib_data)
[self.calib_data])
checkQuantized(model)

def test_nested3(self):
Expand Down Expand Up @@ -287,7 +287,7 @@ def checkQuantized(model):

# test one line API
model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn,
self.calib_data)
[self.calib_data])
checkQuantized(model)

def test_skip_quant(self):
Expand Down Expand Up @@ -315,7 +315,7 @@ def checkQuantized(model):
checkQuantized(model)

# test one line API
model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, self.calib_data)
model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, [self.calib_data])
checkQuantized(model)

@skipIfNoFBGEMM
Expand All @@ -341,7 +341,7 @@ def checkQuantized(model):
checkQuantized(model)

# test one line API
model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data)
model = quantize(QuantStubModel(), test_only_eval_fn, [self.calib_data])
checkQuantized(model)

def test_resnet_base(self):
Expand Down Expand Up @@ -400,7 +400,7 @@ def checkQuantized(model):
checkQuantized(model)

model_oneline = quantize(
NormalizationTestModel(), test_only_eval_fn, self.calib_data)
NormalizationTestModel(), test_only_eval_fn, [self.calib_data])
checkQuantized(model)

def test_save_load_state_dict(self):
Expand Down Expand Up @@ -463,7 +463,7 @@ def checkQuantized(model):

# test one line API
model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn,
self.calib_data)
[self.calib_data])
checkQuantized(model_oneline)

@override_qengines
Expand Down Expand Up @@ -1083,7 +1083,7 @@ def checkQuantized(model):
checkQuantized(model)

model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
self.train_data)
[self.train_data])
checkQuantized(model)

def test_eval_only_fake_quant(self):
Expand Down Expand Up @@ -1123,7 +1123,7 @@ def checkQuantized(model):
checkQuantized(model)

model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, self.img_data_2d_train)
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
checkQuantized(model)

def test_train_save_load_eval(self):
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def checkQuantized(model):
model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, self.img_data_1d_train)
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
checkQuantized(model)

Expand Down Expand Up @@ -1514,7 +1514,7 @@ def checkQuantized(model):
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn'],
['conv3', 'bn3', 'relu4']])
model = quantize(model, test_only_eval_fn, self.img_data_1d)
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
checkQuantized(model)

def test_fusion_sequential_model_train(self):
Expand Down
14 changes: 7 additions & 7 deletions test/quantization/test_quantize_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3095,7 +3095,7 @@ def test_single_linear(self):
# compare the result of the two quantized models later
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
model_eager = quantize(annotated_linear_model, test_only_eval_fn, self.calib_data)
model_eager = quantize(annotated_linear_model, test_only_eval_fn, [self.calib_data])

qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
Expand Down Expand Up @@ -3135,7 +3135,7 @@ def test_observer_with_ignored_function(self):
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
model_eager = quantize(annotated_linear_model, test_only_eval_fn,
self.calib_data)
[self.calib_data])

qconfig_dict = {'': qconfig}
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
Expand All @@ -3161,7 +3161,7 @@ def test_conv(self):
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach())
model_eager = quantize(annotated_conv_model, test_only_eval_fn, self.img_data_2d)
model_eager = quantize(annotated_conv_model, test_only_eval_fn, [self.img_data_2d])
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0])
model_script = torch.jit.script(conv_model)
Expand Down Expand Up @@ -3189,7 +3189,7 @@ def test_conv_transpose(self):
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach())
model_eager = quantize(annotated_conv_model, test_only_eval_fn, self.img_data_2d)
model_eager = quantize(annotated_conv_model, test_only_eval_fn, [self.img_data_2d])
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0])
model_script = torch.jit.script(conv_model)
Expand Down Expand Up @@ -3217,7 +3217,7 @@ def test_conv_bn(self):
conv_model_to_script.conv.weight = torch.nn.Parameter(conv_model.conv.weight.detach())
fuse_modules(conv_model, ['conv', 'bn'], inplace=True)
model_eager = quantize(conv_model, test_only_eval_fn,
self.img_data_2d)
[self.img_data_2d])
qconfig_dict = {
'': default_qconfig
}
Expand Down Expand Up @@ -3248,7 +3248,7 @@ def test_nested(self):
script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach())
script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach())

model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data)
model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data])
qconfig_dict = {
'sub2.fc1': default_per_channel_qconfig if qengine_is_fbgemm() else default_qconfig,
'fc3': default_qconfig
Expand Down Expand Up @@ -3284,7 +3284,7 @@ def test_skip_quant(self):

eager_model.fuse_modules()

model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data)
model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data])
qconfig_dict = {
'': get_default_qconfig(torch.backends.quantized.engine),
'fc': None
Expand Down
4 changes: 2 additions & 2 deletions torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def quantize(model, run_fn, run_args, mapping=None, inplace=False):
model = copy.deepcopy(model)
model.eval()
prepare(model, inplace=True)
run_fn(model, run_args)
run_fn(model, *run_args)
convert(model, mapping, inplace=True)
return model

Expand Down Expand Up @@ -422,7 +422,7 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
model = copy.deepcopy(model)
model.train()
prepare_qat(model, inplace=True)
run_fn(model, run_args)
run_fn(model, *run_args)
convert(model, inplace=True)
return model

Expand Down

0 comments on commit 0db7346

Please sign in to comment.