Skip to content

Commit

Permalink
fix: make downloading from GDrive more robust (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 authored Jan 23, 2024
1 parent 76ff646 commit f7edaa6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion rul_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from .baseline import BaselineDataModule, PretrainingBaselineDataModule
from .core import RulDataModule
from .reader import CmapssReader, FemtoReader, XjtuSyReader
from .reader import CmapssReader, FemtoReader, XjtuSyReader, NCmapssReader
from .reader.data_root import get_data_root, set_data_root
from .ssl import SemiSupervisedDataModule

Expand Down
2 changes: 1 addition & 1 deletion rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,12 @@ def _calc_default_window_size(self):


def _download_ncmapss(data_root):
os.makedirs(data_root)
with tempfile.TemporaryDirectory() as tmp_path:
print("Download N-C-MAPSS dataset from Google Drive")
download_path = os.path.join(tmp_path, "data.zip")
utils.download_gdrive_file(NCMAPSS_DRIVE_ID, download_path)
print("Extract N-C-MAPSS dataset")
os.makedirs(data_root)
with zipfile.ZipFile(download_path, mode="r") as f:
for zipinfo in f.infolist():
zipinfo.filename = os.path.basename(zipinfo.filename)
Expand Down
9 changes: 9 additions & 0 deletions rul_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def download_gdrive_file(file_id: str, save_path: str) -> None:
if response.text.startswith("<!DOCTYPE html>"):
params = {"id": file_id, "confirm": "t"}
response = session.post(GDRIVE_URL_BASE, params=params, stream=True)
if response.status_code == 429:
raise RuntimeError(
"Download failed. Server returned 429. "
"This is usually caused by too many requests. "
"Please try again later."
)
elif not response.status_code == 200:
raise RuntimeError(f"Download failed. Server returned {response.status_code}")
_write_content(response, save_path)


Expand All @@ -112,6 +120,7 @@ def _write_content(response: requests.Response, save_path: str) -> None:
if chunk:
pbar.update(len(chunk))
f.write(chunk)
f.flush()
pbar.close()


Expand Down

0 comments on commit f7edaa6

Please sign in to comment.