diff --git a/test/expect/AlexnetTester.test_alexnet_expect.pkl b/test/expect/AlexnetTester.test_alexnet_expect.pkl new file mode 100644 index 00000000000..217eed006da Binary files /dev/null and b/test/expect/AlexnetTester.test_alexnet_expect.pkl differ diff --git a/test/expect/DensenetTester.test_densenet121_expect.pkl b/test/expect/DensenetTester.test_densenet121_expect.pkl new file mode 100644 index 00000000000..32127953ee8 Binary files /dev/null and b/test/expect/DensenetTester.test_densenet121_expect.pkl differ diff --git a/test/expect/DensenetTester.test_densenet161_expect.pkl b/test/expect/DensenetTester.test_densenet161_expect.pkl new file mode 100644 index 00000000000..7746061cd2c Binary files /dev/null and b/test/expect/DensenetTester.test_densenet161_expect.pkl differ diff --git a/test/expect/DensenetTester.test_densenet169_expect.pkl b/test/expect/DensenetTester.test_densenet169_expect.pkl new file mode 100644 index 00000000000..fe377f88b05 Binary files /dev/null and b/test/expect/DensenetTester.test_densenet169_expect.pkl differ diff --git a/test/expect/DensenetTester.test_densenet201_expect.pkl b/test/expect/DensenetTester.test_densenet201_expect.pkl new file mode 100644 index 00000000000..2185d458666 Binary files /dev/null and b/test/expect/DensenetTester.test_densenet201_expect.pkl differ diff --git a/test/expect/GooglenetTester.test_googlenet_expect.pkl b/test/expect/GooglenetTester.test_googlenet_expect.pkl new file mode 100644 index 00000000000..f4966407f43 Binary files /dev/null and b/test/expect/GooglenetTester.test_googlenet_expect.pkl differ diff --git a/test/expect/InceptionV3Tester.test_inception_v3_expect.pkl b/test/expect/InceptionV3Tester.test_inception_v3_expect.pkl new file mode 100644 index 00000000000..ba98d92343c Binary files /dev/null and b/test/expect/InceptionV3Tester.test_inception_v3_expect.pkl differ diff --git a/test/expect/MNASNetTester.test_mnasnet0_5_expect.pkl b/test/expect/MNASNetTester.test_mnasnet0_5_expect.pkl new file mode 100644 index 00000000000..596513fcb06 Binary files /dev/null and b/test/expect/MNASNetTester.test_mnasnet0_5_expect.pkl differ diff --git a/test/expect/MNASNetTester.test_mnasnet0_75_expect.pkl b/test/expect/MNASNetTester.test_mnasnet0_75_expect.pkl new file mode 100644 index 00000000000..530c8eeafc6 Binary files /dev/null and b/test/expect/MNASNetTester.test_mnasnet0_75_expect.pkl differ diff --git a/test/expect/MNASNetTester.test_mnasnet1_0_expect.pkl b/test/expect/MNASNetTester.test_mnasnet1_0_expect.pkl new file mode 100644 index 00000000000..842575007d1 Binary files /dev/null and b/test/expect/MNASNetTester.test_mnasnet1_0_expect.pkl differ diff --git a/test/expect/MNASNetTester.test_mnasnet1_3_expect.pkl b/test/expect/MNASNetTester.test_mnasnet1_3_expect.pkl new file mode 100644 index 00000000000..a79038ce3eb Binary files /dev/null and b/test/expect/MNASNetTester.test_mnasnet1_3_expect.pkl differ diff --git a/test/expect/MobilenetTester.test_mobilenet_v2_expect.pkl b/test/expect/MobilenetTester.test_mobilenet_v2_expect.pkl new file mode 100644 index 00000000000..54ee307c01b Binary files /dev/null and b/test/expect/MobilenetTester.test_mobilenet_v2_expect.pkl differ diff --git a/test/expect/MobilenetTester.test_mobilenetv2_residual_setting_expect.pkl b/test/expect/MobilenetTester.test_mobilenetv2_residual_setting_expect.pkl new file mode 100644 index 00000000000..c7733885caf Binary files /dev/null and b/test/expect/MobilenetTester.test_mobilenetv2_residual_setting_expect.pkl differ diff --git a/test/expect/ResnetTester.test_resnet101_expect.pkl b/test/expect/ResnetTester.test_resnet101_expect.pkl new file mode 100644 index 00000000000..ba62eb8e625 Binary files /dev/null and b/test/expect/ResnetTester.test_resnet101_expect.pkl differ diff --git a/test/expect/ResnetTester.test_resnet152_expect.pkl b/test/expect/ResnetTester.test_resnet152_expect.pkl new file mode 100644 index 00000000000..2d10165f546 Binary files /dev/null and b/test/expect/ResnetTester.test_resnet152_expect.pkl differ diff --git a/test/expect/ResnetTester.test_resnet18_expect.pkl b/test/expect/ResnetTester.test_resnet18_expect.pkl new file mode 100644 index 00000000000..e764184eff5 Binary files /dev/null and b/test/expect/ResnetTester.test_resnet18_expect.pkl differ diff --git a/test/expect/ResnetTester.test_resnet34_expect.pkl b/test/expect/ResnetTester.test_resnet34_expect.pkl new file mode 100644 index 00000000000..0a174e5d6ea Binary files /dev/null and b/test/expect/ResnetTester.test_resnet34_expect.pkl differ diff --git a/test/expect/ResnetTester.test_resnet50_expect.pkl b/test/expect/ResnetTester.test_resnet50_expect.pkl new file mode 100644 index 00000000000..1a94550e336 Binary files /dev/null and b/test/expect/ResnetTester.test_resnet50_expect.pkl differ diff --git a/test/expect/ResnetTester.test_resnext101_32x8d_expect.pkl b/test/expect/ResnetTester.test_resnext101_32x8d_expect.pkl new file mode 100644 index 00000000000..b2dd8c42da4 Binary files /dev/null and b/test/expect/ResnetTester.test_resnext101_32x8d_expect.pkl differ diff --git a/test/expect/ResnetTester.test_resnext50_32x4d_expect.pkl b/test/expect/ResnetTester.test_resnext50_32x4d_expect.pkl new file mode 100644 index 00000000000..fd4b9d49c49 Binary files /dev/null and b/test/expect/ResnetTester.test_resnext50_32x4d_expect.pkl differ diff --git a/test/expect/ResnetTester.test_wide_resnet101_2_expect.pkl b/test/expect/ResnetTester.test_wide_resnet101_2_expect.pkl new file mode 100644 index 00000000000..8aef5fb2909 Binary files /dev/null and b/test/expect/ResnetTester.test_wide_resnet101_2_expect.pkl differ diff --git a/test/expect/ResnetTester.test_wide_resnet50_2_expect.pkl b/test/expect/ResnetTester.test_wide_resnet50_2_expect.pkl new file mode 100644 index 00000000000..4a7c8d2a9d6 Binary files /dev/null and b/test/expect/ResnetTester.test_wide_resnet50_2_expect.pkl differ diff --git a/test/expect/ShufflenetTester.test_shufflenet_v2_x0_5_expect.pkl b/test/expect/ShufflenetTester.test_shufflenet_v2_x0_5_expect.pkl new file mode 100644 index 00000000000..313c3722093 Binary files /dev/null and b/test/expect/ShufflenetTester.test_shufflenet_v2_x0_5_expect.pkl differ diff --git a/test/expect/ShufflenetTester.test_shufflenet_v2_x1_0_expect.pkl b/test/expect/ShufflenetTester.test_shufflenet_v2_x1_0_expect.pkl new file mode 100644 index 00000000000..ff3d93dfc6c Binary files /dev/null and b/test/expect/ShufflenetTester.test_shufflenet_v2_x1_0_expect.pkl differ diff --git a/test/expect/ShufflenetTester.test_shufflenet_v2_x1_5_expect.pkl b/test/expect/ShufflenetTester.test_shufflenet_v2_x1_5_expect.pkl new file mode 100644 index 00000000000..a4f1426e95a Binary files /dev/null and b/test/expect/ShufflenetTester.test_shufflenet_v2_x1_5_expect.pkl differ diff --git a/test/expect/ShufflenetTester.test_shufflenet_v2_x2_0_expect.pkl b/test/expect/ShufflenetTester.test_shufflenet_v2_x2_0_expect.pkl new file mode 100644 index 00000000000..208449cd38f Binary files /dev/null and b/test/expect/ShufflenetTester.test_shufflenet_v2_x2_0_expect.pkl differ diff --git a/test/expect/SqueezenetTester.test_squeezenet1_0_expect.pkl b/test/expect/SqueezenetTester.test_squeezenet1_0_expect.pkl new file mode 100644 index 00000000000..9cc5f9a1e18 Binary files /dev/null and b/test/expect/SqueezenetTester.test_squeezenet1_0_expect.pkl differ diff --git a/test/expect/SqueezenetTester.test_squeezenet1_1_expect.pkl b/test/expect/SqueezenetTester.test_squeezenet1_1_expect.pkl new file mode 100644 index 00000000000..0f5fe9c8e77 Binary files /dev/null and b/test/expect/SqueezenetTester.test_squeezenet1_1_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg11_bn_expect.pkl b/test/expect/VGGNetTester.test_vgg11_bn_expect.pkl new file mode 100644 index 00000000000..d48fc986c9e Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg11_bn_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg11_expect.pkl b/test/expect/VGGNetTester.test_vgg11_expect.pkl new file mode 100644 index 00000000000..ef0eecbfb3a Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg11_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg13_bn_expect.pkl b/test/expect/VGGNetTester.test_vgg13_bn_expect.pkl new file mode 100644 index 00000000000..a948f33ba97 Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg13_bn_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg13_expect.pkl b/test/expect/VGGNetTester.test_vgg13_expect.pkl new file mode 100644 index 00000000000..044e160ca44 Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg13_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg16_bn_expect.pkl b/test/expect/VGGNetTester.test_vgg16_bn_expect.pkl new file mode 100644 index 00000000000..7c5f83594f9 Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg16_bn_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg16_expect.pkl b/test/expect/VGGNetTester.test_vgg16_expect.pkl new file mode 100644 index 00000000000..82803be0e23 Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg16_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg19_bn_expect.pkl b/test/expect/VGGNetTester.test_vgg19_bn_expect.pkl new file mode 100644 index 00000000000..260f506eb3e Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg19_bn_expect.pkl differ diff --git a/test/expect/VGGNetTester.test_vgg19_expect.pkl b/test/expect/VGGNetTester.test_vgg19_expect.pkl new file mode 100644 index 00000000000..04c1e9aa37b Binary files /dev/null and b/test/expect/VGGNetTester.test_vgg19_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 1864d233772..a563c35a0d4 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -7,17 +7,28 @@ import unittest import traceback import random +import inspect -def set_rng_seed(seed): +STANDARD_NUM_CLASSES = 50 +STANDARD_INPUT_SHAPE = (1, 3, 224, 224) +STANDARD_SEED = 1729 + + +def set_rng_seed(seed=STANDARD_SEED): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) -def get_available_classification_models(): - # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] +def subsample_tensor(tensor, num_samples=20): + num_elems = tensor.numel() + if num_elems <= num_samples: + return tensor + + flat_tensor = tensor.flatten() + ith_index = num_elems // num_samples + return flat_tensor[ith_index - 1::ith_index] def get_available_segmentation_models(): @@ -40,22 +51,32 @@ def get_available_video_models(): # they are not yet supported in JIT. script_test_models = [ "deeplabv3_resnet101", - "mobilenet_v2", - "resnext50_32x4d", "fcn_resnet101", - "googlenet", - "densenet121", - "resnet18", - "alexnet", - "shufflenet_v2_x1_0", - "squeezenet1_0", - "vgg11", - "inception_v3", 'r3d_18', ] class ModelTester(TestCase): + + # create random tensor with given shape using synced RNG state + # caching because these tests take pretty long already (instantiating models and all) + TEST_INPUTS = {} + + def _get_test_input(self, shape=STANDARD_INPUT_SHAPE): + # NOTE not thread-safe, but should give same results even if multi-threaded testing gave a race condition + # giving consistent results is kind of the point of this helper method + if shape not in self.TEST_INPUTS: + set_rng_seed(STANDARD_SEED) + self.TEST_INPUTS[shape] = torch.rand(shape) + return self.TEST_INPUTS[shape] + + # create a randomly-weighted model w/ synced RNG state + def _get_test_model(self, callable, **kwargs): + set_rng_seed(STANDARD_SEED) + model = callable(**kwargs) + model.eval() + return model + def check_script(self, model, name): if name not in script_test_models: return @@ -69,16 +90,256 @@ def check_script(self, model, name): msg = str(e) + str(tb) self.assertTrue(scriptable, msg) - def _test_classification_model(self, name, input_shape): - # passing num_class equal to a number other than 1000 helps in making the test - # more enforcing in nature - model = models.__dict__[name](num_classes=50) - self.check_script(model, name) - model.eval() - x = torch.rand(input_shape) - out = model(x) - self.assertEqual(out.shape[-1], 50) + def _check_scriptable(self, model, expected): + if expected is None: # we don't check scriptability for all models + return + + actual = True + msg = '' + try: + torch.jit.script(model) + except Exception as e: + tb = traceback.format_exc() + actual = False + msg = str(e) + str(tb) + self.assertEqual(actual, expected, msg) + + +class ClassificationCoverageTester(TestCase): + + # Find all models exposed by torchvision.models factory methods (with assumptions) + def get_available_classification_models(self): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + + # Recursively gather test methods from all classification testers + def get_test_methods_for_class(self, klass): + all_methods = inspect.getmembers(klass, predicate=inspect.isfunction) + test_methods = {method[0] for method in all_methods if method[0].startswith('test_')} + for child in klass.__subclasses__(): + test_methods = test_methods.union(self.get_test_methods_for_class(child)) + return test_methods + + # Verify that all models exposed by torchvision.models factory methods + # have corresponding test methods + # NOTE This does not include some of the extra tests (such as Resnet + # dilation) and says nothing about the correctness of the test, nor + # of the model. It just enforces a naming scheme on the tests, and + # verifies that all models have a corresponding test name. + def test_classification_model_coverage(self): + model_names = self.get_available_classification_models() + test_names = self.get_test_methods_for_class(ClassificationModelTester) + + for model_name in model_names: + test_name = 'test_' + model_name + self.assertTrue(test_name in test_names) + + +class ClassificationModelTester(ModelTester): + def _infer_for_test_with(self, model, test_input): + return model(test_input) + + def _check_classification_output_shape(self, test_output, num_classes): + self.assertEqual(test_output.shape, (1, num_classes)) + + # NOTE Depends on presence of test data fixture. See common_utils.py for + # details on creating fixtures. + def _check_model_correctness(self, model, test_input, num_classes=STANDARD_NUM_CLASSES): + test_output = self._infer_for_test_with(model, test_input) + self._check_classification_output_shape(test_output, num_classes) + self.assertExpected(test_output, rtol=1e-5, atol=1e-5) + return test_output + + # NOTE override this in a child class + def _get_input_shape(self): + return STANDARD_INPUT_SHAPE + + def _test_classification_model(self, model_callable, num_classes=STANDARD_NUM_CLASSES, **kwargs): + model = self._get_test_model(model_callable, num_classes=num_classes, **kwargs) + self._check_scriptable(model, True) # currently, all expected to be scriptable + test_input = self._get_test_input(shape=self._get_input_shape()) + test_output = self._check_model_correctness(model, test_input) + return model, test_input, test_output + + +class AlexnetTester(ClassificationModelTester): + def test_alexnet(self): + self._test_classification_model(models.alexnet) + + +# TODO add test for aux_logits arg to factory method +# TODO add test for transform_input arg to factory method +class InceptionV3Tester(ClassificationModelTester): + def _get_input_shape(self): + return (1, 3, 299, 299) + + def test_inception_v3(self): + self._test_classification_model(models.inception_v3) + + +class SqueezenetTester(ClassificationModelTester): + def test_squeezenet1_0(self): + self._test_classification_model(models.squeezenet1_0) + + def test_squeezenet1_1(self): + self._test_classification_model(models.squeezenet1_1) + + +# TODO add test for width_mult arg to factory method +class MobilenetTester(ClassificationModelTester): + def test_mobilenet_v2(self): + self._test_classification_model(models.mobilenet_v2) + + def test_mobilenetv2_residual_setting(self): + self._test_classification_model(models.mobilenet_v2, inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) + + +# TODO add test for aux_logits arg to factory method +# TODO add test for transform_input arg to factory method +class GooglenetTester(ClassificationModelTester): + def test_googlenet(self): + self._test_classification_model(models.googlenet) + + +class VGGNetTester(ClassificationModelTester): + def test_vgg11(self): + self._test_classification_model(models.vgg11) + + def test_vgg11_bn(self): + self._test_classification_model(models.vgg11_bn) + + def test_vgg13(self): + self._test_classification_model(models.vgg13) + + def test_vgg13_bn(self): + self._test_classification_model(models.vgg13_bn) + + def test_vgg16(self): + self._test_classification_model(models.vgg16) + + def test_vgg16_bn(self): + self._test_classification_model(models.vgg16_bn) + + def test_vgg19(self): + self._test_classification_model(models.vgg19) + + def test_vgg19_bn(self): + self._test_classification_model(models.vgg19_bn) + + +# TODO add test for dropout arg to factory method +class MNASNetTester(ClassificationModelTester): + def test_mnasnet0_5(self): + self._test_classification_model(models.mnasnet0_5) + + def test_mnasnet0_75(self): + self._test_classification_model(models.mnasnet0_75) + + def test_mnasnet1_0(self): + self._test_classification_model(models.mnasnet1_0) + + def test_mnasnet1_3(self): + self._test_classification_model(models.mnasnet1_3) + + +# TODO add test for bn_size arg to factory method +# TODO add test for drop_rate arg to factory method +class DensenetTester(ClassificationModelTester): + def _test_densenet_plus_mem_eff(self, model_callable): + model, test_input, test_output = self._test_classification_model(model_callable) + + # above, we perform the standard correctness test against the test fixture, and capture key test params + # below, we check that memory efficient/time inefficient DenseNet implementation behaves like the "standard" one + me_model = self._get_test_model(model_callable, num_classes=STANDARD_NUM_CLASSES, memory_efficient=True) + me_model.load_state_dict(model.state_dict()) # xfer weights over + me_output = self._infer_for_test_with(me_model, test_input) + test_output.squeeze(0) + me_output.squeeze(0) + # NOTE testing against same memory fixtures as the non-mem-efficient version + self.assertExpected(test_output, rtol=1e-5, atol=1e-5) + + def test_densenet121(self): + self._test_densenet_plus_mem_eff(models.densenet121) + + def test_densenet161(self): + self._test_densenet_plus_mem_eff(models.densenet161) + + def test_densenet169(self): + self._test_densenet_plus_mem_eff(models.densenet169) + + def test_densenet201(self): + self._test_densenet_plus_mem_eff(models.densenet201) + + +class ShufflenetTester(ClassificationModelTester): + def test_shufflenet_v2_x0_5(self): + self._test_classification_model(models.shufflenet_v2_x0_5) + + def test_shufflenet_v2_x1_0(self): + self._test_classification_model(models.shufflenet_v2_x1_0) + + def test_shufflenet_v2_x1_5(self): + self._test_classification_model(models.shufflenet_v2_x1_5) + + def test_shufflenet_v2_x2_0(self): + self._test_classification_model(models.shufflenet_v2_x2_0) + + +# TODO add test for zero_init_residual arg to factory method +# TODO add test for norm_layer arg to factory method +class ResnetTester(ClassificationModelTester): + def _get_scriptability_value(self): + return True + + def test_resnet18(self): + self._test_classification_model(models.resnet18) + + def test_resnet34(self): + self._test_classification_model(models.resnet34) + + def test_resnet50(self): + self._test_classification_model(models.resnet50) + + def test_resnet101(self): + self._test_classification_model(models.resnet101) + + def test_resnet152(self): + self._test_classification_model(models.resnet152) + + def test_resnext50_32x4d(self): + self._test_classification_model(models.resnext50_32x4d) + def test_resnext101_32x8d(self): + self._test_classification_model(models.resnext101_32x8d) + + def test_wide_resnet50_2(self): + self._test_classification_model(models.wide_resnet50_2) + + def test_wide_resnet101_2(self): + self._test_classification_model(models.wide_resnet101_2) + + def _make_sliced_model(self, model, stop_layer): + layers = OrderedDict() + for name, layer in model.named_children(): + layers[name] = layer + if name == stop_layer: + break + new_model = torch.nn.Sequential(layers) + return new_model + + def test_resnet_dilation(self): + # TODO improve tests to also check that each layer has the right dimensionality + for i in product([False, True], [False, True], [False, True]): + model = models.__dict__["resnet50"](replace_stride_with_dilation=i) + model = self._make_sliced_model(model, stop_layer="layer4") + model.eval() + x = self._get_test_input(STANDARD_INPUT_SHAPE) + out = model(x) + f = 2 ** sum(i) + self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) + + +class SegmentationModelTester(ModelTester): def _test_segmentation_model(self, name): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature @@ -90,6 +351,8 @@ def _test_segmentation_model(self, name): out = model(x) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) + +class DetectionModelTester(ModelTester): def _test_detection_model(self, name): set_rng_seed(0) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) @@ -102,16 +365,6 @@ def _test_detection_model(self, name): self.assertIs(model_input[0], x) self.assertEqual(len(out), 1) - def subsample_tensor(tensor): - num_elems = tensor.numel() - num_samples = 20 - if num_elems <= num_samples: - return tensor - - flat_tensor = tensor.flatten() - ith_index = num_elems // num_samples - return flat_tensor[ith_index - 1::ith_index] - def compute_mean_std(tensor): # can't compute mean of integral tensor tensor = tensor.to(torch.double) @@ -132,64 +385,6 @@ def compute_mean_std(tensor): self.assertTrue("scores" in out[0]) self.assertTrue("labels" in out[0]) - def _test_video_model(self, name): - # the default input shape is - # bs * num_channels * clip_len * h *w - input_shape = (1, 3, 4, 112, 112) - # test both basicblock and Bottleneck - model = models.video.__dict__[name](num_classes=50) - self.check_script(model, name) - x = torch.rand(input_shape) - out = model(x) - self.assertEqual(out.shape[-1], 50) - - def _make_sliced_model(self, model, stop_layer): - layers = OrderedDict() - for name, layer in model.named_children(): - layers[name] = layer - if name == stop_layer: - break - new_model = torch.nn.Sequential(layers) - return new_model - - def test_memory_efficient_densenet(self): - input_shape = (1, 3, 300, 300) - x = torch.rand(input_shape) - - for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']: - model1 = models.__dict__[name](num_classes=50, memory_efficient=True) - params = model1.state_dict() - model1.eval() - out1 = model1(x) - out1.sum().backward() - - model2 = models.__dict__[name](num_classes=50, memory_efficient=False) - model2.load_state_dict(params) - model2.eval() - out2 = model2(x) - - max_diff = (out1 - out2).abs().max() - - self.assertTrue(max_diff < 1e-5) - - def test_resnet_dilation(self): - # TODO improve tests to also check that each layer has the right dimensionality - for i in product([False, True], [False, True], [False, True]): - model = models.__dict__["resnet50"](replace_stride_with_dilation=i) - model = self._make_sliced_model(model, stop_layer="layer4") - model.eval() - x = torch.rand(1, 3, 224, 224) - out = model(x) - f = 2 ** sum(i) - self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) - - def test_mobilenetv2_residual_setting(self): - model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) - model.eval() - x = torch.rand(1, 3, 224, 224) - out = model(x) - self.assertEqual(out.shape[-1], 1000) - def test_fasterrcnn_double(self): model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) model.double() @@ -205,16 +400,17 @@ def test_fasterrcnn_double(self): self.assertTrue("labels" in out[0]) -for model_name in get_available_classification_models(): - # for-loop bodies don't define scopes, so we have to save the variables - # we want to close over in some way - def do_test(self, model_name=model_name): - input_shape = (1, 3, 224, 224) - if model_name in ['inception_v3']: - input_shape = (1, 3, 299, 299) - self._test_classification_model(model_name, input_shape) - - setattr(ModelTester, "test_" + model_name, do_test) +class VideoModelTester(ModelTester): + def _test_video_model(self, name): + # the default input shape is + # bs * num_channels * clip_len * h *w + input_shape = (1, 3, 4, 112, 112) + # test both basicblock and Bottleneck + model = models.video.__dict__[name](num_classes=50) + self.check_script(model, name) + x = torch.rand(input_shape) + out = model(x) + self.assertEqual(out.shape[-1], 50) for model_name in get_available_segmentation_models(): @@ -223,7 +419,7 @@ def do_test(self, model_name=model_name): def do_test(self, model_name=model_name): self._test_segmentation_model(model_name) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(SegmentationModelTester, "test_" + model_name, do_test) for model_name in get_available_detection_models(): @@ -232,7 +428,7 @@ def do_test(self, model_name=model_name): def do_test(self, model_name=model_name): self._test_detection_model(model_name) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(DetectionModelTester, "test_" + model_name, do_test) for model_name in get_available_video_models(): @@ -240,7 +436,7 @@ def do_test(self, model_name=model_name): def do_test(self, model_name=model_name): self._test_video_model(model_name) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(VideoModelTester, "test_" + model_name, do_test) if __name__ == '__main__': unittest.main()