# Deep Learning Project

**Group:** Songbird  
**Members:** Charlotte de Vries, Jiazhen Tang, Paulo Zirlis

In [None]:
# Setup block
import os
import time  # Added this because ConvergenceTimer needs it

# --- 1. FORCE LEGACY KERAS (CRITICAL FOR ViT) ---
# This must run before any other keras/tensorflow imports
os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 0=all, 1=no Info, 2=no Info/Warnings, 3=Errors only

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

# --- 2. IMPORTS ---
# We prioritize tensorflow.keras to ensure compatibility
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, Callback
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import transformers
from transformers import TFAutoModelForImageClassification
from transformers import TFViTModel

from sklearn.utils import resample
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split

# --- 3. APPLY THE "INPUT LAYER" PATCH (CRITICAL FOR CNN) ---
# This allows the Keras 2 InputLayer to accept the 'shape' argument
# preventing the crash in CNN code.

OriginalInputLayer = layers.InputLayer

class FriendlyInputLayer(OriginalInputLayer):
    def __init__(self, *args, **kwargs):
        # If 'shape' is used (Keras 3 style), swap it to 'input_shape' (Keras 2 style)
        if 'shape' in kwargs and 'input_shape' not in kwargs:
            kwargs['input_shape'] = kwargs.pop('shape')
        super().__init__(*args, **kwargs)

# Overwrite the class in the 'layers' module so code uses it automatically
layers.InputLayer = FriendlyInputLayer
tf.keras.layers.InputLayer = FriendlyInputLayer
# -----------------------------------------------------------

# Setup timer
class ConvergenceTimer(Callback):
    def on_train_begin(self, logs={}):
        self.start_time = time.time()
        print("Training started...")

    def on_train_end(self, logs={}):
        self.end_time = time.time()
        self.total_time = self.end_time - self.start_time
        print(f"\nTraining finished.")
        print(f"Time to converge: {self.total_time:.2f} seconds ({self.total_time/60:.2f} minutes)")


print("Setup OK: Legacy mode enabled & InputLayer patched.")

In [None]:
# Check if GPU is available
import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))

## 1. Project Overview

### 1.1 Objective*

Apply different deep learning architectures to the visual classification task of identifying brain tumors in MRI images and compare them based on accuracy and time to converge.


### 1.2 Neural Network Architectures

We will implement and compare the following architectures:
- Custom Convolutional Neural Network (CNN) with keras sequential
- Custom Residual Network (ResNet)
- Pre-trained Residual Network (ResNet50)
- Pre-trained Residual Network (ResNet50) with fine-tuning
- Pre-trained Visual Transformer (ViT)
- Pre-trained Visual Transformer (ViT) with fine-tuning


### 1.3 Dataset Description

The dataset includes high-resolution CT and MRI images captured from multiple patients, with each image labeled with the corresponding tumor type (e.g., glioma, meningioma, etc.). For this project we will focus solely on the **MRI** images for simplicity. The dataset's creator collected these data from different sources to assist researchers and healthcare professionals in developing AI models for the automatic detection, classification, and segmentation of brain tumors.

The images are divided as follows:
- Healty images: 2000
- Tumor images: 3000
    - Meningioma: 1112
    - Glioma: 672
    - Pituitary: 629
    - Tumor: 587
- **Total of images:** 5000

Source: [Brain tumor multimodal image (Kaggle)](https://www.kaggle.com/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/data)

***

## 2. Data Preprocessing

### 2.1 Load data

In [None]:
# Load data

# 1. SETUP PATHS
dataset_path = 'Data/Brain Tumor MRI images'

print(f"Checking contents of: {dataset_path}")
try:
    items = os.listdir(dataset_path)
    print("Found these items:", items)
except:
    print("Error: The dataset_path does not exist.")

filepaths = []
labels = []



# Get list of all folders in the main directory
all_items = os.listdir(dataset_path)

for item in all_items:
    item_path = os.path.join(dataset_path, item)
    
    # We only care if it's a folder (directory)
    if os.path.isdir(item_path):
        
        # --- CASE A: The 'Healthy' Folder ---
        if 'healthy' in item.lower():
            print(f"Processing Healthy folder: {item}")
            for filename in os.listdir(item_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif')):
                    filepaths.append(os.path.join(item_path, filename))
                    labels.append('Healthy')
                    
        # --- CASE B: The 'Tumour' Folder (Anything that isn't Healthy) ---
        else:
            print(f"Processing Tumour folder: {item}")
            for filename in os.listdir(item_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif')):
                    full_path = os.path.join(item_path, filename)
                    name_lower = filename.lower()
                    
                    # Determine Subtype based on filename
                    if 'glioma' in name_lower:
                        label = 'Glioma'
                    elif 'meningioma' in name_lower:
                        label = 'Meningioma'
                    elif 'pituitary' in name_lower:
                        label = 'Pituitary'
                    else:
                        label = 'Tumor (Unspecified)' 
                    
                    filepaths.append(full_path)
                    labels.append(label)

# Create DataFrame
df = pd.DataFrame({'filepath': filepaths, 'label': labels})

# Check results
print(f"Total images found: {len(df)}")
print(df['label'].value_counts())


### 2.2 Oversampling

In [None]:
# 1. Identify the majority class count
max_count = df['label'].value_counts().max()
print(f"Target count per class: {max_count}")

# 2. Separate the dataframe by class
groups = df.groupby('label')

# 3. Create a list to hold the balanced dataframes
balanced_dfs = []

for label, group_df in groups:
    # If this group is smaller than the max, oversample it
    if len(group_df) < max_count:
        
        # Resample logic:
        # replace=True: This allows duplication (essential for oversampling)
        # n_samples=max_count: Target number of samples
        oversampled_group = resample(group_df, 
                                     replace=True, 
                                     n_samples=max_count, 
                                     random_state=42)
        balanced_dfs.append(oversampled_group)
        print(f"Oversampled {label} from {len(group_df)} to {len(oversampled_group)}")
        
    else:
        # If it's the majority class (Healthy), just keep it as is
        balanced_dfs.append(group_df)
        print(f"Kept {label} at {len(group_df)}")

# 4. Concatenate all back into one DataFrame
df_balanced = pd.concat(balanced_dfs).reset_index(drop=True)

# 5. Shuffle the dataset so classes aren't clustered together
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

# 6. Verify the new counts
print("\nNew Label Distribution:")
print(df_balanced['label'].value_counts())

# 7. Update your main df variable
df = df_balanced

### 2.3 Train-test split

In [None]:
# 2. Split: 80% Train, 20% Test (using stratify to keep classes balanced)
train_df, test_df = train_test_split(df, test_size=0.2, shuffle=True, random_state=42, stratify=df['label'])

# 3. Split Train again to get Validation set (e.g. 10% of total)
train_df, val_df = train_test_split(train_df, test_size=0.125, shuffle=True, random_state=42, stratify=train_df['label'])

print(f"Train size: {len(train_df)}")
print(f"Val size:   {len(val_df)}")
print(f"Test size:  {len(test_df)}")

# 4. Visualize to confirm labels are correct

plt.figure(figsize=(14, 8))

# Get a random sample to check
sample_df = df.sample(10)

for i, (index, row) in enumerate(sample_df.iterrows()):
    plt.subplot(2, 5, i + 1)
    img = mpimg.imread(row['filepath'])
    plt.imshow(img, cmap='gray')
    plt.title(f"{row['label']}\n{os.path.basename(row['filepath'])[:10]}...", fontsize=9) # Show label + part of filename
    plt.axis('off')

plt.tight_layout()
plt.show()

### 2.4 Build Keras generators

In [None]:
### Images to correct format

# 1. Define image size and batch size
IMG_SIZE = (256, 256) 
BATCH_SIZE = 32

# 2. Create ImageDataGenerators
train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen   = ImageDataGenerator(rescale=1./255)
test_datagen  = ImageDataGenerator(rescale=1./255)

# 3. Build generators FROM DATAFRAMES
train_gen = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=True
)

val_gen = val_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

test_gen = test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

***

## 3. Custom Convolutional Neural Network

The Convolutional Neural Network (CNN) model designed for this project consists of four convolutional blocks followed by a final block with pooling, dropout and fully connected layers. Each convolutional block has a convolution layer, batch normalization, activation function (ReLU) and max pooling. Early stopping was added to control for overfitting and underfitting. Batch normalization was used to improve training speed and stability. Dropout was included in the final block to further prevent overfitting. The model was compiled with the Adam optimizer, categorical cross-entropy loss function, and accuracy as the evaluation metric.

<br>

### 3.1 CNN Architecture

**Input and Data Augmentation**
- Input layer: shape (256, 256, 1)
- Data Augmentation: Random rotations and horizontal flips

**First Convolutional Block**
- Conv2d layer: 32 filters, 3x3 kernel, stride of 1, same padding
- Batch Normalization
- ReLU Activation
- MaxPooling2d layer: 2x2 pool size, stride of 2.

**Other Convolutional Blocks**
- same as the first block but with increasing number of filters (64, 128, 256)

**Classifier Head**
- Global Average Pooling layer
- Dense layer: 64 units, ReLU activation
- Dropout layer: 0.3 dropout rate
- Dense layer: 5 units (nÂº of classes), Softmax activation

In [None]:
### CNN Architecture

# Seed for reproducibility
np.random.seed(42)

# Custom CNN
CNN = keras.Sequential([
    
    # Input
    layers.InputLayer(shape=[256, 256, 1]),
    
    # Data Augmentation
    layers.RandomFlip("horizontal"), # flip images horizontally
    layers.RandomRotation(0.1),      # rotate images randomly by 10%


    # 1st Convolutional Block
    layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 2nd Convolutional Block
    layers.Conv2D(filters=32, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=32, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 3rd Convolutional Block
    layers.Conv2D(filters=64, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=64, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 4th Convolutional Block
    layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 5th Convolutional Block
    layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.SpatialDropout2D(0.2),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 6th Convolutional Block
    layers.Conv2D(filters=512, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.SpatialDropout2D(0.2),
    layers.MaxPool2D(pool_size=2, strides=2),

    # Classifier Head
    layers.GlobalAveragePooling2D(),
    layers.Dense(units=128),
    layers.Dense(units=64),
    layers.Dense(units=5, activation='softmax')  # 5 classes
])

CNN.summary()


# Compile the model
CNN.compile(
    optimizer = Adam(0.001),
    loss = 'categorical_crossentropy',
    metrics = ['accuracy']
)

### 3.2 Train

In [None]:
### ============== Callbacks ==============

# Early stopping
early_stopping = EarlyStopping(
    min_delta = 0.001,
    patience = 15,
    restore_best_weights = True
)

# Reduce LR on plateau
lr_scheduler = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,       # Reduce LR by 80% when stuck
    patience=5,       # Wait 5 epochs before reducing
    min_lr=1e-6,      
    verbose=1
)

# Checkpoint
checkpoint = ModelCheckpoint(
    filepath='best_CNN.keras',    # Naming the file
    monitor='val_accuracy',       # What to monitor
    mode='max',                   # 'max' for accuracy, 'min' for loss
    save_best_only=True,          
    verbose=0                     # Print a message when saving
)

# Timer
timer_cnn = ConvergenceTimer()



### ============== Training ==============

# Fit the model
hist_CNN = CNN.fit(
    train_gen,
    validation_data = val_gen,
    batch_size = 32,
    epochs = 30,
    callbacks = [early_stopping, lr_scheduler, checkpoint, timer_cnn],
    verbose = 2
)

print("Training complete!")

### 3.3 Evaluate

In [None]:
### ============== Evaluation ==============

# Loss
plt.plot(hist_CNN.history['loss'], label='train_loss')
plt.plot(hist_CNN.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Accuracy
plt.plot(hist_CNN.history['accuracy'], label='train_accuracy')
plt.plot(hist_CNN.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

### 3.4 Test

In [None]:
_, accuracy_CNN = CNN.evaluate(test_gen)

***

## 4 Residual Networks
3 ResNet models:
- ResNet from scratch
- ResNet pretrained without finetuning
- ResNet pretrained with finetuning

#### 4.1 **ResNet from scratch**

#### 4.1.1 Preprocessing data
For fast training and less overfitting the inputsize is (128, 128, 1)

In [None]:
### Images to correct format

# 1. Define image size and batch size
IMG_SIZE = (128, 128) 
BATCH_SIZE = 32

# 2. Create ImageDataGenerators
train_rn_scratch_ = ImageDataGenerator(rescale=1./255)
val_rn_scratch_  = ImageDataGenerator(rescale=1./255)
test_rn_scratch_  = ImageDataGenerator(rescale=1./255)

# 3. Build generators FROM DATAFRAMES
train_rn_scratch = train_rn_scratch_.flow_from_dataframe(
    dataframe=train_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=True
)

val_rn_scratch = val_rn_scratch_.flow_from_dataframe(
    dataframe=val_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

test_rn_scratch = test_rn_scratch_.flow_from_dataframe(
    dataframe=test_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

#### 4.1.2 Create ResNet block 

In [None]:
def resnet_block(x, filters, stride=1):
    shortcut = x

    x = layers.Conv2D(filters, 3, strides=stride, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)

    if stride != 1 or shortcut.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, 1, strides=stride, padding="same")(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    return x

#### 4.1.3 Create the model

In [None]:
def build_resnet(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    data_augmentation = keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
    ])

    x = data_augmentation(inputs)

    x = layers.Conv2D(16, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = resnet_block(x, 16)
    x = resnet_block(x, 16)

    x = resnet_block(x, 32, stride=2)
    x = resnet_block(x, 32)

    x = resnet_block(x, 64, stride=2)
    x = resnet_block(x, 64)

    # x = resnet_block(x, 128, stride=2)
    # x = resnet_block(x, 128)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

In [None]:
num_classes=5

model_rn_scratch = build_resnet(
    input_shape=(128, 128, 1),
    num_classes=num_classes
)

#### 4.1.4 Train the model

In [None]:
# Define early stopping
from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(
    min_delta = 0.001,
    patience = 20,
    restore_best_weights = True
)

# Checkpoint
from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(
    filepath='best_ResNet.keras',    # Naming the file
    monitor='val_accuracy',       # What to monitor
    mode='max',                   # 'max' for accuracy, 'min' for loss
    save_best_only=True,          
    verbose=0                     # Print a message when saving
)

# Reduce learning rate when needed
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',    
    factor=0.5,            # multiply learning rate with 0.5 if there is no enhancement
    patience=3,            
    verbose=1,            
    min_lr=1e-6            
)

In [None]:
## Compile model
model_rn_scratch.compile(
    optimizer=Adam(1e-3),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

## Train model
timer_rn_scratch = ConvergenceTimer()

history_rn_scratch = model_rn_scratch.fit(
    train_rn_scratch,
    validation_data=val_rn_scratch,
    epochs=25,
    callbacks=[checkpoint, early_stopping, reduce_lr, timer_rn_scratch],
    verbose=2
)

#### 4.1.5 Visualize the performance on the validation set

In [None]:
history_frame = pd.DataFrame(history_rn_scratch.history)
history_frame.loc[:, ['loss', 'val_loss']].plot()
history_frame.loc[:, ['accuracy', 'val_accuracy']].plot();

#### 4.1.6 Evaluate model on the test set

In [None]:
_, accuracy_rn_scratch = model_rn_scratch.evaluate(test_rn_scratch)

#### 4.2 **Pretrained ResNet without finetuning**


#### 4.2.1 Preprocessing data
For the pretrained ResNet-50 model the input should be the same as the ImageNet which is: (224, 224, 3). Therefore we make new train, validation and test sets so that it matches the input the pretrained model expects.

In [None]:
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator

IMG_SIZE = (224, 224)
BATCH_SIZE = 32

train_rn_pretrained_ = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

val_rn_pretrained_ = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

test_rn_pretrained_ = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

train_rn_pretrained = train_rn_pretrained_.flow_from_dataframe(
    dataframe=train_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=True
)

val_rn_pretrained = val_rn_pretrained_.flow_from_dataframe(
    dataframe=val_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

test_rn_pretrained = test_rn_pretrained_.flow_from_dataframe(
    dataframe=test_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)


#### 4.2.2 Create the model

In [None]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, models

## pretrained base
base_model = ResNet50(
    weights='imagenet',
    include_top=False,
    input_shape=(224, 224, 3)
)
base_model.trainable = False   # Freeze weights

## attach head
model_rn_pt = models.Sequential([

    layers.RandomFlip("horizontal"), # data augmentation
    layers.RandomRotation(0.1),

    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(5, activation='softmax')
])

#### 4.2.3 Train the model

In [None]:
## Compile model
model_rn_pt.compile(
    optimizer=Adam(1e-3),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

## Train model
timer_rn_pt = ConvergenceTimer() # keep track of training time

history_rn_pt = model_rn_pt.fit(
    train_rn_pretrained,
    validation_data=val_rn_pretrained,
    epochs=25,
    callbacks=[checkpoint, early_stopping, timer_rn_pt],
    verbose=2
)

#### 4.2.4 Visualizing the loss and accuracy for the validation set

In [None]:
# Loss
plt.plot(history_rn_pt.history['loss'], label='train_loss')
plt.plot(history_rn_pt.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Accuracy
plt.plot(history_rn_pt.history['accuracy'], label='train_accuracy')
plt.plot(history_rn_pt.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

### 4.2.5 Test

In [None]:
_, accuracy_rn_pt = model_rn_pt.evaluate(test_rn_pretrained)

### 4.3 ResNet **with finetuning**

#### 4.3.1 Create the model with fine tuning
Here we unfreeze the last 37 of 177 layers of the ResNet50 so that the weights of these layers can be trained. The first 140 layers stay froozen (untrained).

In [None]:
# Unfreeze the model
base_model.trainable = True

# Freeze the first 140 layers "freeze"
for layer in base_model.layers[:140]:
    layer.trainable = False

## attach head
model_rn_pt_finetuned = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(5, activation='softmax')
])

#### 4.3.2 Train the model

In [None]:
## Compile model
model_rn_pt_finetuned.compile(
    optimizer=Adam(1e-5),    
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

## Train model
timer_rn_pt_finetuned = ConvergenceTimer()

history_rn_pt_finetuned = model_rn_pt_finetuned.fit(
    train_rn_pretrained,
    validation_data=val_rn_pretrained,
    epochs=10,
    callbacks=[checkpoint, early_stopping, timer_rn_pt_finetuned]
)

#### 4.3.3 Visualize the performance after fine tuning for the validation set

In [None]:
import pandas as pd
history_frame = pd.DataFrame(history_rn_pt_finetuned.history)
history_frame.loc[:, ['loss', 'val_loss']].plot()
history_frame.loc[:, ['accuracy', 'val_accuracy']].plot();

#### 4.3.4 Evaluate the model on the test set

In [None]:
_, accuracy_rn_pt_finetuned = model_rn_pt_finetuned.evaluate(test_rn_pretrained)

***

## 5. Pre-trained Vision Transformer

### 5.1 ViT **without finetuning**

#### 5.1.1 Preprocessing data

In [None]:
# 1. Setup Generators with 'channels_first'
# We add data_format='channels_first' to match the Hugging Face ViT requirements
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,
    horizontal_flip=True,
    zoom_range=0.1,
    data_format='channels_first'
)

test_datagen = ImageDataGenerator(
    rescale=1./255,
    data_format='channels_first'
)


# 2. Flow from DataFrame
train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col='filepath',
    y_col='label',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=True
)

val_generator = test_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col='filepath',
    y_col='label',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)

test_generator = test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col='filepath',
    y_col='label',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False,  
    data_format='channels_first'
)

print("Generators recreated with Channels First format.")

# 3. Now run the training code again

early_stopping = EarlyStopping(
    monitor='val_loss', 
    patience=3, 
    restore_best_weights=True
)

timer_vit = ConvergenceTimer()

#### 5.1.2 Model Construction

In [None]:
# 1. Ensure we know the number of classes (should be 5)
num_classes = len(train_generator.class_indices)

# 2. Load the ViT Base Model (86M Parameters)
vit_id = "google/vit-base-patch16-224"

vit = TFAutoModelForImageClassification.from_pretrained(
    vit_id, 
    num_labels=num_classes, 
    ignore_mismatched_sizes=True,
    use_safetensors=False
)

# 3. Define Optimizer and Loss (Using tf_keras to avoid version errors)
# ViT requires a small learning rate (5e-5)
optimizer = Adam(learning_rate=5e-5)
loss = CategoricalCrossentropy(from_logits=True)

# 4. Compile the model
# jit_compile=True will help speed up this heavy model on the GPU
vit.compile(
    optimizer=optimizer, 
    loss=loss, 
    metrics=['accuracy'],
    jit_compile=True 
)

print(f"Successfully loaded and compiled {vit_id}")

#### 5.1.3 Train

In [None]:
history_vit = vit.fit(
    train_generator,
    epochs=10, 
    validation_data=val_generator,
    callbacks=[early_stopping, timer_vit]
)

### 5.1.4 Test

In [None]:
_, accuracy_vit = vit.evaluate(test_generator)

### 5.2 ViT **with finetuning**

- **Stage 1 (Frozen)**: Set backbone.trainable = False. Train for 5-10 epochs. 
- Goal: This forces the new layers to "calm down" and learn basic patterns without disturbing the Backbone. 

<br>

- **Stage 2 (Unfrozen)**: Set backbone.trainable = True. Re-compile with 1e-5 LR. Train for 5 more epochs. 
- Goal: Now that the new layers are stable, you unfreeze the Backbone to let the whole team work together to perfect the results.

<br>

#### 5.2.1 Stage 1

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import EarlyStopping, Callback
from transformers import TFViTModel


class CumulativeTimer(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.total_time = 0.0
        self.stage_start_time = 0.0

    def on_train_begin(self, logs={}):
        self.stage_start_time = time.time()
        print("Starting training stage...")

    def on_train_end(self, logs={}):
        end_time = time.time()
        duration = end_time - self.stage_start_time
        self.total_time += duration
        print(f"Stage finished in {duration:.2f} seconds.")
        print(f"Total Cumulative Time so far: {self.total_time:.2f} seconds ({self.total_time/60:.2f} minutes)")

# Instantiate ONE timer object to re-use across both .fit() calls
global_timer = CumulativeTimer()



# 1. SETUP & FREEZE BACKBONE
backbone = TFViTModel.from_pretrained(
    "google/vit-base-patch16-224",
    use_safetensors=False
)
backbone.trainable = False  # <--- FREEZE HERE

# 2. BUILD MODEL (Same architecture as before)
inputs = Input(shape=(3, 224, 224), name="input_image")

def get_vit_features(x):
    return backbone(x).last_hidden_state

x = layers.Lambda(get_vit_features)(inputs)
x = x[:, 0, :] # CLS token

# Custom Head
x = layers.Dense(512, activation='relu')(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, name="prediction_head")(x)

model_two_stage = Model(inputs=inputs, outputs=outputs, name="ViT_Two_Stage")

# 3. COMPILE & TRAIN (STAGE 1)
# Use a standard learning rate (e.g., 1e-3 or 1e-4) to learn the head quickly
optimizer_stage1 = Adam(learning_rate=1e-3)
loss = CategoricalCrossentropy(from_logits=True)

model_two_stage.compile(optimizer=optimizer_stage1, loss=loss, metrics=['accuracy'])

train_steps = len(train_df) // 32
val_steps = len(val_df) // 32

print("\n--- STAGE 1: Training Head Only (Frozen Backbone) ---")
history_stage1 = model_two_stage.fit(
    train_generator,
    epochs=5,  
    validation_data=val_generator,
    batch_size = 32,
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    # No early stopping needed here usually, just a quick warm-up
    callbacks=[global_timer]
)

#### 5.2.2 Stage 2

In [None]:
# 1. UNFREEZE BACKBONE
backbone.trainable = True # <--- UNFREEZE HERE
print("\nBackbone unfrozen. The model now has", model_two_stage.count_params(), "trainable parameters.")

# 2. RE-COMPILE
# Crucial: Use a very low learning rate (1e-5) to avoid breaking pre-trained weights
optimizer_stage2 = Adam(learning_rate=1e-5) 

model_two_stage.compile(optimizer=optimizer_stage2, loss=loss, metrics=['accuracy'])

# 3. TRAIN (STAGE 2)
print("\n--- STAGE 2: Fine-Tuning Whole Model (Low LR) ---")

# We use Early Stopping here to prevent overfitting the backbone
early_stopping_stage2 = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

history_stage2 = model_two_stage.fit(
    train_generator,
    epochs=10, # Add more epochs if needed
    validation_data=val_generator,
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    callbacks=[early_stopping_stage2,global_timer]
)

### 5.2.3 Test

In [None]:
_, accuracy_vit_finetuned = model_two_stage.evaluate(test_generator)

***

## 6. Evaluation

### 6.1 Compile results

In [None]:
### DF with all accuracies and times

all_models = pd.DataFrame({
    'Model': [
        'Custom CNN', 
        'Custom ResNet', 
        'Pretrained ResNet', 
        'Fine-tuned ResNet', 
        'Pretrained ViT', 
        'Fine-tuned ViT'
    ],
    'Test Accuracy': [
        accuracy_CNN, 
        accuracy_rn_scratch, 
        accuracy_rn_pt, 
        accuracy_rn_pt_finetuned, 
        accuracy_vit, 
        accuracy_vit_finetuned
    ],
    'Time to Converge': [
        timer_cnn.total_time, 
        timer_rn_scratch.total_time, 
        timer_rn_pt.total_time, 
        timer_rn_pt_finetuned.total_time, 
        timer_vit.total_time, 
        global_timer.total_time
    ]
})

# ============================= COMMENT OUT ================================
# BACKUP DATA in case the notebook has crashed.
# backup_data = {
#     'Model': ['Custom CNN', 'Custom ResNet', 'Pretrained ResNet', 'Fine-tuned ResNet', 'Pretrained ViT', 'Fine-tuned ViT'],
#     'Test Accuracy': [0.9020, 0.8685, 0.8630, 0.9390, 0.9150, 0.7615],
#     'Time to Converge': [10.58, 4.38, 8.95, 4.45, 36.41, 22.59]
# }
# all_models = pd.DataFrame(backup_data)
# ==========================================================================


# Sort by Accuracy
all_models = all_models.sort_values("Test Accuracy", ascending=False)

### 6.2 Comparison Plots

In [None]:
# ============== COMPARISON PLOTTING ==============

### COLORS & LOGIC

# Define plot colors
acc_color_light = "#b3cde0"  # Light Blue
acc_color_strong = "#005b96" # Strong Blue

time_color_light = "#fbb4ae" # Light Pink
time_color_strong = "#e31a1c" # Strong Red

# Find the best models
best_acc_model = all_models.loc[all_models['Test Accuracy'].idxmax(), 'Model']
best_time_model = all_models.loc[all_models['Time to Converge'].idxmin(), 'Model']

# Map highlighted color to best models
acc_palette = {model: acc_color_strong if model == best_acc_model else acc_color_light 
               for model in all_models['Model']}

time_palette = {model: time_color_strong if model == best_time_model else time_color_light 
                for model in all_models['Model']}


### PLOTTING

sns.set_theme(style="whitegrid")                 # Seaborn theme
fig, axes = plt.subplots(1, 2, figsize=(16, 6))  # 2 side-by-side plots

# ============= Plot 1: Accuracy =============
sns.barplot(
    data=all_models,
    x="Test Accuracy",
    y="Model",
    ax=axes[0],
    hue="Model",
    palette=acc_palette,   # Pass custom accuracy color map
    legend=False
)
axes[0].set_title("Model Accuracy", fontsize=18, fontweight='bold')
axes[0].set_xlabel("Accuracy", fontsize=14)
axes[0].set_ylabel("")
axes[0].set_xlim(0, 1.1)
axes[0].tick_params(axis='y', which='major', labelsize=16)

# Add labels
for container in axes[0].containers:
    axes[0].bar_label(container, fmt=('%.2f'), padding=3, fontsize=16)


# ============= Plot 2: Time =============
sns.barplot(
    data=all_models,
    x="Time to Converge",
    y="Model",
    ax=axes[1],
    hue="Model",
    palette=time_palette,  # Pass custom time color map
    legend=False
)
axes[1].set_title("Time to Convergence", fontsize=18, fontweight='bold')
axes[1].set_xlabel("Time (minutes)", fontsize=14)
axes[1].set_ylabel("")
axes[1].tick_params(axis='y', which='major', labelsize=16)
axes[1].set_xlim(0, 40)

# Add labels
for container in axes[1].containers:
    axes[1].bar_label(container, fmt='%.1f', padding=3, fontsize=16)

plt.tight_layout()
plt.show()