@@ -20,16 +20,20 @@ def fuse_conv_bn(conv, bn):
20
20
"Conv and BN both must be in the same mode (train or eval)."
21
21
22
22
fused_module_class_map = {
23
- ( nn .Conv1d , nn . BatchNorm1d ) : nni .ConvBn1d ,
24
- ( nn .Conv2d , nn . BatchNorm2d ) : nni .ConvBn2d ,
25
- ( nn .Conv3d , nn . BatchNorm3d ) : nni .ConvBn3d ,
23
+ nn .Conv1d : nni .ConvBn1d ,
24
+ nn .Conv2d : nni .ConvBn2d ,
25
+ nn .Conv3d : nni .ConvBn3d ,
26
26
}
27
27
28
28
if conv .training :
29
29
assert bn .num_features == conv .out_channels , 'Output channel of Conv2d must match num_features of BatchNorm2d'
30
30
assert bn .affine , 'Only support fusing BatchNorm2d with affine set to True'
31
31
assert bn .track_running_stats , 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
32
- return fused_module_class_map .get ((type (conv ), type (bn )))(conv , bn )
32
+ fused_module_class = fused_module_class_map .get ((type (conv )), None )
33
+ if fused_module_class is not None :
34
+ return fused_module_class (conv , bn )
35
+ else :
36
+ raise NotImplementedError ("Cannot fuse train modules: {}" .format ((conv , bn )))
33
37
else :
34
38
return nn .utils .fuse_conv_bn_eval (conv , bn )
35
39
0 commit comments