Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 55 additions & 65 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import urllib
import urllib.error
import urllib.request
import warnings
import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
Expand All @@ -24,22 +25,31 @@
_is_remote_location_available,
)


USER_AGENT = "pytorch/vision"


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both this function and _save_response_content below had the same functionality, so I factored it out.

with open(filename, "wb") as fh:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
with tqdm(total=response.length) as pbar:
for chunk in iter(lambda: response.read(chunk_size), ""):
if not chunk:
break
pbar.update(chunk_size)
fh.write(chunk)
def _save_response_content(
content: Iterator[bytes],
destination: str,
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
for chunk in content:
# filter out keep-alive new chunks
if not chunk:
continue

fh.write(chunk)
pbar.update(len(chunk))


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)


def gen_bar_updater() -> Callable[[int, int, int], None]:
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None)

def bar_update(count, block_size, total_size):
Expand Down Expand Up @@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files


def _quota_exceeded(first_chunk: bytes) -> bool:
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)

try:
return "Google Drive - Quota exceeded" in first_chunk.decode()
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError:
return False
api_response = None
return api_response, content


def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
Expand All @@ -202,70 +221,41 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can drop this comment at some point, because our implementation no longer resembles the answer in the link.


url = "https://docs.google.com/uc?export=download"

root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)

os.makedirs(root, exist_ok=True)

if os.path.isfile(fpath) and check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
else:
session = requests.Session()

response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response)

if token:
params = {"id": file_id, "confirm": token}
response = session.get(url, params=params, stream=True)

# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
# the first_chunk of the payload
response_content_generator = response.iter_content(32768)
first_chunk = None
while not first_chunk: # filter out keep-alive new chunks
first_chunk = next(response_content_generator)

if _quota_exceeded(first_chunk):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)

_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close()
if check_integrity(fpath, md5):
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")

url = "https://drive.google.com/uc"
params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True)

def _get_confirm_token(response: requests.models.Response) -> Optional[str]:
for key, value in response.cookies.items():
if key.startswith("download_warning"):
return value
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
else:
api_response, content = _extract_gdrive_api_response(response)
token = "t" if api_response == "Virus scan warning" else None

return None
if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)

if api_response == "Quota exceeded":
raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)

def _save_response_content(
response_gen: Iterator[bytes],
destination: str,
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0

for chunk in response_gen:
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
_save_response_content(content, fpath)


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
Expand Down