Skip to content

Commit

Permalink
Update on "[quant][graphmode][fx] Add support for qat convbn{relu}1d"
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D24696524](https://our.internmc.facebook.com/intern/diff/D24696524)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Nov 4, 2020
1 parent 2abea89 commit 1afb939
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch/quantization/fuser_method_mappings.py
Expand Up @@ -20,16 +20,20 @@ def fuse_conv_bn(conv, bn):
"Conv and BN both must be in the same mode (train or eval)."

fused_module_class_map = {
(nn.Conv1d, nn.BatchNorm1d): nni.ConvBn1d,
(nn.Conv2d, nn.BatchNorm2d): nni.ConvBn2d,
(nn.Conv3d, nn.BatchNorm3d): nni.ConvBn3d,
nn.Conv1d: nni.ConvBn1d,
nn.Conv2d: nni.ConvBn2d,
nn.Conv3d: nni.ConvBn3d,
}

if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return fused_module_class_map.get((type(conv), type(bn)))(conv, bn)
fused_module_class = fused_module_class_map.get((type(conv)), None)
if fused_module_class is not None:
return fused_module_class(conv, bn)
else:
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)

Expand Down

0 comments on commit 1afb939

Please sign in to comment.