forked from Yacalis/celeba-classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_model.py
136 lines (122 loc) · 4.62 KB
/
build_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 20 15:14:00 2018
@author: Yacalis
"""
from keras.layers import Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.normalization import BatchNormalization
from keras.models import Sequential
from Optimizer import Optimizer
def build_model(input_dim: int, config: object, model_type: str) -> object:
print('Building model...')
if model_type == 'complex':
return build_model_complex(input_dim, config)
elif model_type == 'simple':
return build_model_simple(input_dim, config)
elif model_type == 'celeba':
return build_model_celeba(input_dim, config)
else:
return build_model_single_convo(input_dim, config)
def build_model_single_convo(input_dim, config):
model = Sequential()
model.add(Conv2D(512,
kernel_size=(32, 32),
activation='relu',
input_shape=input_dim))
model.add(MaxPooling2D(pool_size=(8, 8), strides=(8, 8)))
model.add(Flatten())
model.add(Dense(units=1024, activation='relu'))
model.add(Dense(units=2, activation='softmax'))
return compile_model(model, config)
def build_model_simple(input_dim, config):
model = Sequential()
model.add(Conv2D(32,
kernel_size=(3, 3),
activation='relu',
input_shape=input_dim))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32,
kernel_size=(3, 3),
activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(units=128, activation='relu'))
model.add(Dense(units=2, activation='softmax'))
return compile_model(model, config)
def build_model_complex(input_dim, config):
model = Sequential()
model.add(Conv2D(96,
kernel_size=(7, 7),
strides=(4, 4),
activation='relu',
input_shape=input_dim))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(BatchNormalization())
model.add(Conv2D(256,
kernel_size=(5, 5),
strides=(1, 1),
activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(BatchNormalization())
model.add(Conv2D(384,
kernel_size=(3, 3),
strides=(1, 1),
activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))
#model.add(Dense(1, activation='sigmoid'))
return compile_model(model, config)
def build_model_celeba(input_dim, config):
model = Sequential()
model.add(Conv2D(96,
kernel_size=(7, 7),
strides=(4, 4),
activation='relu',
padding='same',
input_shape=input_dim))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(Conv2D(256,
kernel_size=(5, 5),
strides=(1, 1),
activation='relu',
padding='same'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(Conv2D(384,
kernel_size=(3, 3),
strides=(1, 1),
activation='relu',
padding='same'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(3, activation='sigmoid'))
return compile_model(model, config)
def compile_model(model: object, config: object) -> object:
print('Finished building model')
print('Compiling model...')
# set up metrics and optimizer
metrics = ['accuracy']
optimizer = Optimizer(config.optimizer).optimizer
# compile model
if config.complexity == 'celeba':
model.compile(loss='mean_squared_error',
optimizer=optimizer,
metrics=metrics)
else:
model.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=metrics)
print('Finished compiling')
model.summary()
return model