/
cached_download.py
137 lines (114 loc) · 3.25 KB
/
cached_download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import print_function
import hashlib
import os
import os.path as osp
import shutil
import sys
import tempfile
import filelock
from .download import download
cache_root = osp.join(osp.expanduser("~"), ".cache/gdown")
if not osp.exists(cache_root):
try:
os.makedirs(cache_root)
except OSError:
pass
def md5sum(filename, blocksize=None):
if blocksize is None:
blocksize = 65536
hash = hashlib.md5()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(blocksize), b""):
hash.update(block)
return hash.hexdigest()
def assert_md5sum(filename, md5, quiet=False, blocksize=None):
if not (isinstance(md5, str) and len(md5) == 32):
raise ValueError("MD5 must be 32 chars: {}".format(md5))
if not quiet:
print("Computing MD5: {}".format(filename))
md5_actual = md5sum(filename)
if md5_actual == md5:
if not quiet:
print("MD5 matches: {}".format(filename))
return True
raise AssertionError(
"MD5 doesn't match:\nactual: {}\nexpected: {}".format(md5_actual, md5)
)
def cached_download(
url,
path=None,
md5=None,
quiet=False,
postprocess=None,
proxy=None,
speed=None,
):
"""Cached downlaod from URL.
Parameters
----------
url: str
URL. Google Drive URL is also supported.
path: str, optional
Output filename. Default is basename of URL.
md5: str, optional
Expected MD5 for specified file.
quiet: bool
Suppress terminal output. Default is False.
postprocess: callable
Function called with filename as postprocess.
proxy: str
Proxy.
speed: float
Download byte size per second (e.g., 256KB/s = 256 * 1024).
Returns
-------
path: str
Output filename.
"""
if path is None:
path = (
url.replace("/", "-SLASH-")
.replace(":", "-COLON-")
.replace("=", "-EQUAL-")
.replace("?", "-QUESTION-")
)
path = osp.join(cache_root, path)
# check existence
if osp.exists(path) and not md5:
if not quiet:
print("File exists: {}".format(path))
return path
elif osp.exists(path) and md5:
try:
assert_md5sum(path, md5, quiet=quiet)
return path
except AssertionError as e:
print(e, file=sys.stderr)
# download
lock_path = osp.join(cache_root, "_dl_lock")
try:
os.makedirs(osp.dirname(path))
except OSError:
pass
temp_root = tempfile.mkdtemp(dir=cache_root)
try:
temp_path = osp.join(temp_root, "dl")
if not quiet:
msg = "Cached Downloading"
if path:
msg = "{}: {}".format(msg, path)
else:
msg = "{}...".format(msg)
print(msg, file=sys.stderr)
download(url, temp_path, quiet=quiet, proxy=proxy, speed=speed)
with filelock.FileLock(lock_path):
shutil.move(temp_path, path)
except Exception:
shutil.rmtree(temp_root)
raise
if md5:
assert_md5sum(path, md5, quiet=quiet)
# postprocess
if postprocess is not None:
postprocess(path)
return path