-
Notifications
You must be signed in to change notification settings - Fork 7.2k
support confirming no virus scan on GDrive download #5645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
3bffa8e
support confirming no virus scan on GDrive download
pmeier 7fe6562
put gen_bar_updater back
pmeier af3cbe6
Merge branch 'main' into gdrive-virus-scan
pmeier 46a5512
Update torchvision/datasets/utils.py
pmeier 4d0980d
Merge branch 'main' into gdrive-virus-scan
pmeier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
import urllib | ||
import urllib.error | ||
import urllib.request | ||
import warnings | ||
import zipfile | ||
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator | ||
from urllib.parse import urlparse | ||
|
@@ -24,22 +25,31 @@ | |
_is_remote_location_available, | ||
) | ||
|
||
|
||
USER_AGENT = "pytorch/vision" | ||
|
||
|
||
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: | ||
with open(filename, "wb") as fh: | ||
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: | ||
with tqdm(total=response.length) as pbar: | ||
for chunk in iter(lambda: response.read(chunk_size), ""): | ||
if not chunk: | ||
break | ||
pbar.update(chunk_size) | ||
fh.write(chunk) | ||
def _save_response_content( | ||
content: Iterator[bytes], | ||
destination: str, | ||
length: Optional[int] = None, | ||
) -> None: | ||
with open(destination, "wb") as fh, tqdm(total=length) as pbar: | ||
for chunk in content: | ||
# filter out keep-alive new chunks | ||
if not chunk: | ||
continue | ||
|
||
fh.write(chunk) | ||
pbar.update(len(chunk)) | ||
|
||
|
||
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: | ||
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: | ||
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) | ||
|
||
|
||
def gen_bar_updater() -> Callable[[int, int, int], None]: | ||
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.") | ||
pbar = tqdm(total=None) | ||
|
||
def bar_update(count, block_size, total_size): | ||
|
@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: | |
return files | ||
|
||
|
||
def _quota_exceeded(first_chunk: bytes) -> bool: | ||
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: | ||
content = response.iter_content(chunk_size) | ||
first_chunk = None | ||
# filter out keep-alive new chunks | ||
while not first_chunk: | ||
first_chunk = next(content) | ||
content = itertools.chain([first_chunk], content) | ||
|
||
try: | ||
return "Google Drive - Quota exceeded" in first_chunk.decode() | ||
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode()) | ||
api_response = match["api_response"] if match is not None else None | ||
except UnicodeDecodeError: | ||
return False | ||
api_response = None | ||
return api_response, content | ||
|
||
|
||
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): | ||
|
@@ -202,70 +221,41 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ | |
""" | ||
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can drop this comment at some point, because our implementation no longer resembles the answer in the link. |
||
|
||
url = "https://docs.google.com/uc?export=download" | ||
|
||
root = os.path.expanduser(root) | ||
if not filename: | ||
filename = file_id | ||
fpath = os.path.join(root, filename) | ||
|
||
os.makedirs(root, exist_ok=True) | ||
|
||
if os.path.isfile(fpath) and check_integrity(fpath, md5): | ||
print("Using downloaded and verified file: " + fpath) | ||
else: | ||
session = requests.Session() | ||
|
||
response = session.get(url, params={"id": file_id}, stream=True) | ||
token = _get_confirm_token(response) | ||
|
||
if token: | ||
params = {"id": file_id, "confirm": token} | ||
response = session.get(url, params=params, stream=True) | ||
|
||
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent | ||
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517. | ||
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding | ||
# the first_chunk of the payload | ||
response_content_generator = response.iter_content(32768) | ||
first_chunk = None | ||
while not first_chunk: # filter out keep-alive new chunks | ||
first_chunk = next(response_content_generator) | ||
|
||
if _quota_exceeded(first_chunk): | ||
msg = ( | ||
f"The daily quota of the file {filename} is exceeded and it " | ||
f"can't be downloaded. This is a limitation of Google Drive " | ||
f"and can only be overcome by trying again later." | ||
) | ||
raise RuntimeError(msg) | ||
|
||
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath) | ||
response.close() | ||
if check_integrity(fpath, md5): | ||
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") | ||
|
||
url = "https://drive.google.com/uc" | ||
params = dict(id=file_id, export="download") | ||
with requests.Session() as session: | ||
response = session.get(url, params=params, stream=True) | ||
|
||
def _get_confirm_token(response: requests.models.Response) -> Optional[str]: | ||
for key, value in response.cookies.items(): | ||
if key.startswith("download_warning"): | ||
return value | ||
for key, value in response.cookies.items(): | ||
if key.startswith("download_warning"): | ||
token = value | ||
break | ||
else: | ||
api_response, content = _extract_gdrive_api_response(response) | ||
token = "t" if api_response == "Virus scan warning" else None | ||
|
||
return None | ||
if token is not None: | ||
response = session.get(url, params=dict(params, confirm=token), stream=True) | ||
api_response, content = _extract_gdrive_api_response(response) | ||
|
||
if api_response == "Quota exceeded": | ||
raise RuntimeError( | ||
f"The daily quota of the file {filename} is exceeded and it " | ||
f"can't be downloaded. This is a limitation of Google Drive " | ||
f"and can only be overcome by trying again later." | ||
) | ||
|
||
def _save_response_content( | ||
response_gen: Iterator[bytes], | ||
destination: str, | ||
) -> None: | ||
with open(destination, "wb") as f: | ||
pbar = tqdm(total=None) | ||
progress = 0 | ||
|
||
for chunk in response_gen: | ||
if chunk: # filter out keep-alive new chunks | ||
f.write(chunk) | ||
progress += len(chunk) | ||
pbar.update(progress - pbar.n) | ||
pbar.close() | ||
_save_response_content(content, fpath) | ||
|
||
|
||
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both this function and
_save_response_content
below had the same functionality, so I factored it out.