diff --git a/test/test_models.py b/test/test_models.py index b70314f17c5..7645dc419ff 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -3,6 +3,7 @@ import operator import os import pkgutil +import platform import sys import warnings from collections import OrderedDict @@ -343,12 +344,25 @@ def _check_input_backprop(model, inputs): _model_params[m] = {"input_shape": (1, 3, 64, 64)} -# skip big models to reduce memory usage on CI test +# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device). skipped_big_models = { - "vit_h_14", - "regnet_y_128gf", + "vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")}, + "regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")}, + "mvit_v1_b": {("Windows", "cuda")}, + "mvit_v2_s": {("Windows", "cuda")}, } + +def is_skippable(model_name, device): + if model_name not in skipped_big_models: + return False + + platform_system = platform.system() + device_name = str(device).split(":")[0] + + return (platform_system, device_name) in skipped_big_models[model_name] + + # The following contains configuration and expected values to be used tests that are model specific _model_tests_values = { "retinanet_resnet50_fpn": { @@ -612,7 +626,7 @@ def test_classification_model(model_fn, dev): "input_shape": (1, 3, 224, 224), } model_name = model_fn.__name__ - if SKIP_BIG_MODEL and model_name in skipped_big_models: + if SKIP_BIG_MODEL and is_skippable(model_name, dev): pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") kwargs = {**defaults, **_model_params.get(model_name, {})} num_classes = kwargs.get("num_classes") @@ -841,7 +855,7 @@ def test_video_model(model_fn, dev): "num_classes": 50, } model_name = model_fn.__name__ - if SKIP_BIG_MODEL and model_name in skipped_big_models: + if SKIP_BIG_MODEL and is_skippable(model_name, dev): pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") kwargs = {**defaults, **_model_params.get(model_name, {})} num_classes = kwargs.get("num_classes")