In [None]:
import click

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.losses import CategoricalCrossentropy
import matplotlib.pyplot as plt
import numpy as np
import dataloader

from loguru import logger

import importlib

import tf_models  # First, import the library
importlib.reload(tf_models)  # Now, reload it
from tf_models import *

import os
import datetime


dataset = 'cifar10'

def make_tb(name):
    prefix = name
    log_dir = os.path.join(
        "logs", prefix + "-" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    )
    return tf.keras.callbacks.TensorBoard(
        log_dir=log_dir, histogram_freq=1, update_freq="batch"
    )

loader = getattr(dataloader, f"load_{dataset}")
x_train, y_train, x_test, y_test = loader(onehot=True)

if x_train[0].ndim == 2:
    x_train = x_train[..., np.newaxis]
    x_test = x_test[..., np.newaxis]
image_shape = x_train[0].shape  # (28, 28)

model = ResNet(image_shape, num_classes=y_train.shape[1], l2_lambda=4e-3, augmentation=True)

model.compile(
    optimizer="adam", loss=CategoricalCrossentropy(), metrics=["accuracy"]
)

model.load_weights(f'{dataset}-viz.keras', skip_mismatch=True)
logger.info('done')

In [None]:
model.summary()

In [None]:
stage = 'prod'

if stage == 'test':
    iter = 4
    epo = 8
else:
    iter = 2
    epo = 100
for i in range(iter):
    with tf.device("/GPU:0"):
        history = model.fit(
            x_train, y_train, epochs=epo, batch_size=64, 
            # validation_split=0.02,
            validation_data=(x_test, y_test),
            callbacks=[make_tb("nb-resnet-" + dataset + "-augmented")],
        )
        model.save(f'{dataset}-viz.keras')
        logger.info("saved model")

    # Evaluate the model on the test set
    loss, accuracy = model.evaluate(x_test, y_test)
    logger.info(f"iter {i}")
    logger.info(f"Test loss: {loss:.4f}")
    logger.info(f"Test accuracy: {accuracy:.4f}")

In [None]:
model.save(f'{dataset}-viz.keras')

In [None]:
model.load_weights(f'{dataset}-viz.keras')