In [2]:
import click

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.losses import CategoricalCrossentropy
import matplotlib.pyplot as plt
import numpy as np
import dataloader

from loguru import logger

import importlib

import tf_models  # First, import the library
importlib.reload(tf_models)  # Now, reload it
from tf_models import *


dataset = 'cifar10'


loader = getattr(dataloader, f"load_{dataset}")
x_train, y_train, x_test, y_test = loader(onehot=True)

if x_train[0].ndim == 2:
    x_train = x_train[..., np.newaxis]
    x_test = x_test[..., np.newaxis]
image_shape = x_train[0].shape  # (28, 28)

model = ResNet(image_shape, num_classes=y_train.shape[1], l2_lambda=4e-4)

model.compile(
    optimizer="adam", loss=CategoricalCrossentropy(), metrics=["accuracy"]
)
if False:
    model.load_weights(f'{dataset}-viz.keras')
logger.info('done')

[32m2023-10-10 02:36:18.934[0m | [1mINFO    [0m | [36mdataloader[0m:[36m_load_keras[0m:[36m30[0m - [1mUsing cifar10 dataset[0m
[32m2023-10-10 02:36:18.935[0m | [1mINFO    [0m | [36mdataloader[0m:[36m_load_keras[0m:[36m33[0m - [1mThe size is 50000[0m
[32m2023-10-10 02:36:18.936[0m | [1mINFO    [0m | [36mdataloader[0m:[36m_load_keras[0m:[36m34[0m - [1mThe shape is: (32, 32, 3)[0m
[32m2023-10-10 02:36:20.547[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m40[0m - [1mdone[0m


In [3]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 32, 32, 3)]          0         []                            
                                                                                                  
 conv2d_6 (Conv2D)           (None, 32, 32, 64)           1792      ['input_2[0][0]']             
                                                                                                  
 batch_normalization_6 (Bat  (None, 32, 32, 64)           256       ['conv2d_6[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 re_lu_4 (ReLU)              (None, 32, 32, 64)           0         ['batch_normalization_6[0]

In [4]:
stage = 'test'

if stage == 'test':
    iter = 4
    epo = 8
else:
    iter = 32
    epo = 16
for i in range(iter):
    with tf.device("/GPU:0"):
        history = model.fit(
            x_train, y_train, epochs=epo, batch_size=64, validation_split=0.02
        )
        model.save(f'{dataset}-viz.keras')

    # Evaluate the model on the test set
    logger.info(f"iter {i}")
    loss, accuracy = model.evaluate(x_test, y_test)
    logger.info(f"Test loss: {loss:.4f}")
    logger.info(f"Test accuracy: {accuracy:.4f}")

Epoch 1/8


2023-10-10 02:36:43.181415: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inmodel/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2023-10-10 02:36:45.664719: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8900
2023-10-10 02:36:46.712192: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55ea4db87200 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-10 02:36:46.712236: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla V100-SXM2-16GB, Compute Capability 7.0
2023-10-10 02:36:46.719816: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-10-10 02:36:46.877562: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled clu

Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8


[32m2023-10-10 02:40:42.901[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m - [1miter 0[0m




[32m2023-10-10 02:40:45.609[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mTest loss: 1.5547[0m
[32m2023-10-10 02:40:45.610[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mTest accuracy: 0.6192[0m


Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8


[32m2023-10-10 02:44:19.007[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m - [1miter 1[0m




[32m2023-10-10 02:44:21.509[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mTest loss: 1.0638[0m
[32m2023-10-10 02:44:21.510[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mTest accuracy: 0.7582[0m


Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8


[32m2023-10-10 02:47:56.344[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m - [1miter 2[0m




[32m2023-10-10 02:47:58.873[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mTest loss: 1.0505[0m
[32m2023-10-10 02:47:58.874[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mTest accuracy: 0.7613[0m


Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8


[32m2023-10-10 02:51:32.568[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m - [1miter 3[0m




[32m2023-10-10 02:51:35.081[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mTest loss: 0.8314[0m
[32m2023-10-10 02:51:35.082[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mTest accuracy: 0.8211[0m


In [None]:
model.save(f'{dataset}-viz.keras')

In [None]:
model.load_weights(f'{dataset}-viz.keras')