diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 153fe74ba913..d5843b6a0162 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -103,14 +103,15 @@ def forward(self, x): # train mode fuse_fx is called in prepare_qat_fx m = prepare_qat_fx(m, {}) expected_nodes = [ + ns.call_module(nni.ConvBn1d), ns.call_module(nni.ConvBn2d), ns.call_module(nni.ConvBn3d), + ns.call_module(nni.ConvBnReLU1d), ns.call_module(nni.ConvBnReLU2d), ns.call_module(nni.ConvBnReLU3d), ] - # ConvBnRelu1d is not fused expected_occurrence = { - ns.call_module(nn.ReLU): 1 + ns.call_module(nn.ReLU): 0 } self.checkGraphModuleNodes( m, @@ -123,14 +124,16 @@ def forward(self, x): # fuse_fx is a top level api and only supports eval mode m = fuse_fx(m) expected_nodes = [ + ns.call_module(nn.Conv1d), ns.call_module(nn.Conv2d), ns.call_module(nn.Conv3d), + ns.call_module(nni.ConvReLU1d), ns.call_module(nni.ConvReLU2d), ns.call_module(nni.ConvReLU3d), ] # ConvBnRelu1d is not fused expected_occurrence = { - ns.call_module(nn.ReLU): 1 + ns.call_module(nn.ReLU): 0 } self.checkGraphModuleNodes( m, @@ -301,6 +304,55 @@ def forward(self, x): ref_qparams = weight_obs.calculate_qparams() self.assertEqual(qparams, ref_qparams) + def test_conv_bn_relu(self): + convs = { + 1: nn.Conv1d, + 2: nn.Conv2d, + 3: nn.Conv3d, + } + bns = { + 1: nn.BatchNorm1d, + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d, + } + quantized_convs = { + 1: nnq.Conv1d, + 2: nnq.Conv2d, + 3: nnq.Conv3d, + } + quantized_conv_relus = { + 1: nniq.ConvReLU1d, + 2: nniq.ConvReLU2d, + 3: nniq.ConvReLU3d, + } + + class M(torch.nn.Module): + def __init__(self, dim, has_relu): + super().__init__() + self.conv = convs[dim](3, 3, 3) + self.bn = bns[dim](3) + self.relu = nn.ReLU() + self.has_relu = has_relu + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.has_relu: + x = self.relu(x) + return x + + options = itertools.product([1, 2], [True, False], self.static_quant_types) + for dim, has_relu, quant_type in options: + expected_node = ns.call_module( + quantized_conv_relus[dim] if has_relu + else quantized_convs[dim]) + self.checkGraphModeFxOp( + M(dim, has_relu), + self.img_data_dict[dim], + quant_type, + expected_node=expected_node, + ) + @skipIfNoFBGEMM def test_dynamic_quant_fp16(self): class Linear(torch.nn.Module): diff --git a/torch/nn/intrinsic/qat/__init__.py b/torch/nn/intrinsic/qat/__init__.py index d46ca956685c..270dcebaa5f4 100644 --- a/torch/nn/intrinsic/qat/__init__.py +++ b/torch/nn/intrinsic/qat/__init__.py @@ -1,14 +1 @@ -from .modules import LinearReLU -from .modules import ConvReLU2d -from .modules import ConvBn2d -from .modules import ConvBnReLU2d -from .modules import update_bn_stats, freeze_bn_stats - -__all__ = [ - 'ConvBn2d', - 'ConvBnReLU2d', - 'ConvReLU2d', - 'LinearReLU', - 'update_bn_stats', - 'freeze_bn_stats' -] +from .modules import * diff --git a/torch/nn/intrinsic/qat/modules/__init__.py b/torch/nn/intrinsic/qat/modules/__init__.py index bcbb865a5649..f0876e8ded56 100644 --- a/torch/nn/intrinsic/qat/modules/__init__.py +++ b/torch/nn/intrinsic/qat/modules/__init__.py @@ -1,11 +1,13 @@ - from .linear_relu import LinearReLU -from .conv_fused import ConvBn2d, ConvBnReLU2d, ConvReLU2d, update_bn_stats, freeze_bn_stats +from .conv_fused import ConvBn1d, ConvBn2d, ConvBnReLU1d, ConvBnReLU2d, ConvReLU2d, \ + update_bn_stats, freeze_bn_stats __all__ = [ 'LinearReLU', 'ConvReLU2d', + 'ConvBn1d', 'ConvBn2d', + 'ConvBnReLU1d', 'ConvBnReLU2d', 'update_bn_stats', 'freeze_bn_stats' diff --git a/torch/nn/intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py index 5a8b0f042db1..ed109e6df5d9 100644 --- a/torch/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -5,9 +5,16 @@ import torch.nn.qat as nnqat import torch.nn.functional as F from torch.nn import init -from torch.nn.modules.utils import _pair +from torch.nn.modules.utils import _single, _pair from torch.nn.parameter import Parameter +_BN_CLASS_MAP = { + 1: nn.BatchNorm1d, + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d, +} + + class _ConvBnNd(nn.modules.conv._ConvNd): _version = 2 @@ -26,14 +33,15 @@ def __init__(self, # track_running_stats: True # Args for this module freeze_bn=False, - qconfig=None): + qconfig=None, + dim=2): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True - self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True) + self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = Parameter(torch.Tensor(out_channels)) @@ -80,12 +88,16 @@ def freeze_bn_stats(self): def _forward(self, input): running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std - scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape([-1, 1, 1, 1])) + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) # this does not include the conv bias conv = self._conv_forward(input, scaled_weight) - conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + conv_orig = conv / scale_factor.reshape(bias_shape) if self.bias is not None: - conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1]) + conv_orig = conv_orig + self.bias.reshape(bias_shape) conv = self.bn(conv_orig) return conv @@ -190,6 +202,92 @@ def from_float(cls, mod): qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked return qat_convbn +class ConvBn1d(_ConvBnNd, nn.Conv1d): + r""" + A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + _FLOAT_MODULE = torch.nn.intrinsic.ConvBn1d + + def __init__(self, + # Conv1d args + in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=None, + padding_mode='zeros', + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, False, _single(0), groups, bias, padding_mode, + eps, momentum, freeze_bn, qconfig, dim=1) + +class ConvBnReLU1d(ConvBn1d): + r""" + A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + _FLOAT_MODULE = torch.nn.intrinsic.ConvBnReLU1d + + def __init__(self, + # Conv1d args + in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=None, + padding_mode='zeros', + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, + padding_mode, eps, momentum, + freeze_bn, + qconfig) + + def forward(self, input): + return F.relu(ConvBn1d._forward(self, input)) + + @classmethod + def from_float(cls, mod): + return super(ConvBnReLU1d, cls).from_float(mod) + class ConvBn2d(_ConvBnNd, nn.Conv2d): r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, @@ -199,8 +297,6 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d`. - Implementation details: https://arxiv.org/pdf/1806.08342.pdf section 3.2.2 - Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized to default. @@ -231,7 +327,7 @@ def __init__(self, dilation = _pair(dilation) _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, - eps, momentum, freeze_bn, qconfig) + eps, momentum, freeze_bn, qconfig, dim=2) class ConvBnReLU2d(ConvBn2d): r""" @@ -242,8 +338,6 @@ class ConvBnReLU2d(ConvBn2d): We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. - Implementation details: https://arxiv.org/pdf/1806.08342.pdf - Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. @@ -281,8 +375,7 @@ def from_float(cls, mod): return super(ConvBnReLU2d, cls).from_float(mod) class ConvReLU2d(nnqat.Conv2d): - r""" - A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with + r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for weight for quantization aware training. @@ -316,9 +409,9 @@ def from_float(cls, mod): return super(ConvReLU2d, cls).from_float(mod) def update_bn_stats(mod): - if type(mod) in set([ConvBnReLU2d, ConvBn2d]): + if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]): mod.update_bn_stats() def freeze_bn_stats(mod): - if type(mod) in set([ConvBnReLU2d, ConvBn2d]): + if type(mod) in set([ConvBnReLU1d, ConvBnReLU2d, ConvBn1d, ConvBn2d]): mod.freeze_bn_stats() diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 20a1d49619b0..7fae70b601b4 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -245,14 +245,17 @@ def __init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) - def forward(self, input: Tensor) -> Tensor: + def _conv_forward(self, input, weight): if self.padding_mode != 'zeros': return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.weight, self.bias, self.stride, + weight, self.bias, self.stride, _single(0), self.dilation, self.groups) - return F.conv1d(input, self.weight, self.bias, self.stride, + return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight) + class Conv2d(_ConvNd): __doc__ = r"""Applies a 2D convolution over an input signal composed of several input diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index e9f4a4c701eb..6a2d98f0a8fe 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -163,6 +163,32 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): qconv.zero_point = int(act_zp) return qconv + @staticmethod + def from_float(cls, mod): + if hasattr(mod, "weight_fake_quant"): + # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ + # ".from_float only works for " + cls.__QAT_MODULE.__name__ + if type(mod) == cls._NNIQAT_CONV_BN_MODULE: + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, + mod.bn.eps, mod.bn.weight, mod.bn.bias) + assert hasattr(mod, "activation_post_process"), \ + "Input QAT module must have observer attached" + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + assert type(mod) == cls._FLOAT_MODULE, \ + " nnq." + cls.__name__ + ".from_float only works for " + \ + cls._FLOAT_MODULE.__name__ + assert hasattr(mod, "qconfig"), \ + "Input float module must have qconfig defined." + if type(mod) == cls._NNI_CONV_RELU_MODULE: + activation_post_process = mod[1].activation_post_process + mod = mod[0] + else: + activation_post_process = mod.activation_post_process + weight_post_process = mod.qconfig.weight() + return cls.get_qconv(mod, activation_post_process, weight_post_process) class Conv1d(_ConvNd): r"""Applies a 1D convolution over a quantized input signal composed of @@ -198,6 +224,8 @@ class Conv1d(_ConvNd): """ _FLOAT_MODULE = nn.Conv1d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE = nni.ConvReLU1d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -244,17 +272,7 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ - assert hasattr(mod, 'qconfig'), \ - 'Input float module must have qconfig defined.' - if type(mod) == nni.ConvReLU1d: - activation_post_process = mod[1].activation_post_process - mod = mod[0] - else: - activation_post_process = mod.activation_post_process - return cls.get_qconv(mod, activation_post_process) + return _ConvNd.from_float(cls, mod) class Conv2d(_ConvNd): @@ -294,6 +312,8 @@ class Conv2d(_ConvNd): """ _FLOAT_MODULE = nn.Conv2d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE = nni.ConvReLU2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, @@ -339,33 +359,7 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - if hasattr(mod, 'weight_fake_quant'): - # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + \ - # '.from_float only works for ' + cls.__QAT_MODULE.__name__ - if type(mod) == nniqat.ConvBn2d: - mod.weight, mod.bias = fuse_conv_bn_weights( - mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, - mod.bn.eps, mod.bn.weight, mod.bn.bias) - assert hasattr(mod, 'activation_post_process'), \ - 'Input QAT module must have observer attached' - weight_post_process = mod.weight_fake_quant - activation_post_process = mod.activation_post_process - else: - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ - assert hasattr(mod, 'qconfig'), \ - 'Input float module must have qconfig defined.' - # workaround for sequential, ConvReLU2d should probably - # inherit from Conv2d instead - if type(mod) == nni.ConvReLU2d: - activation_post_process = mod[1].activation_post_process - mod = mod[0] - else: - activation_post_process = mod.activation_post_process - weight_post_process = mod.qconfig.weight() - - return cls.get_qconv(mod, activation_post_process, weight_post_process) + return _ConvNd.from_float(cls, mod) class Conv3d(_ConvNd): diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 9f3e031ee50b..0b72f5485231 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -19,14 +19,21 @@ def fuse_conv_bn(conv, bn): assert(conv.training == bn.training),\ "Conv and BN both must be in the same mode (train or eval)." - is_3d = isinstance(conv, nn.Conv3d) + fused_module_class_map = { + nn.Conv1d: nni.ConvBn1d, + nn.Conv2d: nni.ConvBn2d, + nn.Conv3d: nni.ConvBn3d, + } if conv.training: assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' - return nni.ConvBn3d(conv, bn) if is_3d \ - else nni.ConvBn2d(conv, bn) + fused_module_class = fused_module_class_map.get((type(conv)), None) + if fused_module_class is not None: + return fused_module_class(conv, bn) + else: + raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn))) else: return nn.utils.fuse_conv_bn_eval(conv, bn) @@ -48,13 +55,14 @@ def fuse_conv_bn_relu(conv, bn, relu): fused_module : Optional[Type[nn.Sequential]] = None if conv.training: map_to_fused_module_train = { + nn.Conv1d: nni.ConvBnReLU1d, nn.Conv2d: nni.ConvBnReLU2d, nn.Conv3d: nni.ConvBnReLU3d, } assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' assert bn.affine, 'Only support fusing BatchNorm with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True' - fused_module = map_to_fused_module_train.get(type(conv)) + fused_module = map_to_fused_module_train.get(type(conv), None) if fused_module is not None: return fused_module(conv, bn, relu) else: @@ -65,7 +73,7 @@ def fuse_conv_bn_relu(conv, bn, relu): nn.Conv2d: nni.ConvReLU2d, nn.Conv3d: nni.ConvReLU3d, } - fused_module = map_to_fused_module_eval[type(conv)] + fused_module = map_to_fused_module_eval.get(type(conv), None) if fused_module is not None: fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) return fused_module(fused_conv, relu) diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index 4c92192dc5be..70dc59923bb6 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -15,10 +15,13 @@ @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d)) +@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d)) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) @register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) @register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) class ConvBNReLUFusion(): diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index e69c6c5ea33f..8b3c96306324 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -174,7 +174,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.nn.intrinsic.ConvReLU1d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU2d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU3d) +@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d) +@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d) @register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d)) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 732f2efdedfe..4733491baf4e 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -749,7 +749,7 @@ def is_quantized(node): quant_env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) else: # copy quantized or non-quantized node - env[node.name] = self.quantized_graph.node_copy(node, load_x) + env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) # remove activation post process act_post_process_removed_graph = Graph() diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 82340a49309c..683855dee71b 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -48,7 +48,9 @@ nni.ConvReLU2d: nniq.ConvReLU2d, nni.ConvReLU3d: nniq.ConvReLU3d, nni.LinearReLU: nniq.LinearReLU, + nniqat.ConvBn1d: nnq.Conv1d, nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBnReLU1d: nniq.ConvReLU1d, nniqat.ConvBnReLU2d: nniq.ConvReLU2d, nniqat.ConvReLU2d: nniq.ConvReLU2d, nniqat.LinearReLU: nniq.LinearReLU, @@ -62,7 +64,9 @@ nn.Conv2d: nnqat.Conv2d, nn.Linear: nnqat.Linear, # Intrinsic modules: + nni.ConvBn1d: nniqat.ConvBn1d, nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, nni.ConvReLU2d: nniqat.ConvReLU2d, nni.LinearReLU: nniqat.LinearReLU