Skip to content

Commit

Permalink
Update on "[quant][graphmode][fx] Support sigmoid/hardsigmoid/tanh in…
Browse files Browse the repository at this point in the history
… qat"

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D24486972](https://our.internmc.facebook.com/intern/diff/D24486972)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 23, 2020
1 parent 85a777e commit 4a1d3b6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 3 additions & 1 deletion test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,17 +1509,19 @@ def forward(self, x):
m.eval()
qconfig = default_qconfig
prepare = prepare_fx
fq_count = 0
else:
m.train()
qconfig = default_qat_qconfig
prepare = prepare_qat_fx
fq_count = 13

# nothing to fuse so skipping the fuse step
qconfig_dict = {'': qconfig}
prepared = prepare(m, qconfig_dict)
# check the correct number of activation_post_process is inserted
count_check = {
ns.call_module(FixedQParamsFakeQuantize) : 13,
ns.call_module(FixedQParamsFakeQuantize) : fq_count,
}
self.checkGraphModuleNodes(
prepared,
Expand Down
6 changes: 4 additions & 2 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,17 @@ def insert_observer(node, observer, device):
if not activation_is_statically_quantized(qconfig):
continue

if isinstance(obj, FixedQParamsOpQuantizeHandler):
if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training:
# we only insert fake quantize module in qat
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(pattern, None)
assert activation_post_process_ctr is not None, \
'activation_post_process constructor not provided for ' + \
'pattern:' + str(pattern)
device = assert_and_get_unique_device(model)
insert_observer(node, activation_post_process_ctr(), device)
elif isinstance(obj, CopyNode):
elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and not model.training) \
or isinstance(obj, CopyNode):
# inserting observers for output of observed module, or mark the output
# as observed
assert node.op in [
Expand Down

0 comments on commit 4a1d3b6

Please sign in to comment.