/
fixtures.py
90 lines (73 loc) · 2.39 KB
/
fixtures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Fixtures for ScalarStop tests"""
import os
import unittest
import tensorflow as tf
import scalarstop as sp
requires_external_database = unittest.skipUnless(
os.environ.get("TRAIN_STORE_CONNECTION_STRING", False),
"External database connection string was not supplied.",
)
requires_sqlite_json = unittest.skipIf(
not sp.train_store._sqlite_json_enabled(),
"The SQLite3 JSON1 extension is not enabled in this Python installation.",
)
class MyDataBlob(sp.DataBlob):
"""An example DataBlob for training."""
@sp.dataclass
class Hyperparams:
"""Hyperparams."""
rows: int
cols: int
def _tfdata(self):
"""Generate example data."""
x = tf.random.uniform(
shape=(self.hyperparams.rows, self.hyperparams.cols), dtype=tf.float32
)
y = tf.random.uniform(shape=(self.hyperparams.rows,), dtype=tf.float32)
return tf.data.Dataset.zip(
(
tf.data.Dataset.from_tensor_slices(x),
tf.data.Dataset.from_tensor_slices(y),
)
)
def set_training(self):
return self._tfdata()
def set_validation(self):
return self._tfdata()
def set_test(self):
return self._tfdata()
class MyModelTemplate(sp.ModelTemplate):
"""Example model template."""
@sp.dataclass
class Hyperparams:
"""Hyperparams."""
layer_1_units: int
optimizer: str = "adam"
loss: str = "binary_crossentropy"
def new_model(self):
"""Set a model."""
model = tf.keras.Sequential(
layers=[
tf.keras.layers.Dense(
units=self.hyperparams.layer_1_units,
kernel_initializer="zeros",
bias_initializer="zeros",
),
tf.keras.layers.Dense(
units=1,
activation="sigmoid",
kernel_initializer="zeros",
bias_initializer="zeros",
),
]
)
model.compile(
optimizer=self.hyperparams.optimizer,
loss=self.hyperparams.loss,
metrics=[
tf.keras.metrics.BinaryAccuracy(name="binary_accuracy"),
tf.keras.metrics.Precision(name="precision"),
tf.keras.metrics.Recall(name="recall"),
],
)
return model