# Environment Setup

In [1]:
%load_ext tensorboard

import os
from random import random,seed
seed(0)

from typing import Optional

import tensorflow as tf
from tensorflow import keras

from keras.applications.resnet import ResNet152

import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib widget

from datetime import datetime

INPUT_SHAPE = (256,256,3)
NUM_CLASSES = 4
CLASS_NAMES = ["G","H","K","M"]

TRAIN_FRAC, VAL_FRAC = 0.5, 0.2
AUG_ITR = 5
BATCH_SIZE=4

USE_TRAINED = True
USE_TUNED = True

2023-12-11 14:35:50.998474: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-11 14:35:51.023735: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-11 14:35:51.023755: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-11 14:35:51.024421: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-11 14:35:51.028913: I tensorflow/core/platform/cpu_feature_guar

# Dataset Overview

In [2]:
dataset = keras.utils.image_dataset_from_directory(
  "./dataset/",
  labels="inferred",
  label_mode="categorical",
  class_names=CLASS_NAMES,
  color_mode="rgb",
  batch_size=1,
  image_size=INPUT_SHAPE[:2],
  seed=0,
  interpolation="bilinear"
)

class_counter = np.zeros_like(CLASS_NAMES, dtype=np.float32)
for (_,label) in dataset.as_numpy_iterator():
  label_idx = label.argmax()
  class_counter[label_idx] = class_counter[label_idx]+1

for (class_name, count) in zip(CLASS_NAMES, class_counter):
  print(f"number of {class_name} samples: {count}")

plt.figure(figsize=(10,10))
for i,(image,label) in enumerate(dataset.take(9)):
  image_array = image[0,:,:,:].numpy().astype(np.uint8)
  pil_image = Image.fromarray(image_array, mode="RGB")
  ax = plt.subplot(3,3,i+1)
  plt.imshow(pil_image)
  plt.title(CLASS_NAMES[np.argmax(label[0,:])])
  plt.axis("off")

plt.savefig("./figures/rgb_overview.png")
plt.close("all")

Found 187 files belonging to 4 classes.


2023-12-11 14:35:54.020968: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
2023-12-11 14:35:54.020989: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:129] retrieving CUDA diagnostic information for host: Jorogumo
2023-12-11 14:35:54.020991: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:136] hostname: Jorogumo
2023-12-11 14:35:54.021104: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:159] libcuda reported version is: 535.129.3
2023-12-11 14:35:54.021113: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:163] kernel reported version is: 535.129.3
2023-12-11 14:35:54.021114: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:241] kernel version seems to match DSO: 535.129.3


number of G samples: 68.0
number of H samples: 48.0
number of K samples: 30.0
number of M samples: 41.0


# Data Augmentation

In [3]:
augmentation_layers = [
  keras.layers.RandomZoom(height_factor=0.05,
                          width_factor=0.05,
                          fill_mode="constant",
                          interpolation="bilinear",
                          seed=0),
  keras.layers.RandomRotation(factor=0.05,
                             fill_mode="constant",
                             interpolation="bilinear",
                             seed=0),
  keras.layers.RandomTranslation(height_factor=0.05,
                                 width_factor=0.05,
                                 fill_mode="constant",
                                 interpolation="bilinear",
                                 seed=0)
]

def data_augmentation(x):
  for layer in augmentation_layers:
    x = layer(x)
  return x

ds_train, ds_test = keras.utils.split_dataset(dataset, TRAIN_FRAC+VAL_FRAC,
                                              seed=0, shuffle=False)

ds_train_aux = ds_train.map(lambda x,y : (data_augmentation(x),y))
plt.figure(figsize=(10,10))
for i,(image,label) in enumerate(ds_train_aux.take(9)):
  image_array = image[0,:,:,:].numpy().astype(np.uint8)
  pil_image = Image.fromarray(image_array, mode="RGB")
  ax = plt.subplot(3,3,i+1)
  plt.imshow(pil_image)
  plt.title(CLASS_NAMES[np.argmax(label[0,:])])
  plt.axis("off")
plt.savefig("./figures/aug_overview.png")
plt.close("all")

del ds_train_aux

ds_train_aug = ds_train.map(lambda x,y : (x,y))

for _ in range(AUG_ITR):
  ds_train_aug = ds_train_aug.concatenate(
    ds_train.map(lambda x,y : (data_augmentation(x),y)))

ds_train_aug, ds_val_aug = keras.utils.split_dataset(ds_train_aug,
                                             TRAIN_FRAC/(TRAIN_FRAC+VAL_FRAC),
                                             seed=0,
                                             shuffle=True)

# Pre-trained Model

In [4]:
keras.backend.clear_session()

def create_model(input_shape:tuple[int,int,int]=INPUT_SHAPE,
                  num_classes:int=NUM_CLASSES
                ) -> tuple[keras.Model,Optional[keras.Model]]:
  raw_inputs = keras.Input(shape=input_shape, dtype=tf.uint8)
  x = tf.cast(raw_inputs, tf.float32)
  
  # This little shit can't be serialized, so I had to build a separate model
  proc_inputs = keras.applications.resnet.preprocess_input(x)

  if USE_TRAINED and os.path.exists("./models/trained_model.keras"):
    return keras.Model(raw_inputs, proc_inputs), None

  base_model = ResNet152(weights="imagenet",
                    include_top=False,
                    input_shape=input_shape)
  base_model.trainable = False

  if not os.path.exists("./models/base_model.keras"):
    base_model.save("./models/base_model.keras")

  inputs = keras.Input(shape=proc_inputs.shape[1:],
                       dtype=tf.float32)
  x = base_model(inputs, training=False)
  x = keras.layers.GlobalAveragePooling2D()(x)
  x = keras.layers.Dropout(0.2)(x)
  outputs = keras.layers.Dense(num_classes, activation="relu")(x)
  
  return keras.Model(raw_inputs, proc_inputs), keras.Model(inputs, outputs)

if USE_TRAINED and os.path.exists("./models/trained_model.keras"):
  pre_processor, _ = create_model()
  model = keras.models.load_model("./models/trained_model.keras")
else:
  pre_processor, model = create_model()

model.summary(show_trainable=True)

if not os.path.exists("./models/untrained_model.keras"):
  model.save("./models/untrained_model.keras")

Model: "model_1"
____________________________________________________________________________
 Layer (type)                Output Shape              Param #   Trainable  
 input_3 (InputLayer)        [(None, 256, 256, 3)]     0         Y          
                                                                            
 resnet152 (Functional)      (None, 8, 8, 2048)        5837094   Y          
                                                       4                    
                                                                            
 global_average_pooling2d (  (None, 2048)              0         Y          
 GlobalAveragePooling2D)                                                    
                                                                            
 dropout (Dropout)           (None, 2048)              0         Y          
                                                                            
 dense (Dense)               (None, 4)                 8196

# Top Layer Training

In [None]:
class_weight = {idx: np.sum(class_counter)
                          /(class_counter[idx]*class_counter.size)
                  for idx in range(class_counter.size)}

ds_train_proc = ds_train_aug.map(
  lambda x,y : (pre_processor(x),y)).rebatch(32)
ds_val_proc = ds_val_aug.map(
  lambda x,y : (pre_processor(x),y)).rebatch(32)

if not (USE_TRAINED and os.path.exists("./models/trained_model.keras")):
  model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.CategoricalAccuracy()]
  )

  log_dir = "logs/fit/" + datetime.now().strftime("%Y%m%d_%H%M%S")
  tensorboard_callback = \
    keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

  model.fit(ds_train_proc, epochs=50,
            class_weight=class_weight,
            validation_data=ds_val_proc,
            callbacks=[tensorboard_callback,
                      keras.callbacks.EarlyStopping(patience=5,
                                                    restore_best_weights=True)])

  model.save("./models/trained_model.keras")

%tensorboard --logdir logs/fit

# Initial Benchmark

In [6]:
ds_test_proc = ds_test.map(
  lambda x,y : (pre_processor(x),y)
)
model.evaluate(ds_test_proc)

y_hat = model.predict(ds_test_proc)

i = 0
plt.figure(figsize=(10,10))
for pair,prediction in zip(ds_test, y_hat):
  if i >= 9:
    break
  if random() > 0.5:
    continue

  image,label = pair
  image_array = image[0,:,:,:].numpy().astype(np.uint8)
  pil_image = Image.fromarray(image_array, mode="RGB")
  ax = plt.subplot(3,3,i+1)
  plt.imshow(pil_image)
  plt.title(f"True: {CLASS_NAMES[np.argmax(label[0,:])]}\n" +
            f"Predicted: {CLASS_NAMES[np.argmax(prediction)]}")
  plt.axis("off")

  i += 1

plt.savefig("./figures/trained_model_results.png")
plt.close("all")



# Fine-Tuning

In [None]:
if not (USE_TUNED and os.path.exists("./models/tuned_model.keras")):
  model.get_layer(name="resnet152").trainable = True

  model.compile(
    optimizer=keras.optimizers.Adam(1e-5),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.CategoricalAccuracy()]
  )

  log_dir = "logs/tune/" + datetime.now().strftime("%Y%m%d_%H%M%S")
  tensorboard_callback = \
    keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

  model.fit(ds_train_proc, epochs=50,
            class_weight=class_weight,
            validation_data=ds_val_proc,
            callbacks=[tensorboard_callback,
                      keras.callbacks.EarlyStopping(patience=5,
                                                    restore_best_weights=True)])
  
  model.get_layer(name="resnet152").trainable = False

  model.save("./models/tuned_model.keras")
else:
  model = keras.models.load_model("./models/tuned_model.keras")

%tensorboard --logdir logs/tune

# Final Benchmark

In [9]:
model.evaluate(ds_test_proc)
y_hat = model.predict(ds_test_proc)

i = 0
plt.figure(figsize=(10,10))
for pair,prediction in zip(ds_test, y_hat):
  if i >= 9:
    break
  if random() > 0.5:
    continue

  image,label = pair
  image_array = image[0,:,:,:].numpy().astype(np.uint8)
  pil_image = Image.fromarray(image_array, mode="RGB")
  ax = plt.subplot(3,3,i+1)
  plt.imshow(pil_image)
  plt.title(f"True: {CLASS_NAMES[np.argmax(label[0,:])]}\n" +
            f"Predicted: {CLASS_NAMES[np.argmax(prediction)]}")
  plt.axis("off")

  i += 1

plt.savefig("./figures/tuned_model_results.png")
plt.close("all")

