diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 7dec874f41b1..ab4729b2ec39 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -814,6 +814,7 @@ def forward(self, x): # make sure these modules are not traced self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) +@skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops """ @@ -1486,6 +1487,39 @@ def forward(self, x): expected_node_occurrence=count_check, expected_node_list=order_check) + def test_float_functional(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.ff1 = nnq.FloatFunctional() + self.ff2 = nnq.FloatFunctional() + self.ff3 = nnq.FloatFunctional() + + def forward(self, x): + x = self.ff1.add(x, x) + x = self.ff2.mul(x, x) + x = self.ff3.add_relu(x, x) + return x + + m = M() + m.eval() + qconfig_dict = {"": default_qconfig} + m = prepare_fx(m, qconfig_dict) + node_occurrence = { + ns.call_module(torch.quantization.MinMaxObserver): 4, + ns.call_module(torch.nn.quantized.FloatFunctional): 0 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + node_list = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_function(torch.ops.quantized.add), + ns.call_function(torch.ops.quantized.mul), + ns.call_function(torch.ops.quantized.add_relu), + ns.call_method('dequantize') + ] + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node_list=node_list) + class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( self, mode, name, model, eager_quantizable_model, diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index 248095e19033..a40a3e3fbcac 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -11,7 +11,7 @@ from .linear import Linear from .embedding_ops import Embedding, EmbeddingBag -from .functional_modules import FloatFunctional, QFunctional +from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional class Quantize(torch.nn.Module): @@ -110,5 +110,6 @@ def from_float(mod): 'Sigmoid', # Wrapper modules 'FloatFunctional', + 'FXFloatFunctional', 'QFunctional', ] diff --git a/torch/nn/quantized/modules/functional_modules.py b/torch/nn/quantized/modules/functional_modules.py index d3fa7189e056..9adb3af92f1d 100644 --- a/torch/nn/quantized/modules/functional_modules.py +++ b/torch/nn/quantized/modules/functional_modules.py @@ -82,6 +82,58 @@ def add_relu(self, x, y): r = self.activation_post_process(r) return r +class FXFloatFunctional(FloatFunctional): + r""" module to replace FloatFunctional module before FX graph mode quantization, + since activation_post_process will be inserted in top level module directly + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + def forward(self, x): + raise RuntimeError("FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation") + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + def add(self, x, y): + # type: (Tensor, Tensor) -> Tensor + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + def add_scalar(self, x, y): + # type: (Tensor, float) -> Tensor + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + def mul(self, x, y): + # type: (Tensor, Tensor) -> Tensor + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + def mul_scalar(self, x, y): + # type: (Tensor, float) -> Tensor + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.cat``""" + def cat(self, x, dim=0): + # type: (List[Tensor], int) -> Tensor + r = torch.cat(x, dim=dim) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + def add_relu(self, x, y): + # type: (Tensor, Tensor) -> Tensor + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + return r class QFunctional(torch.nn.Module): r"""Wrapper class for quantized operations. diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 70aba15f931f..5b7621b9d635 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -52,8 +52,11 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ return NotImplemented @register_quant_pattern(operator.add) +@register_quant_pattern(torch.add) @register_quant_pattern((torch.nn.ReLU, operator.add)) +@register_quant_pattern((torch.nn.ReLU, torch.add)) @register_quant_pattern((torch.nn.functional.relu, operator.add)) +@register_quant_pattern((torch.nn.functional.relu, torch.add)) class Add(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) @@ -62,7 +65,7 @@ def __init__(self, quantizer, node): (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node node = node.args[0] - assert node.op == 'call_function' and node.target == operator.add + assert node.op == 'call_function' and node.target in [operator.add, torch.add] self.add_node = node self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]]) @@ -90,8 +93,11 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ 'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs) @register_quant_pattern(operator.mul) +@register_quant_pattern(torch.mul) @register_quant_pattern((torch.nn.ReLU, operator.mul)) +@register_quant_pattern((torch.nn.ReLU, torch.mul)) @register_quant_pattern((torch.nn.functional.relu, operator.mul)) +@register_quant_pattern((torch.nn.functional.relu, torch.mul)) class Mul(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) @@ -100,7 +106,7 @@ def __init__(self, quantizer, node): (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node node = node.args[0] - assert node.op == 'call_function' and node.target == operator.mul + assert node.op == 'call_function' and node.target in [operator.mul, torch.mul] self.mul_node = node self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]]) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 92b243ef7148..923cd6d54bb5 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -117,7 +117,6 @@ def load_arg(a): graph_module = GraphModule(root, graph) return graph_module - def assert_and_get_unique_device(module): """ Returns the unique device for a module, or None if no device is found. diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 0ae395e3cf57..b6eb32c25e46 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -12,6 +12,16 @@ def _check_is_graph_module(model): 'Got type:' + str(type(model)) + ' Please make ' + 'sure to follow the tutorials.') +def _swap_ff_with_fxff(model): + r""" Swap FloatFunctional with FXFloatFunctional + """ + modules_to_swap = [] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + for name in modules_to_swap: + model._modules[name] = torch.nn.quantized.FXFloatFunctional() + def _fuse_fx(graph_module, inplace=False): r""" Internal helper function to fuse modules in preparation for quantization @@ -52,6 +62,9 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", []) skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", []) + # swap FloatFunctional with FXFloatFunctional + _swap_ff_with_fxff(model) + # symbolically trace the model if not is_standalone_module: # standalone module and custom module config are applied in top level module