-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_template.py
73 lines (59 loc) · 2.46 KB
/
model_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
A class that builds models with hyperparameters.
A :py:class:`ModelTemplate` is a description of how to create a
compiled machine learning model and the hyperparameters that the
model depends on.
Subclass :py:class:`ModelTemplate` to describe the *architecture*
of your model and the hyperparameters that are used to construct
the model (and the model's optimizer).
Then, pass an instance of your :py:class:`ModelTemplate` subclass
to as subclass of :py:class:`~scalarstop.model.Model` to train an
instance of a machine learning model created from your
:py:class:`ModelTemplate`.
>>> import tensorflow as tf
>>> import scalarstop as sp
>>>
>>> class small_dense_10_way_classifier_v1(sp.ModelTemplate):
... @sp.dataclass
...
... class Hyperparams(sp.HyperparamsType):
... hidden_units: int
... optimizer: str = "adam"
...
... def new_model(self):
... model = tf.keras.Sequential(
... layers=[
... tf.keras.layers.Flatten(input_shape=(28, 28)),
... tf.keras.layers.Dense(
... units=self.hyperparams.hidden_units,
... activation="relu",
... ),
... tf.keras.layers.Dense(units=10)
... ],
... name=self.name,
... )
... model.compile(
... optimizer=self.hyperparams.optimizer,
... loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
... metrics=["accuracy"],
... )
... return model
>>> model_template = small_dense_10_way_classifier_v1(hyperparams=dict(hidden_units=20))
>>> model_template.name
'small_dense_10_way_classifier_v1-zc9r3do1baeeffafanjnjmou'
"""
from typing import Any
from scalarstop._single_namespace import SingleNamespace
from scalarstop.exceptions import IsNotImplemented
class ModelTemplate(SingleNamespace):
"""Describes machine learning model architectures and hyperparameters. Used to generate new machine learning model objects that are passed into :py:class:`~scalarstop.model.Model` objects.""" # pylint: disable=line-too-long
_model = None
def __repr__(self) -> str:
return f"<sp.ModelTemplate {self.name}>"
def new_model(self) -> Any:
"""
Create a new compiled model with the current hyperparameters.
When you override this method, make sure to create a new model
object every single time this function is called.
"""
raise IsNotImplemented("ModelTemplate.new_model()")