-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
151 lines (137 loc) · 4.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from experiment import Experiment
import torch
# from hparam_tuning import Hyperparameter_Grid_Search
# ==============================================================================
# Import your models here:
from models.alexnet_cifar import Alexnet_cifar
from models.alexnet_cifar_gr import Alexnet_cifar_gr
from models.alexnet_cifar_gr_compare import Alexnet_cifar_gr_compare
from models.alexnet_cifar_gr_shuffle import Alexnet_cifar_gr_shuffle
from models.alexnet_flower import Alexnet_flower
from models.alexnet_flower_gr import Alexnet_flower_gr
from models.alexnet_flower_gr_compare import Alexnet_flower_gr_compare
from models.alexnet_flower_gr_shuffle import Alexnet_flower_gr_shuffle
from models.alexnet_miniImage import Alexnet_miniImage
from models.alexnet_miniImage_gr import Alexnet_miniImage_gr
from models.alexnet_miniImage_gr_shuffle import Alexnet_miniImage_gr_shuffle
from models.vgg16_cifar_gr_compare import vgg16_cifar_gr_compare
from models.vgg16_miniImage_gr_compare import vgg16_miniImage_gr_compare
from models.vgg16net_flower import vgg16net_flower
from models.vgg16net_flower_gr import vgg16_flowers_gr
from models.vgg16_cifar_gr import vgg16_cifar_gr
from models.vgg16net_cifar import vgg16net_cifar
from models.vgg16net_flower_gr_compare import vgg16_flowers_gr_compare
from models.vgg16net_miniImage import vgg16net_miniImage
from models.vgg16_miniImage_gr import vgg16_miniImage_gr
def main():
torch.cuda.empty_cache()
flower = (
'daffodil', 'snowdrop', 'lilyValley', 'bluebell', 'crocys', 'iris',
'tigerlily', 'tulip', 'fritillary', 'sunflower', 'daisy', 'colts foot',
'dandelion', 'cowslip', 'buttercup', 'wind flower', 'pansy'
)
cifar=('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# ==========================================================================
# Run an experiment
# Please at least run the following parameters:
# batch_size (int): 4, 64, 128
# num_of_epoch (int): 100
# optimizer (str): "Adam", "SGD"
# learning_rate (float): 1e-4, 1e-3, 1e-2
# model = ResNeXt()
# model_name = "resneXt"
# parameter = {
# "classes": classes,
# "model": model,
# "model_name": model_name,
# "batch_size": 64,
# "num_of_epoch": 100,
# "optimizer": "Adam",
# "learning_rate": 1e-4,
# }
# exp = Experiment(**parameter)
# exp.train_val_evaluate()
opt='Adam'
lrr=2e-5
model = Alexnet_miniImage(100)
model_name = "Alexnet_miniImage_4"
parameter = {
"classes": cifar,
"model": model,
"model_name": model_name,
"batch_size": 64,
"num_of_epoch": 100,
"optimizer": opt,
"learning_rate": lrr,
}
exp = Experiment(**parameter)
exp.train_val_evaluate()
model = Alexnet_miniImage_gr(100)
model_name = "Alexnet_miniImage_gr_4"
parameter = {
"classes": cifar,
"model": model,
"model_name": model_name,
"batch_size": 64,
"num_of_epoch": 100,
"optimizer": opt,
"learning_rate": lrr,
}
exp = Experiment(**parameter)
exp.train_val_evaluate()
# model = Alexnet_flower_gr_compare(17)
# model_name = "Alexnet_flower_gr_compare"
# parameter = {
# "classes": flower,
# "model": model,
# "model_name": model_name,
# "batch_size": 64,
# "num_of_epoch": 100,
# "optimizer": opt,
# "learning_rate": lrr,
# }
# exp = Experiment(**parameter)
# exp.train_val_evaluate()
#
# model =vgg16_cifar_gr_compare()
# model_name = "vgg16_cifar_gr_compare"
# parameter = {
# "classes": cifar,
# "model": model,
# "model_name": model_name,
# "batch_size": 64,
# "num_of_epoch": 50,
# "optimizer": opt,
# "learning_rate": lrr,
# }
# exp = Experiment(**parameter)
# exp.train_val_evaluate()
# model = vgg16_flowers_gr_compare()
# model_name = "vgg16_flowers_gr_compare"
# parameter = {
# "classes": flower,
# "model": model,
# "model_name": model_name,
# "batch_size": 32,
# "num_of_epoch": 100,
# "optimizer": opt,
# "learning_rate": lrr,
# }
# exp = Experiment(**parameter)
# exp.train_val_evaluate()
# model = vgg16_miniImage_gr_compare()
# model_name = "vgg16_miniImage_gr_compare"
# parameter = {
# "classes": flower,
# "model": model,
# "model_name": model_name,
# "batch_size": 64,
# "num_of_epoch": 100,
# "optimizer": opt,
# "learning_rate": lrr,
# }
# exp = Experiment(**parameter)
# exp.train_val_evaluate()
# ==========================================================================
if __name__ == '__main__':
main()