diff --git a/test/test_models.py b/test/test_models.py index 4f021d323b2..b52b7ecc690 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -8,9 +8,11 @@ import operator import torch import torch.nn as nn +import torchvision from torchvision import models import pytest import warnings +import traceback ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1' @@ -36,6 +38,11 @@ def get_available_video_models(): return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] +def get_available_quantizable_models(): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + + def _get_expected_file(name=None): # Determine expected file based on environment expected_file_base = get_relative_path(os.path.realpath(__file__), "expect") @@ -617,5 +624,49 @@ def test_video_model(model_name, dev): assert out.shape[-1] == 50 +@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and + 'qnnpack' in torch.backends.quantized.supported_engines), + reason="This Pytorch Build has not been built with fbgemm and qnnpack") +@pytest.mark.parametrize('model_name', get_available_quantizable_models()) +def test_quantized_classification_model(model_name): + defaults = { + 'input_shape': (1, 3, 224, 224), + 'pretrained': False, + 'quantize': True, + } + kwargs = {**defaults, **_model_params.get(model_name, {})} + input_shape = kwargs.pop('input_shape') + + # First check if quantize=True provides models that can run with input data + model = torchvision.models.quantization.__dict__[model_name](**kwargs) + x = torch.rand(input_shape) + model(x) + + kwargs['quantize'] = False + for eval_mode in [True, False]: + model = torchvision.models.quantization.__dict__[model_name](**kwargs) + if eval_mode: + model.eval() + model.qconfig = torch.quantization.default_qconfig + else: + model.train() + model.qconfig = torch.quantization.default_qat_qconfig + + model.fuse_model() + if eval_mode: + torch.quantization.prepare(model, inplace=True) + else: + torch.quantization.prepare_qat(model, inplace=True) + model.eval() + + torch.quantization.convert(model, inplace=True) + + try: + torch.jit.script(model) + except Exception as e: + tb = traceback.format_exc() + raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/test/test_quantized_models.py b/test/test_quantized_models.py deleted file mode 100644 index d8fd5325755..00000000000 --- a/test/test_quantized_models.py +++ /dev/null @@ -1,93 +0,0 @@ -import torchvision -from common_utils import TestCase, map_nested_tensor_object -from collections import OrderedDict -from itertools import product -import torch -import numpy as np -from torchvision import models -import unittest -import traceback -import random - - -def set_rng_seed(seed): - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - - -def get_available_quantizable_models(): - # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] - - -# list of models that are not scriptable -scriptable_quantizable_models_blacklist = [] - - -@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines and - 'qnnpack' in torch.backends.quantized.supported_engines, - "This Pytorch Build has not been built with fbgemm and qnnpack") -class ModelTester(TestCase): - def check_quantized_model(self, model, input_shape): - x = torch.rand(input_shape) - model(x) - return - - def check_script(self, model, name): - if name in scriptable_quantizable_models_blacklist: - return - scriptable = True - msg = "" - try: - torch.jit.script(model) - except Exception as e: - tb = traceback.format_exc() - scriptable = False - msg = str(e) + str(tb) - self.assertTrue(scriptable, msg) - - def _test_classification_model(self, name, input_shape): - # First check if quantize=True provides models that can run with input data - - model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=True) - self.check_quantized_model(model, input_shape) - - for eval_mode in [True, False]: - model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=False) - if eval_mode: - model.eval() - model.qconfig = torch.quantization.default_qconfig - else: - model.train() - model.qconfig = torch.quantization.default_qat_qconfig - - model.fuse_model() - if eval_mode: - torch.quantization.prepare(model, inplace=True) - else: - torch.quantization.prepare_qat(model, inplace=True) - model.eval() - - torch.quantization.convert(model, inplace=True) - - self.check_script(model, name) - - -for model_name in get_available_quantizable_models(): - # for-loop bodies don't define scopes, so we have to save the variables - # we want to close over in some way - def do_test(self, model_name=model_name): - input_shape = (1, 3, 224, 224) - if model_name in ['inception_v3']: - input_shape = (1, 3, 299, 299) - self._test_classification_model(model_name, input_shape) - - # inception_v3 was causing timeouts on circleci - # See https://github.com/pytorch/vision/issues/1857 - if model_name not in ['inception_v3']: - setattr(ModelTester, "test_" + model_name, do_test) - - -if __name__ == '__main__': - unittest.main()