Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add download tests for Caltech(101|256) #2731

Merged
merged 2 commits into from Sep 30, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 52 additions & 13 deletions test/test_datasets_download.py
Expand Up @@ -45,13 +45,23 @@ def inner_wrapper(request, *args, **kwargs):


@contextlib.contextmanager
def log_download_attempts(patch=True):
urls_and_md5s = set()
with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock:
def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None):
if urls_and_md5s is None:
urls_and_md5s = set()
if patch_auxiliaries is None:
patch_auxiliaries = patch

with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context(
unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url)
)
if patch_auxiliaries:
# download_and_extract_archive
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
try:
yield urls_and_md5s
finally:
for args, kwargs in mock.call_args_list:
for args, kwargs in download_url_mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
Expand Down Expand Up @@ -105,15 +115,14 @@ def __init__(self, url, md5=None, id=None):
self.md5 = md5
self.id = id or url

def __repr__(self):
return self.id
Comment on lines +118 to +119
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for easier debugging.


def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)

return dict(argnames="url, md5", argvalues=argvalues, ids=ids)
def make_download_configs(urls_and_md5s, name=None):
return [
DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
]


def places365():
Expand All @@ -124,10 +133,40 @@ def places365():

datasets.Places365(root, split=split, small=small, download=True)

return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s]
return make_download_configs(urls_and_md5s, "Places365")


def caltech101():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech101(".", download=True)
except Exception:
pass

return make_download_configs(urls_and_md5s, "Caltech101")


def caltech256():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech256(".", download=True)
except Exception:
pass

return make_download_configs(urls_and_md5s, "Caltech256")


def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)

return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)


@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),)))
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256())))
def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))

Expand Down