-
Notifications
You must be signed in to change notification settings - Fork 0
/
models_optimization.py
104 lines (83 loc) · 2.28 KB
/
models_optimization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping
from kerastuner.tuners import Hyperband
from modules.models.supervised.engagement_estimators import MelchiorModel
from modules.models.supervised.baselines import TimeDistributedENet
from modules.models.supervised.baselines import TimeDistributedMLP
from modules.utils.data_utils.data_handlers import DataGenerator
from modules.utils.general_utils.utilities import save_full_model
os.environ['PATH'] += os.pathsep + 'C:\\Program Files (x86)\\Graphviz2.38\\bin'
##############################################################################
TUN_PATH = 'data\\test\\inputs\\context'
VAL_FRAC = 0.2
MAX_EPOCHS = 40
HB_ITERATIONS = 1
BTCH = [i for i in range(len(os.listdir(TUN_PATH)))]
BTCH = np.random.choice(BTCH, len(BTCH), replace=False)
VAL_CUT = int(VAL_FRAC * len(BTCH))
TU_BTCH = BTCH[:-VAL_CUT]
VAL_TU_BTCH = BTCH[-VAL_CUT:]
INPUTS = [
'continuous_features',
'context'
]
TARGETS = [
'tar_delta_sessions',
'tar_active_time',
'tar_session_time',
'tar_activity',
'tar_sessions'
]
MODELS = {
'enet_td': TimeDistributedENet(
n_features=4,
adjust_for_env=False
),
'mlp_td': TimeDistributedMLP(
n_features=4,
adjust_for_env=False
),
'melchior': MelchiorModel(
n_features=4,
adjust_for_env=False
)
}
TU_GEN = DataGenerator(
list_batches=TU_BTCH,
inputs=INPUTS,
targets=TARGETS,
train=True,
shuffle=True
)
VAL_TU_GEN = DataGenerator(
list_batches=VAL_TU_BTCH,
inputs=INPUTS,
targets=TARGETS,
train=True,
shuffle=True
)
##############################################################################
for name, model in MODELS.items():
ES = EarlyStopping(
monitor='val_loss',
min_delta=0.0001,
patience=5,
verbose=1,
mode='auto',
restore_best_weights=True
)
model.tune(
tuner=Hyperband,
generator=TU_GEN,
verbose=2,
validation_data=VAL_TU_GEN,
epochs=MAX_EPOCHS,
max_epochs=MAX_EPOCHS,
hyperband_iterations=HB_ITERATIONS,
objective='val_loss',
callbacks=[ES],
directory='o',
project_name='{}_hb'.format(name[:3])
)
save_full_model(model=model)