diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py
index dbc9cf2a6b4..4ef9daa4fed 100644
--- a/torchvision/datasets/utils.py
+++ b/torchvision/datasets/utils.py
@@ -11,6 +11,7 @@
import urllib
import urllib.error
import urllib.request
+import warnings
import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
@@ -24,22 +25,31 @@
_is_remote_location_available,
)
-
USER_AGENT = "pytorch/vision"
-def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
- with open(filename, "wb") as fh:
- with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
- with tqdm(total=response.length) as pbar:
- for chunk in iter(lambda: response.read(chunk_size), ""):
- if not chunk:
- break
- pbar.update(chunk_size)
- fh.write(chunk)
+def _save_response_content(
+ content: Iterator[bytes],
+ destination: str,
+ length: Optional[int] = None,
+) -> None:
+ with open(destination, "wb") as fh, tqdm(total=length) as pbar:
+ for chunk in content:
+ # filter out keep-alive new chunks
+ if not chunk:
+ continue
+
+ fh.write(chunk)
+ pbar.update(len(chunk))
+
+
+def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
+ with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
+ _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def gen_bar_updater() -> Callable[[int, int, int], None]:
+ warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files
-def _quota_exceeded(first_chunk: bytes) -> bool:
+def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
+ content = response.iter_content(chunk_size)
+ first_chunk = None
+ # filter out keep-alive new chunks
+ while not first_chunk:
+ first_chunk = next(content)
+ content = itertools.chain([first_chunk], content)
+
try:
- return "Google Drive - Quota exceeded" in first_chunk.decode()
+ match = re.search("
Google Drive - (?P.+?)", first_chunk.decode())
+ api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError:
- return False
+ api_response = None
+ return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
@@ -202,8 +221,6 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
- url = "https://docs.google.com/uc?export=download"
-
root = os.path.expanduser(root)
if not filename:
filename = file_id
@@ -211,61 +228,34 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
os.makedirs(root, exist_ok=True)
- if os.path.isfile(fpath) and check_integrity(fpath, md5):
- print("Using downloaded and verified file: " + fpath)
- else:
- session = requests.Session()
-
- response = session.get(url, params={"id": file_id}, stream=True)
- token = _get_confirm_token(response)
-
- if token:
- params = {"id": file_id, "confirm": token}
- response = session.get(url, params=params, stream=True)
-
- # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
- # with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
- # Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
- # the first_chunk of the payload
- response_content_generator = response.iter_content(32768)
- first_chunk = None
- while not first_chunk: # filter out keep-alive new chunks
- first_chunk = next(response_content_generator)
-
- if _quota_exceeded(first_chunk):
- msg = (
- f"The daily quota of the file {filename} is exceeded and it "
- f"can't be downloaded. This is a limitation of Google Drive "
- f"and can only be overcome by trying again later."
- )
- raise RuntimeError(msg)
-
- _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
- response.close()
+ if check_integrity(fpath, md5):
+ print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
+ url = "https://drive.google.com/uc"
+ params = dict(id=file_id, export="download")
+ with requests.Session() as session:
+ response = session.get(url, params=params, stream=True)
-def _get_confirm_token(response: requests.models.Response) -> Optional[str]:
- for key, value in response.cookies.items():
- if key.startswith("download_warning"):
- return value
+ for key, value in response.cookies.items():
+ if key.startswith("download_warning"):
+ token = value
+ break
+ else:
+ api_response, content = _extract_gdrive_api_response(response)
+ token = "t" if api_response == "Virus scan warning" else None
- return None
+ if token is not None:
+ response = session.get(url, params=dict(params, confirm=token), stream=True)
+ api_response, content = _extract_gdrive_api_response(response)
+ if api_response == "Quota exceeded":
+ raise RuntimeError(
+ f"The daily quota of the file {filename} is exceeded and it "
+ f"can't be downloaded. This is a limitation of Google Drive "
+ f"and can only be overcome by trying again later."
+ )
-def _save_response_content(
- response_gen: Iterator[bytes],
- destination: str,
-) -> None:
- with open(destination, "wb") as f:
- pbar = tqdm(total=None)
- progress = 0
-
- for chunk in response_gen:
- if chunk: # filter out keep-alive new chunks
- f.write(chunk)
- progress += len(chunk)
- pbar.update(progress - pbar.n)
- pbar.close()
+ _save_response_content(content, fpath)
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: