In [None]:
#default_exp train

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#export
from unoai.imports import *

In [None]:
#export
def train_model(train_ds: tf.data.Dataset, test_ds: tf.data.Dataset, 
                epochs: int, model_fn: Callable, 
                opt_fn: Callable, loss_fn: Callable, model: tf.keras.Model=None):
    
    if model is None: model = model_fn()
    opt = opt_fn()
    # todo: standardize metrics
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
    
    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images)
            loss = loss_fn(labels, predictions, reduction=tf.compat.v1.losses.Reduction.SUM)
        gradients = tape.gradient(loss, model.trainable_variables)
        opt.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(loss)
        train_accuracy(labels, predictions)
    
    @tf.function
    def test_step(images, labels):
        predictions = model(images)
        t_loss = loss_fn(labels, predictions, reduction=tf.compat.v1.losses.Reduction.SUM)
        test_loss(t_loss)
        test_accuracy(labels, predictions)
    
    for epoch in tqdm(range(epochs)):
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

        for images, labels in train_ds:
            train_step(images, labels)

        for test_images, test_labels in test_ds:
            test_step(test_images, test_labels)

        template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
        print(template.format(epoch+1,
                            train_loss.result(),
                            train_accuracy.result()*100,
                            test_loss.result(),
                            test_accuracy.result()*100))
    
    return model