diff --git a/test/test_models.py b/test/test_models.py index 716de629360..18d01921be4 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -18,6 +18,7 @@ from torchvision import models ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" +SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" def get_models_from_module(module): @@ -231,6 +232,7 @@ def _check_input_backprop(model, inputs): "lraspp_mobilenet_v3_large", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2", + "keypointrcnn_resnet50_fpn", ) # The tests for the following quantized models are flaky possibly due to inconsistent @@ -329,6 +331,12 @@ def _check_input_backprop(model, inputs): _model_params[m] = {"input_shape": (1, 3, 64, 64)} +# skip big models to reduce memory usage on CI test +skipped_big_models = { + "vit_h_14", + "regnet_y_128gf", +} + # The following contains configuration and expected values to be used tests that are model specific _model_tests_values = { "retinanet_resnet50_fpn": { @@ -592,6 +600,8 @@ def test_classification_model(model_fn, dev): "input_shape": (1, 3, 224, 224), } model_name = model_fn.__name__ + if dev == "cuda" and SKIP_BIG_MODEL and model_name in skipped_big_models: + pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") kwargs = {**defaults, **_model_params.get(model_name, {})} num_classes = kwargs.get("num_classes") input_shape = kwargs.pop("input_shape") @@ -606,7 +616,7 @@ def test_classification_model(model_fn, dev): _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_fx_compatible(model, x, eager_out=out) - if dev == torch.device("cuda"): + if dev == "cuda": with torch.cuda.amp.autocast(): out = model(x) # See autocast_flaky_numerics comment at top of file. @@ -659,7 +669,7 @@ def check_out(out): _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_fx_compatible(model, x, eager_out=out) - if dev == torch.device("cuda"): + if dev == "cuda": with torch.cuda.amp.autocast(): out = model(x) # See autocast_flaky_numerics comment at top of file. @@ -757,7 +767,7 @@ def compute_mean_std(tensor): full_validation = check_out(out) _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) - if dev == torch.device("cuda"): + if dev == "cuda": with torch.cuda.amp.autocast(): out = model(model_input) # See autocast_flaky_numerics comment at top of file. @@ -823,7 +833,7 @@ def test_video_model(model_fn, dev): _check_fx_compatible(model, x, eager_out=out) assert out.shape[-1] == 50 - if dev == torch.device("cuda"): + if dev == "cuda": with torch.cuda.amp.autocast(): out = model(x) assert out.shape[-1] == 50