In [None]:
!pip install tensorflow tensorflow_datasets prometheus_client --progress-bar off

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import warnings
warnings.filterwarnings('ignore')

tfds.disable_progress_bar()

(mnist_train, mnist_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)


def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255.0, label


mnist_train = mnist_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
mnist_train = mnist_train.cache() \
    .shuffle(ds_info.splits['train'].num_examples) \
    .batch(128) \
    .prefetch(tf.data.experimental.AUTOTUNE)

mnist_test = mnist_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
mnist_test = mnist_test.batch(128) \
    .cache() \
    .prefetch(tf.data.experimental.AUTOTUNE)


model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(
        filters=32, kernel_size=(2, 2), input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=128, activation=tf.nn.relu),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])


model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.summary()

import time
from prometheus_client import CollectorRegistry, Gauge, push_to_gateway
class SendMetrics(tf.keras.callbacks.Callback):
    def __init__(self):
        self.epoch_start_time = 0
        self.batchstart = 0
        self.batchtimes = []
        self.current_time=time.time()
        
        self.pushgateway_url = "http://prom-push-as-pushgateway.apps.zero.massopen.cloud"
        self.registry = CollectorRegistry()
        self.epoch_gauge = Gauge(name='epoch_duration_seconds', documentation='epoch_value is the metric itself, the stuff in the {}s are tags',labelnames=["model","framework","date","epoch"],registry=self.registry)
        self.step_gauge = Gauge(name='step_during_milliseconds', documentation='step_during_milliseconds is the metric itself, the stuff in the {}s are tags',labelnames=["model","framework","date","epoch"],registry=self.registry)

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time=tf.timestamp()
        
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        total=0
        count=0
        for x in self.batchtimes:
            total=total+x
            count=count+1
        
        value_formatter = "{0:.2f}"
        epoch_value = value_formatter.format(tf.timestamp() - self.epoch_start_time)
        step_value = value_formatter.format(total/count*1000)
        print("epoch={}, epoch time={}s, step time={}ms".format(epoch, epoch_value, step_value))
        self.add_metrics_data(epoch,epoch_value,step_value)
        
    def on_train_batch_begin(self, batch, logs=None):
        self.batchstart=tf.timestamp()
        

    def on_train_batch_end(self, batch, logs=None):
        self.batchtimes.append(tf.timestamp()- self.batchstart)   
            
    def add_metrics_data(self, epoch_num, epoch_value, step_value):                  
        self.epoch_gauge.labels("mnist-minimal","Tensorflow",self.current_time,epoch_num).set(epoch_value)
        self.step_gauge.labels("mnist-minimal","Tensorflow",self.current_time,epoch_num).set(step_value)  
        push_to_gateway(self.pushgateway_url, job='jupyterhub_load', registry=self.registry)
        
model.fit(mnist_train, epochs=5, validation_data=mnist_test, callbacks=[SendMetrics()])
