From 1afb93906cbd8662c14fd0e345d1cb5da0a0ace2 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 3 Nov 2020 16:15:33 -0800 Subject: [PATCH] Update on "[quant][graphmode][fx] Add support for qat convbn{relu}1d" Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D24696524](https://our.internmc.facebook.com/intern/diff/D24696524) [ghstack-poisoned] --- torch/quantization/fuser_method_mappings.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 9b2ec3620ccf..0b72f5485231 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -20,16 +20,20 @@ def fuse_conv_bn(conv, bn): "Conv and BN both must be in the same mode (train or eval)." fused_module_class_map = { - (nn.Conv1d, nn.BatchNorm1d): nni.ConvBn1d, - (nn.Conv2d, nn.BatchNorm2d): nni.ConvBn2d, - (nn.Conv3d, nn.BatchNorm3d): nni.ConvBn3d, + nn.Conv1d: nni.ConvBn1d, + nn.Conv2d: nni.ConvBn2d, + nn.Conv3d: nni.ConvBn3d, } if conv.training: assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' - return fused_module_class_map.get((type(conv), type(bn)))(conv, bn) + fused_module_class = fused_module_class_map.get((type(conv)), None) + if fused_module_class is not None: + return fused_module_class(conv, bn) + else: + raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn))) else: return nn.utils.fuse_conv_bn_eval(conv, bn)