diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index dbc9cf2a6b4..4ef9daa4fed 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -11,6 +11,7 @@ import urllib import urllib.error import urllib.request +import warnings import zipfile from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator from urllib.parse import urlparse @@ -24,22 +25,31 @@ _is_remote_location_available, ) - USER_AGENT = "pytorch/vision" -def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: - with open(filename, "wb") as fh: - with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: - with tqdm(total=response.length) as pbar: - for chunk in iter(lambda: response.read(chunk_size), ""): - if not chunk: - break - pbar.update(chunk_size) - fh.write(chunk) +def _save_response_content( + content: Iterator[bytes], + destination: str, + length: Optional[int] = None, +) -> None: + with open(destination, "wb") as fh, tqdm(total=length) as pbar: + for chunk in content: + # filter out keep-alive new chunks + if not chunk: + continue + + fh.write(chunk) + pbar.update(len(chunk)) + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) def gen_bar_updater() -> Callable[[int, int, int], None]: + warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.") pbar = tqdm(total=None) def bar_update(count, block_size, total_size): @@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: return files -def _quota_exceeded(first_chunk: bytes) -> bool: +def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: + content = response.iter_content(chunk_size) + first_chunk = None + # filter out keep-alive new chunks + while not first_chunk: + first_chunk = next(content) + content = itertools.chain([first_chunk], content) + try: - return "Google Drive - Quota exceeded" in first_chunk.decode() + match = re.search("Google Drive - (?P<api_response>.+?)", first_chunk.decode()) + api_response = match["api_response"] if match is not None else None except UnicodeDecodeError: - return False + api_response = None + return api_response, content def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): @@ -202,8 +221,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ """ # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url - url = "https://docs.google.com/uc?export=download" - root = os.path.expanduser(root) if not filename: filename = file_id @@ -211,61 +228,34 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ os.makedirs(root, exist_ok=True) - if os.path.isfile(fpath) and check_integrity(fpath, md5): - print("Using downloaded and verified file: " + fpath) - else: - session = requests.Session() - - response = session.get(url, params={"id": file_id}, stream=True) - token = _get_confirm_token(response) - - if token: - params = {"id": file_id, "confirm": token} - response = session.get(url, params=params, stream=True) - - # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent - # with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517. - # Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding - # the first_chunk of the payload - response_content_generator = response.iter_content(32768) - first_chunk = None - while not first_chunk: # filter out keep-alive new chunks - first_chunk = next(response_content_generator) - - if _quota_exceeded(first_chunk): - msg = ( - f"The daily quota of the file {filename} is exceeded and it " - f"can't be downloaded. This is a limitation of Google Drive " - f"and can only be overcome by trying again later." - ) - raise RuntimeError(msg) - - _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath) - response.close() + if check_integrity(fpath, md5): + print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") + url = "https://drive.google.com/uc" + params = dict(id=file_id, export="download") + with requests.Session() as session: + response = session.get(url, params=params, stream=True) -def _get_confirm_token(response: requests.models.Response) -> Optional[str]: - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - return value + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + token = value + break + else: + api_response, content = _extract_gdrive_api_response(response) + token = "t" if api_response == "Virus scan warning" else None - return None + if token is not None: + response = session.get(url, params=dict(params, confirm=token), stream=True) + api_response, content = _extract_gdrive_api_response(response) + if api_response == "Quota exceeded": + raise RuntimeError( + f"The daily quota of the file {filename} is exceeded and it " + f"can't be downloaded. This is a limitation of Google Drive " + f"and can only be overcome by trying again later." + ) -def _save_response_content( - response_gen: Iterator[bytes], - destination: str, -) -> None: - with open(destination, "wb") as f: - pbar = tqdm(total=None) - progress = 0 - - for chunk in response_gen: - if chunk: # filter out keep-alive new chunks - f.write(chunk) - progress += len(chunk) - pbar.update(progress - pbar.n) - pbar.close() + _save_response_content(content, fpath) def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: