From ea44bf2d58f67927b4f4784f89de992dd88682cb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Jun 2021 11:42:31 +0100 Subject: [PATCH] Port internet tests to pytest --- .../unittest/linux/scripts/environment.yml | 1 + .../unittest/windows/scripts/environment.yml | 1 + CONTRIBUTING.md | 2 +- test/test_internet.py | 31 +++++++++---------- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index fcf61a6e2f8..c0d36f95a43 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -6,6 +6,7 @@ channels: dependencies: - pytest - pytest-cov + - pytest-mock - pip - libpng # NOTE: Pinned to fix issues with size_t on Windows diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index 9a916a27d07..11de8e089e8 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -6,6 +6,7 @@ channels: dependencies: - pytest - pytest-cov + - pytest-mock - pip - libpng # NOTE: Pinned to fix issues with size_t on Windows diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 748dc50df9e..3fa41efec26 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,7 +49,7 @@ python setup.py develop # MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py develop # for C++ debugging, please use DEBUG=1 # DEBUG=1 python setup.py develop -pip install flake8 typing mypy pytest scipy +pip install flake8 typing mypy pytest pytest-mock scipy ``` You may also have to install `libpng-dev` and `libjpeg-turbo8-dev` libraries: ```bash diff --git a/test/test_internet.py b/test/test_internet.py index 05496752c7f..772379a2289 100644 --- a/test/test_internet.py +++ b/test/test_internet.py @@ -6,8 +6,7 @@ """ import os -import unittest -import unittest.mock +import pytest import warnings from urllib.error import URLError @@ -15,7 +14,7 @@ from common_utils import get_tmp_dir -class DatasetUtilsTester(unittest.TestCase): +class TestDatasetUtils: def test_get_redirect_url(self): url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" @@ -26,7 +25,7 @@ def test_get_redirect_url(self): def test_get_redirect_url_max_hops_exceeded(self): url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" - with self.assertRaises(RecursionError): + with pytest.raises(RecursionError): utils._get_redirect_url(url, max_hops=0) def test_download_url(self): @@ -34,38 +33,38 @@ def test_download_url(self): url = "http://github.com/pytorch/vision/archive/master.zip" try: utils.download_url(url, temp_dir) - self.assertFalse(len(os.listdir(temp_dir)) == 0) + assert len(os.listdir(temp_dir)) != 0 except URLError: - msg = "could not download test file '{}'".format(url) - warnings.warn(msg, RuntimeWarning) - raise unittest.SkipTest(msg) + pytest.skip(f"could not download test file '{url}'") def test_download_url_retry_http(self): with get_tmp_dir() as temp_dir: url = "https://github.com/pytorch/vision/archive/master.zip" try: utils.download_url(url, temp_dir) - self.assertFalse(len(os.listdir(temp_dir)) == 0) + assert len(os.listdir(temp_dir)) != 0 except URLError: - msg = "could not download test file '{}'".format(url) - warnings.warn(msg, RuntimeWarning) - raise unittest.SkipTest(msg) + pytest.skip(f"could not download test file '{url}'") def test_download_url_dont_exist(self): with get_tmp_dir() as temp_dir: url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip" - with self.assertRaises(URLError): + with pytest.raises(URLError): utils.download_url(url, temp_dir) - @unittest.mock.patch("torchvision.datasets.utils.download_file_from_google_drive") - def test_download_url_dispatch_download_from_google_drive(self, mock): + def test_download_url_dispatch_download_from_google_drive(self, mocker): url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" filename = "filename" md5 = "md5" + mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive') with get_tmp_dir() as root: utils.download_url(url, root, filename, md5) - mock.assert_called_once_with(id, root, filename, md5) + mocked.assert_called_once_with(id, root, filename, md5) + + +if __name__ == '__main__': + pytest.main([__file__])