diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index ec945f4f58f..71c9b3a52f4 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -1,14 +1,14 @@ +import copy import datetime import os import time -import copy import torch +import torch.ao.quantization import torch.utils.data -from torch import nn import torchvision -import torch.quantization import utils +from torch import nn from train import train_one_epoch, evaluate, load_data @@ -52,8 +52,8 @@ def main(args): if not (args.test_only or args.post_training_quantize): model.fuse_model() - model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend) - torch.quantization.prepare_qat(model, inplace=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) + torch.ao.quantization.prepare_qat(model, inplace=True) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -90,12 +90,12 @@ def main(args): pin_memory=True) model.eval() model.fuse_model() - model.qconfig = torch.quantization.get_default_qconfig(args.backend) - torch.quantization.prepare(model, inplace=True) + model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) + torch.ao.quantization.prepare(model, inplace=True) # Calibrate first print("Calibrating") evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) - torch.quantization.convert(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) if args.output_dir: print('Saving quantized model') if utils.is_main_process(): @@ -109,8 +109,8 @@ def main(args): evaluate(model, criterion, data_loader_test, device=device) return - model.apply(torch.quantization.enable_observer) - model.apply(torch.quantization.enable_fake_quant) + model.apply(torch.ao.quantization.enable_observer) + model.apply(torch.ao.quantization.enable_fake_quant) start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: @@ -122,7 +122,7 @@ def main(args): with torch.no_grad(): if epoch >= args.num_observer_update_epochs: print('Disabling observer for subseq epochs, epoch = ', epoch) - model.apply(torch.quantization.disable_observer) + model.apply(torch.ao.quantization.disable_observer) if epoch >= args.num_batch_norm_update_epochs: print('Freezing BN for subseq epochs, epoch = ', epoch) model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) @@ -132,7 +132,7 @@ def main(args): quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() quantized_eval_model.to(torch.device('cpu')) - torch.quantization.convert(quantized_eval_model, inplace=True) + torch.ao.quantization.convert(quantized_eval_model, inplace=True) print('Evaluate Quantized model') evaluate(quantized_eval_model, criterion, data_loader_test, diff --git a/references/classification/utils.py b/references/classification/utils.py index fad607636e5..46e46893acd 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -1,14 +1,14 @@ -from collections import defaultdict, deque, OrderedDict import copy import datetime +import errno import hashlib +import os import time +from collections import defaultdict, deque, OrderedDict + import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -352,8 +352,8 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T # Quantized Classification model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) model.fuse_model() - model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack') - _ = torch.quantization.prepare_qat(model, inplace=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') + _ = torch.ao.quantization.prepare_qat(model, inplace=True) print(store_model_weights(model, './qat.pth')) # Object Detection diff --git a/test/test_models.py b/test/test_models.py index 9e376bedce5..b0d122816d7 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -696,19 +696,19 @@ def test_quantized_classification_model(model_name): model = torchvision.models.quantization.__dict__[model_name](**kwargs) if eval_mode: model.eval() - model.qconfig = torch.quantization.default_qconfig + model.qconfig = torch.ao.quantization.default_qconfig else: model.train() - model.qconfig = torch.quantization.default_qat_qconfig + model.qconfig = torch.ao.quantization.default_qat_qconfig model.fuse_model() if eval_mode: - torch.quantization.prepare(model, inplace=True) + torch.ao.quantization.prepare(model, inplace=True) else: - torch.quantization.prepare_qat(model, inplace=True) + torch.ao.quantization.prepare_qat(model, inplace=True) model.eval() - torch.quantization.convert(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) try: torch.jit.script(model) diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 685815ac676..028ca2c8bfd 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -96,7 +96,7 @@ def forward(self, x: Tensor) -> Tensor: return x def fuse_model(self) -> None: - torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) + torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) class QuantizableInception(Inception): @@ -148,8 +148,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: *args, **kwargs ) - self.quant = torch.quantization.QuantStub() - self.dequant = torch.quantization.DeQuantStub() + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x: Tensor) -> GoogLeNetOutputs: x = self._transform_input(x) diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 6c6384c295a..b1413a00e94 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -104,7 +104,7 @@ def forward(self, x: Tensor) -> Tensor: return x def fuse_model(self) -> None: - torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) + torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True) class QuantizableInceptionA(inception_module.InceptionA): @@ -236,8 +236,8 @@ def __init__( QuantizableInceptionAux ] ) - self.quant = torch.quantization.QuantStub() - self.dequant = torch.quantization.DeQuantStub() + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x: Tensor) -> InceptionOutputs: x = self._transform_input(x) diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 2349afff447..e890cb02ef5 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,14 +1,13 @@ -from torch import nn -from torch import Tensor - -from ..._internally_replaced_utils import load_state_dict_from_url - from typing import Any +from torch import Tensor +from torch import nn +from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls -from torch.quantization import QuantStub, DeQuantStub, fuse_modules -from .utils import _replace_relu, quantize_model + +from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import ConvNormActivation +from .utils import _replace_relu, quantize_model __all__ = ['QuantizableMobileNetV2', 'mobilenet_v2'] diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 8655a9b0a45..9c89448e18d 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,11 +1,18 @@ +from typing import Any, List, Optional + import torch from torch import nn, Tensor +from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules + from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import ConvNormActivation, SqueezeExcitation -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\ - model_urls, _mobilenet_v3_conf -from torch.quantization import QuantStub, DeQuantStub, fuse_modules -from typing import Any, List, Optional +from ..mobilenetv3 import ( + InvertedResidual, + InvertedResidualConfig, + MobileNetV3, + model_urls, + _mobilenet_v3_conf, +) from .utils import _replace_relu @@ -141,13 +148,13 @@ def _mobilenet_v3_model( backend = 'qnnpack' model.fuse_model() - model.qconfig = torch.quantization.get_default_qat_qconfig(backend) - torch.quantization.prepare_qat(model, inplace=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) + torch.ao.quantization.prepare_qat(model, inplace=True) if pretrained: _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) - torch.quantization.convert(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) model.eval() else: if pretrained: diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 8f87e40ec3d..0d586d3d37e 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,11 +1,12 @@ +from typing import Any, Type, Union, List + import torch -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls import torch.nn as nn from torch import Tensor -from typing import Any, Type, Union, List +from torch.ao.quantization import fuse_modules +from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls from ..._internally_replaced_utils import load_state_dict_from_url -from torch.quantization import fuse_modules from .utils import _replace_relu, quantize_model __all__ = ['QuantizableResNet', 'resnet18', 'resnet50', @@ -45,10 +46,10 @@ def forward(self, x: Tensor) -> Tensor: return out def fuse_model(self) -> None: - torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'], - ['conv2', 'bn2']], inplace=True) + torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'], + ['conv2', 'bn2']], inplace=True) if self.downsample: - torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) + torch.ao.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) class QuantizableBottleneck(Bottleneck): @@ -81,7 +82,7 @@ def fuse_model(self) -> None: ['conv2', 'bn2', 'relu2'], ['conv3', 'bn3']], inplace=True) if self.downsample: - torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) + torch.ao.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) class QuantizableResNet(ResNet): @@ -89,8 +90,8 @@ class QuantizableResNet(ResNet): def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableResNet, self).__init__(*args, **kwargs) - self.quant = torch.quantization.QuantStub() - self.dequant = torch.quantization.DeQuantStub() + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x: Tensor) -> Tensor: x = self.quant(x) diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 4f0861dcb30..734c67bbd34 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -46,8 +46,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: inverted_residual=QuantizableInvertedResidual, **kwargs ) - self.quant = torch.quantization.QuantStub() - self.dequant = torch.quantization.DeQuantStub() + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x: Tensor) -> Tensor: x = self.quant(x) @@ -65,14 +65,14 @@ def fuse_model(self) -> None: for name, m in self._modules.items(): if name in ["conv1", "conv5"]: - torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True) + torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True) for m in self.modules(): if type(m) == QuantizableInvertedResidual: if len(m.branch1._modules.items()) > 0: - torch.quantization.fuse_modules( + torch.ao.quantization.fuse_modules( m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True ) - torch.quantization.fuse_modules( + torch.ao.quantization.fuse_modules( m.branch2, [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], inplace=True, diff --git a/torchvision/models/quantization/utils.py b/torchvision/models/quantization/utils.py index c195d162482..24987d4abe0 100644 --- a/torchvision/models/quantization/utils.py +++ b/torchvision/models/quantization/utils.py @@ -24,18 +24,18 @@ def quantize_model(model: nn.Module, backend: str) -> None: model.eval() # Make sure that weight qconfig matches that of the serialized models if backend == 'fbgemm': - model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] - activation=torch.quantization.default_observer, - weight=torch.quantization.default_per_channel_weight_observer) + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, + weight=torch.ao.quantization.default_per_channel_weight_observer) elif backend == 'qnnpack': - model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] - activation=torch.quantization.default_observer, - weight=torch.quantization.default_weight_observer) + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, + weight=torch.ao.quantization.default_weight_observer) # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 model.fuse_model() # type: ignore[operator] - torch.quantization.prepare(model, inplace=True) + torch.ao.quantization.prepare(model, inplace=True) model(_dummy_input_data) - torch.quantization.convert(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) return