-
Notifications
You must be signed in to change notification settings - Fork 7
/
model.py
117 lines (107 loc) · 7.14 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import absolute_import
import torch
from torch import nn
from models import resnet, pre_act_resnet, wide_resnet, resnext, densenet
def generate_model(opt):
assert opt.mode in ['score', 'feature']
if opt.mode == 'score':
last_fc = True
elif opt.mode == 'feature':
last_fc = False
assert opt.model_name in ['resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet']
if opt.model_name == 'resnet':
assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
if opt.model_depth == 10:
model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 18:
model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 34:
model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 50:
model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 101:
model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 152:
model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 200:
model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
polices = resnet.get_fine_tuning_parameters(model, opt.ft_begin_index)
elif opt.model_name == 'wideresnet':
assert opt.model_depth in [50]
if opt.model_depth == 50:
model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_name == 'resnext':
assert opt.model_depth in [50, 101, 152]
if opt.model_depth == 50:
model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 101:
model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 152:
model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_name == 'preresnet':
assert opt.model_depth in [18, 34, 50, 101, 152, 200]
if opt.model_depth == 18:
model = pre_act_resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 34:
model = pre_act_resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 50:
model = pre_act_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 101:
model = pre_act_resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 152:
model = pre_act_resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 200:
model = pre_act_resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_name == 'densenet':
assert opt.model_depth in [121, 169, 201, 264]
if opt.model_depth == 121:
model = densenet.densenet121(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 169:
model = densenet.densenet169(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 201:
model = densenet.densenet201(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
elif opt.model_depth == 264:
model = densenet.densenet264(num_classes=opt.n_classes,
sample_size=opt.sample_size, sample_duration=opt.sample_duration,
last_fc=last_fc)
return model, polices