Skip to content

Commit

Permalink
Update on "[quant][graphmode][fx] Add support for dynamic quant for R…
Browse files Browse the repository at this point in the history
…NN and RNNCell"

Summary:

Test Plan:
python test/test_quantization.py TestQuantizeFxOps.test_rnn
python test/test_quantization.py TestQuantizeFxOps.test_rnn_cell

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D25449047](https://our.internmc.facebook.com/intern/diff/D25449047)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Dec 10, 2020
1 parent 2dbfead commit ff4353e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -2132,6 +2132,7 @@ def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_inp
model_graph = prepare_fx(model_graph, graph_qconfig_dict)
model_graph = convert_fx(model_graph)
self.assertEqual(model_eager(sample_input), model_graph(sample_input))
self.checkScriptable(model_graph, [[sample_input]], True)

def test_rnn_cell(self):
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
Expand Down
6 changes: 3 additions & 3 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -525,10 +525,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

module = quantizer.modules[node.target]
qmodule = get_dynamic_quant_module_class(type(module))
quantized = qmodule.from_float(module)
qmodule_cls = get_dynamic_quant_module_class(type(module))
qmodule = qmodule_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized)
setattr(quantizer.modules[parent_name], name, qmodule)
return quantizer.quantized_graph.create_node(
'call_module',
node.target,
Expand Down

0 comments on commit ff4353e

Please sign in to comment.