Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx] Add support for dynamic quant for RNN and RNNCell #49126

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -28,6 +28,7 @@
default_qconfig,
default_dynamic_qconfig,
default_qat_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_weight_only_qconfig,
get_default_qconfig,
Expand All @@ -36,6 +37,7 @@
prepare,
prepare_qat,
convert,
quantize_dynamic,
default_placeholder_observer,
PerChannelMinMaxObserver,
QConfigDynamic,
Expand All @@ -57,6 +59,8 @@
from torch.testing._internal.common_quantization import (
LinearModelWithSubmodule,
ResNetBase,
RNNDynamicModel,
RNNCellDynamicModel,
)

from torch.testing._internal.common_quantized import (
Expand Down Expand Up @@ -2107,6 +2111,48 @@ def forward(self, indices, offsets):
# make sure it runs
m(*inputs)

def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input):
options = itertools.product(qconfigs, module_type_strs)
for qconfig, module_type_str in options:
model_eager = M(module_type_str).eval()
model_graph = copy.deepcopy(model_eager)
if torch.backends.quantized.engine == 'qnnpack' and \
qconfig is float16_dynamic_qconfig:
continue
# fp16 dynamic quant is not supported for qnnpack

eager_qconfig_dict = {x : qconfig for x in module_types}
model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict)

graph_qconfig_dict = {
"object_type": [
(x, qconfig) for x in module_types
]
}
model_graph = prepare_fx(model_graph, graph_qconfig_dict)
model_graph = convert_fx(model_graph)
self.assertEqual(model_eager(sample_input), model_graph(sample_input))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to test for serialization here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be the same as eager mode module, I'm not very familiar, are we using state_dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or are you referring to checkScriptable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added checkScriptable here, but in general we'll do e2e test in TestQuantizeFxModels


def test_rnn_cell(self):
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']
module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell]
sample_input = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float)
self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input)

def test_rnn(self):
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
module_type_strs = ['LSTM']
module_types = [torch.nn.LSTM]
niter = 10
sample_input = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input)


class TestQuantizeFxModels(QuantizationTestCase):
def _test_model_impl(
self, mode, name, model, eager_quantizable_model,
Expand Down
45 changes: 44 additions & 1 deletion torch/quantization/fx/quantization_patterns.py
Expand Up @@ -11,6 +11,7 @@

from ..quantization_mappings import (
get_static_quant_module_class,
get_dynamic_quant_module_class,
get_quantized_operator,
)
from ..utils import (
Expand Down Expand Up @@ -471,7 +472,6 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
]
assert node.op == 'call_module'
emb_node = node
emb = quantizer.modules[emb_node.target]
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
if dtypes not in supported_dtypes:
Expand All @@ -481,6 +481,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

emb = quantizer.modules[emb_node.target]
qemb = get_static_quant_module_class(type(emb))
quantized = qemb.from_float(emb)
parent_name, name = _parent_name(emb_node.target)
Expand All @@ -491,6 +492,48 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
load_arg(quantized=False)(emb_node.args),
load_arg(quantized=False)(emb_node.kwargs))

# TODO (maybe): merge with embedding quantize handler
@register_quant_pattern(torch.nn.GRUCell)
@register_quant_pattern(torch.nn.LSTMCell)
@register_quant_pattern(torch.nn.RNNCell)
@register_quant_pattern(torch.nn.LSTM)
@mark_input_output_not_observed()
class RNNDynamic(QuantizeHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)

def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation | weight | activation_compute_type
# dynamic | float32 | qint8 | quint8
# dynamic | float16 | float16 | None
# tuple (activation_dtype, weight_dtype, compute_dtype)
supported_dtypes = [
(torch.float32, torch.qint8, torch.quint8),
(torch.float16, torch.float16, None),
]
assert node.op == 'call_module'
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by Embedding/EmbeddingBag, "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

module = quantizer.modules[node.target]
qmodule = get_dynamic_quant_module_class(type(module))
quantized = qmodule.from_float(module)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
node.target,
load_arg(quantized=False)(node.args),
load_arg(quantized=False)(node.kwargs))

ARGS_TO_SKIP = {
torch._ops.ops.quantized.hardswish: ['inplace'],
Expand Down
13 changes: 13 additions & 0 deletions torch/quantization/quantization_mappings.py
Expand Up @@ -124,6 +124,19 @@ def get_static_quant_module_class(float_module_class, additional_static_quant_ma
" does not have a corresponding quantized module class"
return static_quant_module_class

def get_dynamic_quant_module_class(float_module_class, additional_dynamic_quant_mapping=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to add types to function I/O

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add it in a separate PR I think, all other functions in this file are not typed yet

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not blocking this PR, but would be awesome if we started adding these as we go, at least to function I/O. We don't have to wait for a file to have existing type annots to add more. This also distributes the cost of adding them to everyone, as opposed to one person.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, fully agree that we should add types as we change code. I'm saying I plan to add it in a separate PR, or are you suggesting to add the type annotations for the functions in this file in this PR?

r"""n Get the dynamically quantized module class corresponding to
the floating point module class
"""
if additional_dynamic_quant_mapping is None:
additional_dynamic_quant_mapping = {}
all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping)
dynamic_quant_module_class = all_mappings.get(float_module_class, None)
assert dynamic_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \
" does not have a corresponding quantized module class"
return dynamic_quant_module_class

def get_default_qat_module_mappings():
''' Get default module mapping for quantization aware training
'''
Expand Down