Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import copy
import datetime
import os
import time
import copy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be nice to separate unrelated changes into a separate PR, there are a lot in this PR


import torch
import torch.ao.quantization
import torch.utils.data
from torch import nn
import torchvision
import torch.quantization
import utils
from torch import nn
from train import train_one_epoch, evaluate, load_data


Expand Down Expand Up @@ -52,8 +52,8 @@ def main(args):

if not (args.test_only or args.post_training_quantize):
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
torch.quantization.prepare_qat(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -90,12 +90,12 @@ def main(args):
pin_memory=True)
model.eval()
model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig(args.backend)
torch.quantization.prepare(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
if args.output_dir:
print('Saving quantized model')
if utils.is_main_process():
Expand All @@ -109,8 +109,8 @@ def main(args):
evaluate(model, criterion, data_loader_test, device=device)
return

model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.enable_fake_quant)
model.apply(torch.ao.quantization.enable_observer)
model.apply(torch.ao.quantization.enable_fake_quant)
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
Expand All @@ -122,7 +122,7 @@ def main(args):
with torch.no_grad():
if epoch >= args.num_observer_update_epochs:
print('Disabling observer for subseq epochs, epoch = ', epoch)
model.apply(torch.quantization.disable_observer)
model.apply(torch.ao.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs:
print('Freezing BN for subseq epochs, epoch = ', epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
Expand All @@ -132,7 +132,7 @@ def main(args):
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device('cpu'))
torch.quantization.convert(quantized_eval_model, inplace=True)
torch.ao.quantization.convert(quantized_eval_model, inplace=True)

print('Evaluate Quantized model')
evaluate(quantized_eval_model, criterion, data_loader_test,
Expand Down
12 changes: 6 additions & 6 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections import defaultdict, deque, OrderedDict
import copy
import datetime
import errno
import hashlib
import os
import time
from collections import defaultdict, deque, OrderedDict

import torch
import torch.distributed as dist

import errno
import os


class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
Expand Down Expand Up @@ -352,8 +352,8 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T
# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.quantization.prepare_qat(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))

# Object Detection
Expand Down
10 changes: 5 additions & 5 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,19 +696,19 @@ def test_quantized_classification_model(model_name):
model = torchvision.models.quantization.__dict__[model_name](**kwargs)
if eval_mode:
model.eval()
model.qconfig = torch.quantization.default_qconfig
model.qconfig = torch.ao.quantization.default_qconfig
else:
model.train()
model.qconfig = torch.quantization.default_qat_qconfig
model.qconfig = torch.ao.quantization.default_qat_qconfig

model.fuse_model()
if eval_mode:
torch.quantization.prepare(model, inplace=True)
torch.ao.quantization.prepare(model, inplace=True)
else:
torch.quantization.prepare_qat(model, inplace=True)
torch.ao.quantization.prepare_qat(model, inplace=True)
model.eval()

torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)

try:
torch.jit.script(model)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(self, x: Tensor) -> Tensor:
return x

def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInception(Inception):
Expand Down Expand Up @@ -148,8 +148,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
*args,
**kwargs
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def forward(self, x: Tensor) -> Tensor:
return x

def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInceptionA(inception_module.InceptionA):
Expand Down Expand Up @@ -236,8 +236,8 @@ def __init__(
QuantizableInceptionAux
]
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> InceptionOutputs:
x = self._transform_input(x)
Expand Down
13 changes: 6 additions & 7 deletions torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from torch import nn
from torch import Tensor

from ..._internally_replaced_utils import load_state_dict_from_url

from typing import Any

from torch import Tensor
from torch import nn
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from .utils import _replace_relu, quantize_model


__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']
Expand Down
21 changes: 14 additions & 7 deletions torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Any, List, Optional

import torch
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\
model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional
from ..mobilenetv3 import (
InvertedResidual,
InvertedResidualConfig,
MobileNetV3,
model_urls,
_mobilenet_v3_conf,
)
from .utils import _replace_relu


Expand Down Expand Up @@ -141,13 +148,13 @@ def _mobilenet_v3_model(
backend = 'qnnpack'

model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)

torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
model.eval()
else:
if pretrained:
Expand Down
19 changes: 10 additions & 9 deletions torchvision/models/quantization/resnet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Type, Union, List

import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn
from torch import Tensor
from typing import Any, Type, Union, List
from torch.ao.quantization import fuse_modules
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls

from ..._internally_replaced_utils import load_state_dict_from_url
from torch.quantization import fuse_modules
from .utils import _replace_relu, quantize_model

__all__ = ['QuantizableResNet', 'resnet18', 'resnet50',
Expand Down Expand Up @@ -45,10 +46,10 @@ def forward(self, x: Tensor) -> Tensor:
return out

def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
['conv2', 'bn2']], inplace=True)
torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
['conv2', 'bn2']], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)
torch.ao.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)


class QuantizableBottleneck(Bottleneck):
Expand Down Expand Up @@ -81,16 +82,16 @@ def fuse_model(self) -> None:
['conv2', 'bn2', 'relu2'],
['conv3', 'bn3']], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)
torch.ao.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)


class QuantizableResNet(ResNet):

def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableResNet, self).__init__(*args, **kwargs)

self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
Expand Down
10 changes: 5 additions & 5 deletions torchvision/models/quantization/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
inverted_residual=QuantizableInvertedResidual,
**kwargs
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
Expand All @@ -65,14 +65,14 @@ def fuse_model(self) -> None:

for name, m in self._modules.items():
if name in ["conv1", "conv5"]:
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
for m in self.modules():
if type(m) == QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules(
torch.ao.quantization.fuse_modules(
m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True
)
torch.quantization.fuse_modules(
torch.ao.quantization.fuse_modules(
m.branch2,
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
inplace=True,
Expand Down
16 changes: 8 additions & 8 deletions torchvision/models/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ def quantize_model(model: nn.Module, backend: str) -> None:
model.eval()
# Make sure that weight qconfig matches that of the serialized models
if backend == 'fbgemm':
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer)
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
activation=torch.ao.quantization.default_observer,
weight=torch.ao.quantization.default_per_channel_weight_observer)
elif backend == 'qnnpack':
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer,
weight=torch.quantization.default_weight_observer)
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
activation=torch.ao.quantization.default_observer,
weight=torch.ao.quantization.default_weight_observer)

# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
model.fuse_model() # type: ignore[operator]
torch.quantization.prepare(model, inplace=True)
torch.ao.quantization.prepare(model, inplace=True)
model(_dummy_input_data)
torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)

return