-
Notifications
You must be signed in to change notification settings - Fork 0
/
_keras_callbacks.py
115 lines (100 loc) · 4.06 KB
/
_keras_callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Private Keras callbacks that make :py:class:`~scalarstop.model.KerasModel` work."""
from typing import Any, Dict, Mapping, Optional
import tensorflow as tf
def logs_as_floats(logs: Mapping[str, Any]) -> Dict[str, float]:
"""Convert Keras metric log values to floats."""
return {name: float(value) for name, value in logs.items()}
class BatchLoggingCallback(tf.keras.callbacks.Callback):
"""A Keras callback to handle some of the bookkeeping."""
def __init__(
self,
*,
scalarstop_model,
logger,
):
super().__init__()
self._scalarstop_model = scalarstop_model
self._logger = logger
def on_train_batch_end( # pylint: disable=signature-differs
self, batch: int, logs: Dict[str, Any]
) -> None:
"""Enable issuing log messages at the end of every batch."""
super().on_train_batch_end(batch=batch, logs=logs)
float_logs = logs_as_floats(logs)
self._logger.info(
"Trained batch %s for epoch %s for model %s",
batch,
self._scalarstop_model.current_epoch,
self._scalarstop_model.name,
extra=dict(
current_batch=batch,
current_epoch=self._scalarstop_model.current_epoch,
model_name=self._scalarstop_model.name,
training_metrics=float_logs,
),
)
class EpochCallback(tf.keras.callbacks.Callback):
"""A Keras callback to handle some of the bookkeeping."""
def __init__(
self,
*,
scalarstop_model,
logger,
steps_per_epoch: Optional[int] = None,
validation_steps_per_epoch: Optional[int] = None,
models_directory: Optional[str] = None,
train_store=None,
log_epochs: bool = False,
):
super().__init__()
self._scalarstop_model = scalarstop_model
self._models_directory = models_directory
self._train_store = train_store
self._log_epochs = log_epochs
self._logger = logger
self._steps_per_epoch = steps_per_epoch
self._validation_steps_per_epoch = validation_steps_per_epoch
def on_epoch_end( # pylint: disable=signature-differs
self, epoch: int, logs: Dict[str, Any]
) -> None:
"""
Enable various tasks at the end of every epoch, such as:
- saving the model to the filesystem.
- saving epoch metrics to the TrainStore.
- logging epoch metrics to a Python logger.
"""
super().on_epoch_end(epoch=epoch, logs=logs)
# Make sure that metrics are floats and not some
# unserializable data type like tf.Tensor
float_logs = logs_as_floats(logs)
# Append epoch metrics to the model history.
for metric, value in float_logs.items():
if metric in self._scalarstop_model._history:
self._scalarstop_model._history[metric].append(value)
else:
self._scalarstop_model._history[metric] = [value]
# Save the model to the filesystem.
if self._models_directory:
self._scalarstop_model.save(self._models_directory)
# Report the metric to the train store.
if self._train_store:
self._train_store.insert_model_epoch(
epoch_num=self._scalarstop_model.current_epoch,
model_name=self._scalarstop_model.name,
metrics=float_logs,
steps_per_epoch=self._steps_per_epoch,
validation_steps_per_epoch=self._validation_steps_per_epoch,
ignore_existing=True,
)
# Log the epoch.
if self._log_epochs:
self._logger.info(
"Trained epoch %s for model %s",
self._scalarstop_model.current_epoch,
self._scalarstop_model.name,
extra=dict(
current_epoch=self._scalarstop_model.current_epoch,
model_name=self._scalarstop_model.name,
training_metrics=float_logs,
),
)