# Environment Setup

In [None]:
%load_ext tensorboard

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.applications.resnet import ResNet152
from tensorflow.keras.applications.resnet import (preprocess_input,
                                                  decode_predictions)

import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib widget

from datetime import datetime

INPUT_SHAPE = (256,256,3)
NUM_CLASSES = 4
CLASS_NAMES = ["G","H","K","M"]

TRAIN_FRAC, VAL_FRAC = 0.6, 0.2
BATCH_SIZE=4

# Dataset Overview

In [None]:
dataset = keras.utils.image_dataset_from_directory(
  "./dataset/",
  labels="inferred",
  label_mode="categorical",
  class_names=CLASS_NAMES,
  color_mode="rgb",
  batch_size=1,
  image_size=INPUT_SHAPE[:2],
  seed=0,
  interpolation="bilinear"
)

class_counter = np.zeros_like(CLASS_NAMES, dtype=np.float32)
for (_,label) in dataset.as_numpy_iterator():
  label_idx = label.argmax()
  class_counter[label_idx] = class_counter[label_idx]+1

for (class_name, count) in zip(CLASS_NAMES, class_counter):
  print(f"number of {class_name} samples: {count}")

plt.figure(figsize=(10,10))
for i,(image,label) in enumerate(dataset.take(9)):
  image_array = image[0,:,:,:].numpy().astype(np.uint8)
  pil_image = Image.fromarray(image_array, mode="RGB")
  ax = plt.subplot(3,3,i+1)
  plt.imshow(pil_image)
  plt.title(CLASS_NAMES[np.argmax(label[0,:])])
  plt.axis("off")

plt.savefig("./figures/rgb_overview.png")
plt.close("all")

# Data Augmentation

In [None]:
augmentation_layers = [
  keras.layers.RandomZoom(height_factor=0.1,
                          width_factor=0.1,
                          fill_mode="reflect",
                          interpolation="bilinear",
                          seed=0),
  keras.layers.RandomRotation(factor=0.05,
                             fill_mode="reflect",
                             interpolation="bilinear",
                             seed=0),
  keras.layers.RandomTranslation(height_factor=0.1,
                                 width_factor=0.1,
                                 fill_mode="reflect",
                                 interpolation="bilinear",
                                 seed=0)
]

def data_augmentation(x):
  for layer in augmentation_layers:
    x = layer(x)
  return x

ds_train, ds_test = keras.utils.split_dataset(dataset, TRAIN_FRAC+VAL_FRAC,
                                              seed=0, shuffle=False)

ds_train = ds_train.concatenate(
  ds_train.map(lambda x,y : (data_augmentation(x),y)))
ds_train, ds_val = keras.utils.split_dataset(ds_train,
                                             VAL_FRAC/(TRAIN_FRAC+VAL_FRAC),
                                             seed=False,
                                             shuffle=True)

# Pre-trained Model

In [None]:
keras.backend.clear_session()

def create_model(input_shape:tuple[int,int,int]=INPUT_SHAPE,
                 num_classes:int=NUM_CLASSES) -> keras.Model:
  base_model = ResNet152(weights="imagenet",
                    include_top=False,
                    input_shape=input_shape)
  base_model.trainable = False
  base_model.save("./models/base_model.keras")

  raw_inputs = keras.Input(shape=input_shape, dtype=tf.uint8)
  x = tf.cast(raw_inputs, tf.float32)
  proc_inputs = preprocess_input(x)

  inputs = keras.Input(shape=proc_inputs.shape[1:],
                       dtype=tf.float32)
  x = base_model(inputs, training=False)
  x = keras.layers.GlobalAveragePooling2D()(x)
  x = keras.layers.Dropout(0.2)(x)
  outputs = keras.layers.Dense(num_classes, activation="relu")(x)
  
  return keras.Model(raw_inputs, proc_inputs), keras.Model(inputs, outputs)

pre_processor, model = create_model()
model.summary(show_trainable=True)
model.save("./models/untrained_model.keras")

# Top Layer Training

In [None]:
model.compile(
  optimizer=keras.optimizers.Adam(),
  loss=keras.losses.CategoricalCrossentropy(from_logits=True),
  metrics=[keras.metrics.CategoricalAccuracy()]
)

ds_train = ds_train.map(lambda x,y : (pre_processor(x),y))
ds_val = ds_val.map(lambda x,y : (pre_processor(x),y))

class_weight = {idx: np.sum(class_counter) 
                        /(class_counter[idx]*class_counter.size)
                for idx in range(class_counter.size)}

log_dir = "logs/fit/" + datetime.now().strftime("%Y%m%d_%H%M%S")
tensorboard_callback = \
  keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

model.fit(ds_train, epochs=10,
          class_weight=class_weight,
          validation_data=ds_val,
          callbacks=[tensorboard_callback])

%tensorboard --logdir logs/fit