Skip to content

Commit 1afb939

Browse files
committed
Update on "[quant][graphmode][fx] Add support for qat convbn{relu}1d"
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D24696524](https://our.internmc.facebook.com/intern/diff/D24696524) [ghstack-poisoned]
1 parent 2abea89 commit 1afb939

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torch/quantization/fuser_method_mappings.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,20 @@ def fuse_conv_bn(conv, bn):
2020
"Conv and BN both must be in the same mode (train or eval)."
2121

2222
fused_module_class_map = {
23-
(nn.Conv1d, nn.BatchNorm1d): nni.ConvBn1d,
24-
(nn.Conv2d, nn.BatchNorm2d): nni.ConvBn2d,
25-
(nn.Conv3d, nn.BatchNorm3d): nni.ConvBn3d,
23+
nn.Conv1d: nni.ConvBn1d,
24+
nn.Conv2d: nni.ConvBn2d,
25+
nn.Conv3d: nni.ConvBn3d,
2626
}
2727

2828
if conv.training:
2929
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
3030
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
3131
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
32-
return fused_module_class_map.get((type(conv), type(bn)))(conv, bn)
32+
fused_module_class = fused_module_class_map.get((type(conv)), None)
33+
if fused_module_class is not None:
34+
return fused_module_class(conv, bn)
35+
else:
36+
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
3337
else:
3438
return nn.utils.fuse_conv_bn_eval(conv, bn)
3539

0 commit comments

Comments
 (0)