diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 56119edf2db..d03a76a79ff 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -46,7 +46,12 @@ def inner_wrapper(request, *args, **kwargs): @contextlib.contextmanager -def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None): +def log_download_attempts( + urls_and_md5s=None, + patch=True, + download_url_target="torchvision.datasets.utils.download_url", + patch_auxiliaries=None, +): if urls_and_md5s is None: urls_and_md5s = set() if patch_auxiliaries is None: @@ -54,7 +59,7 @@ def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None with contextlib.ExitStack() as stack: download_url_mock = stack.enter_context( - unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) + unittest.mock.patch(download_url_target, wraps=None if patch else download_url) ) if patch_auxiliaries: # download_and_extract_archive @@ -127,13 +132,9 @@ def make_download_configs(urls_and_md5s, name=None): ] -def collect_download_configs(dataset_loader, name): - try: - with log_download_attempts() as urls_and_md5s: - dataset_loader() - except Exception: - pass - +def collect_download_configs(dataset_loader, name, **kwargs): + with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s: + dataset_loader() return make_download_configs(urls_and_md5s, name) @@ -164,6 +165,17 @@ def cifar100(): return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100") +def voc(): + download_configs = [] + for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"): + with contextlib.suppress(Exception), log_download_attempts( + download_url_target="torchvision.datasets.voc.download_url" + ) as urls_and_md5s: + datasets.VOCSegmentation(".", year=year, download=True) + download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}")) + return download_configs + + def make_parametrize_kwargs(download_configs): argvalues = [] ids = [] @@ -175,7 +187,16 @@ def make_parametrize_kwargs(download_configs): @pytest.mark.parametrize( - **make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100())) + **make_parametrize_kwargs( + itertools.chain( + places365(), + caltech101(), + caltech256(), + cifar10(), + cifar100(), + voc(), + ) + ) ) def test_url_is_accessible(url, md5): retry(lambda: assert_url_is_accessible(url))