-
Notifications
You must be signed in to change notification settings - Fork 0
/
embeddings_extraction.py
124 lines (104 loc) · 2.99 KB
/
embeddings_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from tensorflow.keras.callbacks import EarlyStopping
from modules.utils.data_utils.data_handlers import DataGenerator
from modules.utils.model_utils.metrics_losses import smape_k
from modules.utils.general_utils.utilities import load_full_model
from modules.utils.general_utils.utilities import save_full_model
from modules.utils.general_utils.embedding_handlers import extract_embedding
FEATURES_PATH = 'data\\train\\inputs\\continuous_features'
INPUTS = [
'continuous_features',
'context'
]
TARGETS = [
'tar_delta_sessions',
'tar_active_time',
'tar_session_time',
'tar_activity',
'tar_sessions',
]
BTCH = [i for i in range(len(os.listdir(FEATURES_PATH)))]
TS_BTCH = BTCH[0::5]
TR_BTCH = [btch for btch in BTCH if btch not in TS_BTCH]
VL_BTCH = TR_BTCH[0::20]
TR_BTCH = [btch for btch in TR_BTCH if btch not in VL_BTCH]
###############################################################################
encoders = {
'td_mlp': 'sp_dropout_5_global_features'
}
encoder_objs = {}
###############################################################################
for model_name, out_layer in encoders.items():
print(f'Extracting embedding for {model_name}')
model = load_full_model(
name=model_name,
optimizer='adam',
loss={
'output_absence_act': smape_k,
'output_active_act': smape_k,
'output_sess_time_act': smape_k,
'output_activity_act': smape_k,
'output_sess_act': smape_k
},
metrics={
'output_absence_act': smape_k,
'output_active_act': smape_k,
'output_sess_time_act': smape_k,
'output_activity_act': smape_k,
'output_sess_act': smape_k
}
)
TR_GEN = DataGenerator(
list_batches=TR_BTCH,
inputs=INPUTS,
targets=TARGETS,
train=True,
shuffle=True
)
VAL_GEN = DataGenerator(
list_batches=VL_BTCH,
inputs=INPUTS,
targets=TARGETS,
train=True,
shuffle=True
)
ES = EarlyStopping(
monitor='val_loss',
min_delta=0.0001,
patience=10,
verbose=1,
mode='auto',
restore_best_weights=True
)
model.fit(
x=TR_GEN,
validation_data=VAL_GEN,
epochs=200,
verbose=2,
callbacks=[ES],
workers=8,
max_queue_size=100
)
# save the trained model
save_full_model(
model,
path='results\\saved_trained_models\\{}'
)
engagement_encoder = model.get_encoder(
out_layer=out_layer
)
engagement_encoder.save(
f'results\\saved_encoders\\{model_name}_encoder'
)
encoder_objs[f'{model_name}_eng_emb'] = {
'encoder': engagement_encoder,
'inp_names': [
'continuous_features',
'context',
]
}
embeddings = extract_embedding(
encoder_objs=encoder_objs,
root='data\\train\\inputs',
batches=TS_BTCH,
)