-
Notifications
You must be signed in to change notification settings - Fork 152
/
Copy pathdownload_model.py
57 lines (50 loc) · 1.84 KB
/
download_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import os
import wget
def check(opt):
if opt.model == "pix2pix":
assert opt.task in [
"edges2shoes-r",
"map2sat",
"cityscapes",
"cityscapes_fast",
"edges2shoes-r_fast",
"map2sat_fast",
]
elif opt.model == "cycle_gan":
assert opt.task in ["horse2zebra", "horse2zebra_fast"]
elif opt.model == "gaugan":
assert opt.task in ["cityscapes", "cityscapes_fast", "coco_fast"]
elif opt.model == "munit":
assert opt.task in ["edges2shoes-r_fast"]
else:
raise NotImplementedError("Unsupported model [%s]!" % opt.model)
def download(path):
url = "https://huggingface.co/mit-han-lab/gan-compression/resolve/main/" + path
dir = os.path.dirname(path)
os.makedirs(dir, exist_ok=True)
wget.download(url, path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download a pretrained model.")
parser.add_argument(
"--stage",
type=str,
default="compressed",
choices=["full", "mobile", "distill", "supernet", "finetune", "compressed", "legacy"],
help="specify the stage you want to download",
)
parser.add_argument(
"--model",
type=str,
default="pix2pix",
choices=["pix2pix", "cycle_gan", "gaugan", "munit"],
help="specify the model you want to download",
)
parser.add_argument("--task", type=str, default="horse2zebra", help="the base number of filters of the generator")
opt = parser.parse_args()
check(opt)
path = os.path.join("pretrained", opt.model, opt.task, opt.stage, "latest_net_G.pth")
download(path)
if opt.stage != "compressed" and opt.stage != "legacy":
path = os.path.join("pretrained", opt.model, opt.task, opt.stage, "latest_net_D.pth")
download(path)