-
Notifications
You must be signed in to change notification settings - Fork 0
/
save_callback.py
31 lines (26 loc) · 1.06 KB
/
save_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import matplotlib.pyplot as plt
import os
from keras.callbacks import Callback
class CustomCallback(Callback):
"""
epoch이 끝날 때마다 loss 그래프와 weight를 저장하기 위한 콜백
save_path 내부에 저장이 됨
"""
def __init__(self, model, save_path, val_gen):
super(CustomCallback, self).__init__()
self.model = model
self.val_gen = val_gen
self.save_path = save_path
self.val_loss = []
self.train_loss = []
def on_epoch_end(self, epoch, logs=None):
self.train_loss.append(logs['loss'])
self.val_loss.append(self.model.evaluate_generator(self.val_gen, len(self.val_gen)))
plt.clf()
plt.plot(list(range(epoch+1)), self.val_loss, label='val')
plt.plot(list(range(epoch+1)), self.train_loss, label='train')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend()
plt.savefig(os.path.join(self.save_path, 'loss.png'))
self.model.save_weights(os.path.join(self.save_path, 'model_weights_{}_epoch.h5'.format(epoch)))