Skip to content

Commit

Permalink
[quant][graphmode][fx] Support quantizing FloatFunctional
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c44920c68672590c7464451635b6ea24ac3e6797
Pull Request resolved: #46634
  • Loading branch information
jerryzh168 committed Oct 22, 2020
1 parent 6de619e commit e91476b
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 7 deletions.
82 changes: 79 additions & 3 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -12,9 +12,6 @@
prepare_fx,
convert_fx,
prepare_qat_fx,
)

from torch.quantization import (
default_qconfig,
default_dynamic_qconfig,
float16_dynamic_qconfig,
Expand Down Expand Up @@ -814,6 +811,7 @@ def forward(self, x):
# make sure these modules are not traced
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
"""
Expand Down Expand Up @@ -1486,6 +1484,84 @@ def forward(self, x):
expected_node_occurrence=count_check,
expected_node_list=order_check)

def test_float_functional(self):
class TorchAdd(nn.Module):
"""Wrapper around torch.add so that all ops can be found at build"""
def __init__(self):
super().__init__()
self.add_func = nnq.FloatFunctional()

def forward(self, x, y):
return self.add_func.add(x, y)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.ff1 = TorchAdd()
self.ff2 = nnq.FloatFunctional()
self.ff3 = nnq.FloatFunctional()
self.ff4 = nnq.FloatFunctional()
self.ff5 = nnq.FloatFunctional()
self.ff6 = nnq.FloatFunctional()

def forward(self, x):
x = self.ff1(x, x)
x = self.ff2.add_scalar(x, 3)
x = self.ff3.mul(x, x)
x = self.ff4.mul_scalar(x, 3)
x = self.ff5.add_relu(x, x)
x = self.ff6.cat([x])
return x

data = torch.rand(3, 3)
# Note: QAT test succeeded by chance, to make it actually work
# we need to fix eager mode FloatFunctional by removing
# activation_post_process in add_scalar and mul_scalar
for quant_type in self.static_quant_types:
m = M()
ref_m = torch.quantization.QuantWrapper(M())
is_qat = quant_type == QuantType.QAT
if is_qat:
m.train()
ref_m.train()
qconfig = default_qat_qconfig
expected_act_post_process = torch.quantization.FakeQuantize
else:
m.eval()
ref_m.eval()
qconfig = default_qconfig
expected_act_post_process = torch.quantization.MinMaxObserver

prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx
qconfig_dict = {"": qconfig}
m = prepare_fx_function(m, qconfig_dict)
node_occurrence = {
ns.call_module(expected_act_post_process): 5,
ns.call_module(torch.nn.quantized.FloatFunctional): 0
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
m(data)
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.ops.quantized.add),
ns.call_function(torch.ops.quantized.add),
ns.call_function(torch.ops.quantized.mul),
ns.call_function(torch.ops.quantized.mul),
ns.call_function(torch.ops.quantized.add_relu),
ns.call_function(torch.ops.quantized.cat),
ns.call_method('dequantize')
]
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node_list=node_list)

# make sure numerics match with eager mode
ref_m.qconfig = qconfig
prepare_function = prepare_qat if is_qat else prepare
ref_m = prepare_function(ref_m)
ref_m(data)
ref_m = convert(ref_m)
self.assertEqual(m(data), ref_m(data))

class TestQuantizeFxModels(QuantizationTestCase):
def _test_model_impl(
self, mode, name, model, eager_quantizable_model,
Expand Down
3 changes: 2 additions & 1 deletion torch/nn/quantized/modules/__init__.py
Expand Up @@ -11,7 +11,7 @@
from .linear import Linear
from .embedding_ops import Embedding, EmbeddingBag

from .functional_modules import FloatFunctional, QFunctional
from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional


class Quantize(torch.nn.Module):
Expand Down Expand Up @@ -110,5 +110,6 @@ def from_float(mod):
'Sigmoid',
# Wrapper modules
'FloatFunctional',
'FXFloatFunctional',
'QFunctional',
]
52 changes: 52 additions & 0 deletions torch/nn/quantized/modules/functional_modules.py
Expand Up @@ -82,6 +82,58 @@ def add_relu(self, x, y):
r = self.activation_post_process(r)
return r

class FXFloatFunctional(torch.nn.Module):
r""" module to replace FloatFunctional module before FX graph mode quantization,
since activation_post_process will be inserted in top level module directly
Valid operation names:
- add
- cat
- mul
- add_relu
- add_scalar
- mul_scalar
"""
def forward(self, x):
raise RuntimeError("FloatFunctional is not intended to use the " +
"'forward'. Please use the underlying operation")

r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x, y):
# type: (Tensor, Tensor) -> Tensor
r = torch.add(x, y)
return r

r"""Operation equivalent to ``torch.add(Tensor, float)``"""
def add_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
r = torch.add(x, y)
return r

r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
def mul(self, x, y):
# type: (Tensor, Tensor) -> Tensor
r = torch.mul(x, y)
return r

r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
def mul_scalar(self, x, y):
# type: (Tensor, float) -> Tensor
r = torch.mul(x, y)
return r

r"""Operation equivalent to ``torch.cat``"""
def cat(self, x, dim=0):
# type: (List[Tensor], int) -> Tensor
r = torch.cat(x, dim=dim)
return r

r"""Operation equivalent to ``relu(torch.add(x,y))``"""
def add_relu(self, x, y):
# type: (Tensor, Tensor) -> Tensor
r = torch.add(x, y)
r = torch.nn.functional.relu(r)
return r

class QFunctional(torch.nn.Module):
r"""Wrapper class for quantized operations.
Expand Down
10 changes: 8 additions & 2 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -52,8 +52,11 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
return NotImplemented

@register_quant_pattern(operator.add)
@register_quant_pattern(torch.add)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.ReLU, torch.add))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
@register_quant_pattern((torch.nn.functional.relu, torch.add))
class Add(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
Expand All @@ -62,7 +65,7 @@ def __init__(self, quantizer, node):
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.add
assert node.op == 'call_function' and node.target in [operator.add, torch.add]
self.add_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]])

Expand Down Expand Up @@ -90,8 +93,11 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)

@register_quant_pattern(operator.mul)
@register_quant_pattern(torch.mul)
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.ReLU, torch.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
class Mul(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
Expand All @@ -100,7 +106,7 @@ def __init__(self, quantizer, node):
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.mul
assert node.op == 'call_function' and node.target in [operator.mul, torch.mul]
self.mul_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]])

Expand Down
1 change: 0 additions & 1 deletion torch/quantization/fx/quantize.py
Expand Up @@ -117,7 +117,6 @@ def load_arg(a):
graph_module = GraphModule(root, graph)
return graph_module


def assert_and_get_unique_device(module):
"""
Returns the unique device for a module, or None if no device is found.
Expand Down
17 changes: 17 additions & 0 deletions torch/quantization/quantize_fx.py
Expand Up @@ -12,6 +12,20 @@ def _check_is_graph_module(model):
'Got type:' + str(type(model)) + ' Please make ' +
'sure to follow the tutorials.')

def _swap_ff_with_fxff(model):
r""" Swap FloatFunctional with FXFloatFunctional
"""
modules_to_swap = []
for name, module in model.named_children():
if isinstance(module, torch.nn.quantized.FloatFunctional):
modules_to_swap.append(name)
else:
_swap_ff_with_fxff(module)

for name in modules_to_swap:
del model._modules[name]
model._modules[name] = torch.nn.quantized.FXFloatFunctional()

def _fuse_fx(graph_module, inplace=False):
r""" Internal helper function to fuse modules in preparation for quantization
Expand Down Expand Up @@ -52,6 +66,9 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i
skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", [])
skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", [])

# swap FloatFunctional with FXFloatFunctional
_swap_ff_with_fxff(model)

# symbolically trace the model
if not is_standalone_module:
# standalone module and custom module config are applied in top level module
Expand Down

0 comments on commit e91476b

Please sign in to comment.