# Tumor Classification in MRI: a Comparision of Neural Network Architectures

**Group:** Songbird  
**Members:** Charlotte de Vries, Jiazhen Tang, Paulo Zirlis  
**Course**: Deep Learning in Python - UvA 2025

In [None]:
# Setup block

import os
import time  # ConvergenceTimer

# --- 1. FORCE LEGACY KERAS (CRITICAL FOR ViT) ---
# This must run before any other keras/tensorflow imports
# To ensure the correct version for the models
os.environ["TF_USE_LEGACY_KERAS"] = "1"

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model, Input, models
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, Callback
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from transformers import TFAutoModelForImageClassification
from transformers import TFViTModel

from sklearn.utils import resample
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split

# --- 3. APPLY THE "INPUT LAYER" PATCH (CRITICAL FOR CNN) ---
# This allows the Keras 2 InputLayer to accept the 'shape' argument
# preventing the crash in CNN code.

OriginalInputLayer = layers.InputLayer

class FriendlyInputLayer(OriginalInputLayer):
    def __init__(self, *args, **kwargs):
        # If 'shape' is used (Keras 3 style), swap it to 'input_shape' (Keras 2 style)
        if 'shape' in kwargs and 'input_shape' not in kwargs:
            kwargs['input_shape'] = kwargs.pop('shape')
        super().__init__(*args, **kwargs)

# Overwrite the class in the 'layers' module so code uses it automatically
layers.InputLayer = FriendlyInputLayer
tf.keras.layers.InputLayer = FriendlyInputLayer
# -----------------------------------------------------------

# Setup timer callback
class ConvergenceTimer(Callback):
    def on_train_begin(self, logs={}):
        self.start_time = time.time()
        print("Training started...")

    def on_train_end(self, logs={}):
        self.end_time = time.time()
        self.total_time = self.end_time - self.start_time
        print("Training finished.")
        print(f"Time to converge: {self.total_time:.2f} seconds ({self.total_time/60:.2f} minutes)")


print("Setup OK: Legacy mode enabled & InputLayer patched.")

In [None]:
# Check if GPU is available
import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))

## 1. Introduction*

Medical imaging has been a prominent field for machine learning techniques in recent years. With the rise of Deep Learning models, such complex tasks have increasingly been tackled by researchers as the techniques become more advances and refined over the years. As novice data scientists, we were curious about the trade-off involved in different model architectures, especially the complexity of their networks, the accuracy they can achieve and the time they take to train and fine-tune. With that in mind, our group deciced to test different deep learning architectures to the visual classification task of identifying brain tumors in MRI images and compare them based on accuracy and time to converge. This work may be helpful to researchers and health professionals who question the applicability and required complexity of these models for image classification in the health industry. As individuals, we hope this projects enhances our data science skills and delve more into the field.

<br><br>
**STILL NEED TO REFINE A BIT MORE!!!!**
<br><br>

We wished to implement and compare the following NN architectures:
- Custom Convolutional Neural Network (CNN)
- Custom Residual Network
- Pre-trained Residual Network (ResNet50)
- Pre-trained Residual Network (ResNet50) with fine-tuning
- Pre-trained Visual Transformer (ViT)
- Pre-trained Visual Transformer (ViT) with fine-tuning

<br>

### Outline

1. Introduction
2. Methods
    - Data
    - Custom Convolutional Neural Network (CNN)
    - Custom Residual Network
    - Pre-trained Residual Network (ResNet50)
    - Pre-trained Residual Network (ResNet50) with fine-tuning
    - Pre-trained Visual Transformer (ViT)
    - Pre-trained Visual Transformer (ViT) with fine-tuning
3. Results
    - Compile Results
    - Comparison Plots
    - Model Insights
4. Discussion
5. References
6. Division of Labour

***

## 2. Methods

### 2.1 Data

The dataset includes high-resolution CT and MRI images captured from multiple patients, with each image labeled with the corresponding tumor type (e.g., glioma, meningioma, etc.). For this project we will focus solely on the **MRI** images for simplicity. The dataset's creator collected these data from different articles to assist researchers and healthcare professionals in developing AI models for the automatic detection, classification, and segmentation of brain tumors. The images were originally organized in two folders: Healthy images and Tumor images; The specific tumor classes were specified in the file names.

The classes sizes are as follows:
- Healty images: 2000
- Tumor images: 3000
    - Meningioma: 1112
    - Glioma: 672
    - Pituitary: 629
    - Tumor: 587
- **Total of images:** 5000

Source: [Brain tumor multimodal image (Kaggle)](https://www.kaggle.com/datasets/murtozalikhon/brain-tumor-multimodal-image-ct-and-mri/data)

#### 2.1.1 Load Data

In [None]:
### Load data

# Setup paths
dataset_path = 'Data/Brain Tumor MRI images'

print(f"Checking contents of: {dataset_path}")
try:
    items = os.listdir(dataset_path)
    print("Found these items:", items)
except: print("Error: The dataset_path does not exist.")

filepaths = []
labels = []



# Get list of all folders in the main directory
all_items = os.listdir(dataset_path)

for item in all_items:
    item_path = os.path.join(dataset_path, item)
    
    # Select only folders (directory)
    if os.path.isdir(item_path):
        
        # --- CASE A: The 'Healthy' Folder ---
        if 'healthy' in item.lower():
            print(f"Processing Healthy folder: {item}")
            for filename in os.listdir(item_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif')):
                    filepaths.append(os.path.join(item_path, filename))
                    labels.append('Healthy')
                    
        # --- CASE B: The 'Tumour' Folder (Anything that isn't Healthy) ---
        else:
            print(f"Processing Tumour folder: {item}")
            for filename in os.listdir(item_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif')):
                    full_path = os.path.join(item_path, filename)
                    name_lower = filename.lower()
                    
                    # Determine Subtype based on filename
                    if 'glioma' in name_lower:
                        label = 'Glioma'
                    elif 'meningioma' in name_lower:
                        label = 'Meningioma'
                    elif 'pituitary' in name_lower:
                        label = 'Pituitary'
                    else:
                        label = 'Tumor (Unspecified)' 
                    
                    filepaths.append(full_path)
                    labels.append(label)

# Create DataFrame
df = pd.DataFrame({'filepath': filepaths, 'label': labels})

# Check results
print(f"Total images found: {len(df)}")
print(df['label'].value_counts())


#### 2.1.2 Oversampling

Since the healthy class had almost double the observations than most of the other classes, we performed an oversampling technique to make all classes achieve a size of 2000 examples.

In [None]:
# 1. Identify the majority class count
max_count = df['label'].value_counts().max()
print(f"Target count per class: {max_count}")

# 2. Separate the dataframe by class
groups = df.groupby('label')

# 3. Create a list to hold the balanced dataframes
balanced_dfs = []

for label, group_df in groups:
    # If this group is smaller than the max, oversample it
    if len(group_df) < max_count:
        
        # Resample logic:
        # replace=True: This allows duplication (essential for oversampling)
        # n_samples=max_count: Target number of samples
        oversampled_group = resample(group_df, 
                                     replace=True, 
                                     n_samples=max_count, 
                                     random_state=42)
        balanced_dfs.append(oversampled_group)
        print(f"Oversampled {label} from {len(group_df)} to {len(oversampled_group)}")
        
    else:
        # If it's the majority class (Healthy), just keep it as is
        balanced_dfs.append(group_df)
        print(f"Kept {label} at {len(group_df)}")

# 4. Concatenate all back into one DataFrame
df_balanced = pd.concat(balanced_dfs).reset_index(drop=True)

# 5. Shuffle the dataset so classes aren't clustered together
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

# 6. Verify the new counts
print("\nNew Label Distribution:")
print(df_balanced['label'].value_counts())

# 7. Update your main df variable
df = df_balanced

#### 2.1.3 Train-test split

For the split, we determined 70% for training, 20% for the test and 10% for validation.

In [None]:
# Split: 80% Train, 20% Test (using stratify to keep classes balanced)
train_df, test_df = train_test_split(df, test_size=0.2, shuffle=True, random_state=42, stratify=df['label'])

# Split Train again to get Validation set (e.g. 10% of total with stratify)
train_df, val_df = train_test_split(train_df, test_size=0.125, shuffle=True, random_state=42, stratify=train_df['label'])

print(f"Train size: {len(train_df)}")
print(f"Val size:   {len(val_df)}")
print(f"Test size:  {len(test_df)}")


# Visualize to confirm labels are correct
plt.figure(figsize=(14, 8))

# Random sample to check
sample_df = df.sample(10)

for i, (index, row) in enumerate(sample_df.iterrows()):
    plt.subplot(2, 5, i + 1)
    img = mpimg.imread(row['filepath'])
    plt.imshow(img, cmap='gray')
    plt.title(f"{row['label']}\n{os.path.basename(row['filepath'])[:10]}...", fontsize=9) # Show label + part of filename
    plt.axis('off')

plt.tight_layout()
plt.show()

#### 2.1.4 CNN Keras Generators

In [None]:
### Images to correct format for CNN

# Define image size and batch size
IMG_SIZE = (256, 256) 
BATCH_SIZE = 32

# Create ImageDataGenerators
train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen   = ImageDataGenerator(rescale=1./255)
test_datagen  = ImageDataGenerator(rescale=1./255)

# Build generators FROM DATAFRAMES
train_gen = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=True
)

val_gen = val_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

test_gen = test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

#### 2.1.5 Custom ResNet Keras Generators

For faster training and less overfitting the inputsize is (128, 128, 1) for the Residual Networks, so we use the same data frames but with a different size.

In [None]:
### Images to correct input size for ResNet

# 1. Define image size and batch size
IMG_SIZE = (128, 128) 
BATCH_SIZE = 32

# 2. Create ImageDataGenerators
train_rn_scratch_ = ImageDataGenerator(rescale=1./255)
val_rn_scratch_  = ImageDataGenerator(rescale=1./255)
test_rn_scratch_  = ImageDataGenerator(rescale=1./255)

# 3. Build generators FROM DATAFRAMES
train_rn_scratch = train_rn_scratch_.flow_from_dataframe(
    dataframe=train_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=True
)

val_rn_scratch = val_rn_scratch_.flow_from_dataframe(
    dataframe=val_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

test_rn_scratch = test_rn_scratch_.flow_from_dataframe(
    dataframe=test_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="grayscale",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

#### 2.1.6 ResNet50 Keras Generators

For the pretrained ResNet-50 model the input should be the same as the ImageNet which is: (224, 224, 3). Therefore we preprocess the data again so that it matches the input the pretrained model expects (ImageNet).

In [None]:
### image format same as ImageNet

IMG_SIZE = (224, 224)
BATCH_SIZE = 32

train_rn_pretrained_ = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

val_rn_pretrained_ = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

test_rn_pretrained_ = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

train_rn_pretrained = train_rn_pretrained_.flow_from_dataframe(
    dataframe=train_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=True
)

val_rn_pretrained = val_rn_pretrained_.flow_from_dataframe(
    dataframe=val_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

test_rn_pretrained = test_rn_pretrained_.flow_from_dataframe(
    dataframe=test_df,
    x_col="filepath",
    y_col="label",
    target_size=IMG_SIZE,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

#### 2.1.7 ViT Keras Generators

ViT requires a specific image format because of the pretrained base, so we also create separate generators with 'channels_first'.

In [None]:
# Setup Generators with 'channels_first'
# We add data_format='channels_first' to match the Hugging Face ViT requirements
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,
    horizontal_flip=True,
    zoom_range=0.1,
    data_format='channels_first'
)

test_datagen = ImageDataGenerator(
    rescale=1./255,
    data_format='channels_first'
)


# Flow from DataFrame
train_vit = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col='filepath',
    y_col='label',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=True
)

val_vit = test_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col='filepath',
    y_col='label',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)

test_vit = test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col='filepath',
    y_col='label',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    shuffle=False,  
    data_format='channels_first'
)

print("Generators recreated with Channels First format.")

***

### 2.2 Custom Convolutional Neural Network

The Convolutional Neural Network (CNN) model designed for this project consists of six convolutional blocks followed by a final block with pooling and fully connected layers. Convolutional blocks 1 through 5 have two instances of convolution layer, batch normalization and activation function (ReLU), followed by max pooling, while the sixth blok has only one instance of convolutional layer, batch nromalization and activation before pooling. Early stopping was added to control for overfitting and underfitting. Batch normalization was used to improve training speed and stability. Regular dropout produces problems when working with images due to the random dropping of pixels which affects the relations of adjacent pixels, so we used Spatial dropout which takes those proximities into account while reducing overfitting. The model was compiled with the Adam optimizer, categorical cross-entropy loss function, and accuracy as the evaluation metric. Additionally, a learning rate scheduler was used to allow for the model to learn quicker at the start and then slowly stabilize better once the loss stops improving. This along with Spatial dropout proved to be very effective in training.

<br>

#### 2.2.1 Architecture

**Input and Data Augmentation**
- Input layer: shape (256, 256, 1);
- Batchsize: 32;
- Data Augmentation: Random rotations and horizontal flips;

**1st to 4th Convolutional Blocks**
- Two 3×3 convolutional layers with Batch Normalization and ReLU;
- Max pooling with kernel 2 and stride 2 for downsampling;
- 1st block has 16 filters, and this doubles with each block (32, 64, 128);

**5th Convolutional Block**
- Two 3×3 convolutional layers with 256 filters, Batch Normalization and ReLU;
- Spatial Dropout 2D of 20%;
- Max pooling with kernel 2 and stride 2 for downsampling;

**6th Convolutional Block**
- One 3×3 convolutional layers with 512 filters, Batch Normalization and ReLU;
- Spatial Dropout 2D of 20%;
- Max pooling with kernel 2 and stride 2 for downsampling;

**Classification Head**
- Global Average Pooling layer;
- Fully connected layer with 128 units;
- Fully connected layer with 64 units;
- Fully connected layer with Softmax activation.

In [None]:
### CNN Architecture

# Seed for reproducibility
np.random.seed(42)

# Custom CNN
CNN = keras.Sequential([
    
    # Input
    layers.InputLayer(shape = [256, 256, 1]),
    
    # Data Augmentation
    layers.RandomFlip("horizontal"), # flip images horizontally
    layers.RandomRotation(0.1),      # rotate images randomly by 10%


    # 1st Convolutional Block
    layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 2nd Convolutional Block
    layers.Conv2D(filters=32, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=32, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 3rd Convolutional Block
    layers.Conv2D(filters=64, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=64, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 4th Convolutional Block
    layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 5th Convolutional Block
    layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.SpatialDropout2D(0.2),
    layers.MaxPool2D(pool_size=2, strides=2),

    # 6th Convolutional Block
    layers.Conv2D(filters=512, kernel_size=3, strides=1, padding='same'),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.SpatialDropout2D(0.2),
    layers.MaxPool2D(pool_size=2, strides=2),

    # Classifier Head
    layers.GlobalAveragePooling2D(),
    layers.Dense(units=128),
    layers.Dense(units=64),
    layers.Dense(units=5, activation='softmax')  # 5 classes
])

CNN.summary()


# Compile the model
CNN.compile(
    optimizer = Adam(0.001),
    loss = 'categorical_crossentropy',
    metrics = ['accuracy']
)

#### 2.2.2 Train

In [None]:
### ============== Callbacks ==============

# Early stopping
early_stopping = EarlyStopping(
    min_delta = 0.001,
    patience = 15,
    restore_best_weights = True
)

# Reduce LR on plateau
lr_scheduler = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,       # Reduce LR by 80% when stuck
    patience=5,       # Wait 5 epochs before reducing
    min_lr=1e-6,      
    verbose=1
)

# Checkpoint
checkpoint = ModelCheckpoint(
    filepath='best_CNN.keras',    # Naming the file
    monitor='val_accuracy',       # What to monitor
    mode='max',                   # 'max' for accuracy, 'min' for loss
    save_best_only=True,          
    verbose=0
)

# Timer
timer_cnn = ConvergenceTimer()



### ============== Training ==============

# Fit the model
hist_CNN = CNN.fit(
    train_gen,
    validation_data = val_gen,
    batch_size = 32,
    epochs = 30,
    callbacks = [early_stopping, lr_scheduler, checkpoint, timer_cnn],
    verbose = 2
)

print("Training complete!")

#### 2.2.3 Evaluate

In [None]:
### ============== Evaluation ==============

# Define function for train loss and accuracy
def eval_train(hist):
    # Loss
    plt.plot(hist.history['loss'], label='train_loss')
    plt.plot(hist.history['val_loss'], label='val_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # Accuracy
    plt.plot(hist.history['accuracy'], label='train_accuracy')
    plt.plot(hist.history['val_accuracy'], label='val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

# Evaluate
eval_train(hist_CNN)

#### 2.2.4 Test

In [None]:
_, accuracy_CNN = CNN.evaluate(test_gen)

***

### 2.3 Custom Residual Network

#### 2.3.1 Architecture

**Input and Data Augmentation**
- Input layer: shape (128, 128, 1)
- Batchsize: 32
- Data Augmentation: Random rotations and horizontal flips

**ResNet Block (function)**
- Two 3×3 convolutional layers with Batch Normalization and ReLU
- Optional 1×1 convolution in the shortcut path to match dimensions
- Residual (skip) connection followed by ReLU activation

**Backbone Architecture**
- Initial convolutional layer followed by multiple residual blocks
- The number of filters increases across stages (16 → 32 → 64)
- Downsampling is performed by using a stride of 2 in the first block of each stage

**Classification Head**
- Global Average Pooling layer
- Fully connected layer with Softmax activation

In [None]:
### =============== RESNET BLOCK ===============

def resnet_block(x, filters, stride=1):
    shortcut = x

    x = layers.Conv2D(filters, 3, strides=stride, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)

    if stride != 1 or shortcut.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, 1, strides=stride, padding="same")(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    return x

In [None]:
## Function to build the model
def build_resnet(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    data_augmentation = keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
    ])

    x = data_augmentation(inputs)

    x = layers.Conv2D(16, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = resnet_block(x, 16)
    x = resnet_block(x, 16)

    x = resnet_block(x, 32, stride=2)
    x = resnet_block(x, 32)

    x = resnet_block(x, 64, stride=2)
    x = resnet_block(x, 64)

    # Resulted in overfitting on the traindata so left it out
    # x = resnet_block(x, 128, stride=2)
    # x = resnet_block(x, 128)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

In [None]:
### =============== MODEL INPUTS ===============

num_classes= 5

model_rn_scratch = build_resnet(
    input_shape=(128, 128, 1),
    num_classes=num_classes
)

#### 2.3.2 Train

In [None]:
### =================== CALLBACKS ===================

# Define early stopping
early_stopping_rn = EarlyStopping(
    min_delta = 0.001,
    patience = 8,
    restore_best_weights = True
)

# Checkpoint
checkpoint = ModelCheckpoint(
    filepath='best_rn_scratch.keras',    # Naming the file
    monitor='val_accuracy',       # What to monitor
    mode='max',                   # 'max' for accuracy, 'min' for loss
    save_best_only=True,          
    verbose=0                        # Print a message when saving
)

# Reduce learning rate when needed
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',    
    factor=0.5,            # multiply learning rate with 0.5 if there is no enhancement
    patience=3,            
    verbose=1,            
    min_lr=1e-6            
)

In [None]:
### =================== TRAINING ===================

## Compile model
model_rn_scratch.compile(
    optimizer=Adam(1e-3),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

## Train model
timer_rn_scratch = ConvergenceTimer()

history_rn_scratch = model_rn_scratch.fit(
    train_rn_scratch,
    validation_data=val_rn_scratch,
    epochs=100,
    callbacks=[checkpoint, early_stopping_rn, reduce_lr, timer_rn_scratch],
    verbose=2
)

#### 2.3.3 Evaluate

In [None]:
eval_train(history_rn_scratch)

#### 2.3.4 Test

In [None]:
_, accuracy_rn_scratch = model_rn_scratch.evaluate(test_rn_scratch)

## Extract the best of the trainset
# and the corresponding accuracy of the validationset
train_accs = history_rn_scratch.history['accuracy']
val_accs   = history_rn_scratch.history['val_accuracy']
best_train_epoch = train_accs.index(max(train_accs))
val_acc_at_best_train = val_accs[best_train_epoch]

print(f"Best train accuracy of the best model: {train_accs[best_train_epoch]:.4f}")
print(f"Validation accuracy of the best model: {val_acc_at_best_train:.4f}")
print(f"Test accuracy of the best model: {accuracy_rn_scratch:.4f}")
print(f"Epoch: {best_train_epoch+1}")

**Conclusion:** After 50 epochs, the model converged, achieving an accuracy of 0.91 on the training set, 0.87 on the validation set, and 0.88 on the test set. This suggests a slight overfitting to the training data. However, the relatively small gap between training, validation, and test accuracies indicates that the model generalizes reasonably well to unseen data.

### 2.4 Pretrained ResNet **without finetuning**

#### 2.4.1 Architecture

**Input and Data Augmentation**
- Input layer: shape (224, 224, 3)
- Batchsize: 32
- Data Augmentation: Random rotations and horizontal flips

**Backbone Architecture**
- Pretrained base from the ResNet-50, trained on ImageNet

**Classification Head**
- Global Average Pooling layer
- Fully connected layer with Softmax activation


In [None]:
from tensorflow.keras.applications import ResNet50

## pretrained base
base_model = ResNet50(
    weights='imagenet',
    include_top=False,
    input_shape=(224, 224, 3)
)
base_model.trainable = False   # Freeze weights

## attach head
model_rn_pt = models.Sequential([

    layers.RandomFlip("horizontal"), # data augmentation
    layers.RandomRotation(0.1),

    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(5, activation='softmax')
])

#### 2.4.2 Train

In [None]:
## Callbacks
# Checkpoint
checkpoint = ModelCheckpoint(
    filepath='best_rn_pt.keras',    # Naming the file
    monitor='val_accuracy',       # What to monitor
    mode='max',                   # 'max' for accuracy, 'min' for loss
    save_best_only=True,          
    verbose=0                     # Print a message when saving
)

## Compile model
model_rn_pt.compile(
    optimizer=Adam(1e-3),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

## Train model
timer_rn_pt = ConvergenceTimer() # keep track of training time

history_rn_pt = model_rn_pt.fit(
    train_rn_pretrained,
    validation_data=val_rn_pretrained,
    epochs=50,
    callbacks=[checkpoint, reduce_lr, early_stopping, timer_rn_pt],
    verbose=2
)

#### 2.4.3 Evaluate

In [None]:
eval_train(history_rn_pt)

#### 2.4.4 Test

In [None]:
## Fit the best model on th test dat
_, accuracy_rn_pt = model_rn_pt.evaluate(test_rn_pretrained)

## Extract the accuracy of the best model on the trainset
# and the corresponding accuracy of the validationset
train_accs = history_rn_pt.history['accuracy']
val_accs   = history_rn_pt.history['val_accuracy']
best_train_epoch = train_accs.index(max(train_accs))
val_acc_at_best_train = val_accs[best_train_epoch]

print(f"Best train accuracy of the best model: {train_accs[best_train_epoch]:.4f}")
print(f"Validation accuracy of the best model: {val_acc_at_best_train:.4f}")
print(f"Test accuracy of the best model: {accuracy_rn_pt:.4f}")
print(f"Epoch: {best_train_epoch+1}")

**Conclusion:** The model has an accuracy of 0.91 on the training set, 0.87 on the validation set, and 0.88 on the test set. This indicates a slight degree of overfitting to the training data. However, the relatively small gap between training, validation, and test performance suggests that the model generalizes reasonably well to unseen data. The model performs similar as the custom ResNet, but the pretrained model converged at epoch 37 instead of 50. 

***

### 2.5 Pretrained ResNet **with finetuning**

#### 2.5.1 Architecture

**Input and Data Augmentation**
- Input layer: shape (224, 224, 3)
- Batchsize: 32
- Data Augmentation: Random rotations and horizontal flips

**Backbone Architecture**
- Pretrained base from the ResNet-50, trained on ImageNet
- The first 150 layers of the base are froozen
- The last 27 layers of the base are unfroozen, so their weights can be trained on our data

**Classification Head**
- Global Average Pooling layer
- Fully connected layer with Softmax activation

In [None]:
# Unfreeze the model
base_model.trainable = True

# Freeze the first 150 layers "freeze"
for layer in base_model.layers[:150]:
    layer.trainable = False

## attach head
model_rn_pt_finetuned = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(5, activation='softmax')
])

Here we unfreeze the last 37 of 177 layers of the ResNet50 so that the weights of these layers can be trained. The first 140 layers stay froozen (untrained).

#### 2.5.2 Train

In [None]:
## Callbacks
# Checkpoint

checkpoint = ModelCheckpoint(
    filepath='best_rn_pt_finetuned.keras',    # Naming the file
    monitor='val_accuracy',       # What to monitor
    mode='max',                   # 'max' for accuracy, 'min' for loss
    save_best_only=True,          
    verbose=0                     # Print a message when saving
)

## Compile model
model_rn_pt_finetuned.compile(
    optimizer=Adam(1e-5),    
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

## Train model
timer_rn_pt_finetuned = ConvergenceTimer()

history_rn_pt_finetuned = model_rn_pt_finetuned.fit(
    train_rn_pretrained,
    validation_data=val_rn_pretrained,
    epochs=50,
    callbacks=[checkpoint, early_stopping_rn, reduce_lr, timer_rn_pt_finetuned],
    verbose=2
    )

#### 2.5.3 Evaluate

In [None]:
eval_train(history_rn_pt_finetuned)

#### 2.5.4 Test

In [None]:
## Fit the best model on the testset
_, accuracy_rn_pt_finetuned = model_rn_pt_finetuned.evaluate(test_rn_pretrained)

## Extract the accuracy of the best model on the trainset
# and the corresponding accuracy of the validationset
train_accs = history_rn_pt_finetuned.history['accuracy']
val_accs   = history_rn_pt_finetuned.history['val_accuracy']
best_train_epoch = train_accs.index(max(train_accs))
val_acc_at_best_train = val_accs[best_train_epoch]

print(f"Best train accuracy of the best model: {train_accs[best_train_epoch]:.4f}")
print(f"Validation accuracy of the best model: {val_acc_at_best_train:.4f}")
print(f"Test accuracy of the best model: {accuracy_rn_pt_finetuned:.4f}")
print(f"Epoch: {best_train_epoch+1}")

**Conclusion:** After 13 epochs, the model converged, achieving an accuracy of 0.97 on the training set, 0.95 on the validation set, and 0.94 on the test set. This indicates a slight degree of overfitting to the training data. However, the relatively small gap between training, validation, and test performance suggests that the model generalizes reasonably well to unseen data. The model has the highest accuracy compared to the custom ResNet and pretrained ResNet without finetung indicating that fine tuning enhances the model. 

### 2.6 Pre-trained Vision Transformer **without finetuning**

#### 2.6.1 Architecture

**Input and Data Augmentation**

- Input layer: Shape (224, 224, 3) (implied by the model ID vit-base-patch16-224)
- Batchsize: 32 (Standard assumption based on the previous context, though not explicitly in this snippet)
- Data Augmentation: Random rotations and horizontal flips (Standard assumption based on the previous context)

**Backbone Architecture** 
- Pretrained base: Vision Transformer (ViT) Base model (google/vit-base-patch16-224)
- Patching: Images are split into 16x16 patches
- Parameters: ~86 Million parameters
- Source: Pre-trained on ImageNet-21k (implied by the Hugging Face model ID)

**Classification Head**
- Output Layer: Linear projection layer adjusted for num_classes (5 classes)
- Activation: Softmax (implicitly handled via from_logits = True in the loss function during training)
- Optimizer: Adam with a learning rate of 5e-5

In [None]:
# Ensure the number of classes (should be 5)
num_classes = len(train_vit.class_indices)

# Load the ViT Base Model (86M Parameters)
vit_id = "google/vit-base-patch16-224"

vit = TFAutoModelForImageClassification.from_pretrained(
    vit_id, 
    num_labels=num_classes, 
    ignore_mismatched_sizes=True,
    use_safetensors=False
)


# Compile the model
# jit_compile=True will help speed up this heavy model on the GPU
vit.compile(
    optimizer = Adam(learning_rate=5e-5), # ViT requires a small learning rate (5e-5)
    loss = CategoricalCrossentropy(from_logits=True), 
    metrics=['accuracy'],
    jit_compile=True 
)

print(f"Successfully loaded and compiled {vit_id}")

#### 2.6.2 Train

In [None]:
### ================ CALLBACKS ================

early_stopping = EarlyStopping(
    monitor='val_loss', 
    patience=3, 
    restore_best_weights=True
)

timer_vit = ConvergenceTimer()


### ================ TRAIN ================

history_vit = vit.fit(
    train_vit,
    epochs=10, 
    validation_data=val_vit,
    callbacks=[early_stopping, timer_vit]
)

#### 2.6.3 Evaluate

In [None]:
eval_train(history_vit)

#### 2.6.4 Test

In [None]:
_, accuracy_vit = vit.evaluate(test_vit)

Conclusion: The model has an accuracy of 0.93 on the training set and 0.91 on the validation set (at the best epoch). This indicates a slight degree of overfitting to the training data, as the training loss continues to decrease while the validation loss fluctuates. However, the gap remains small, suggesting the model generalizes well. Notably, the pretrained model achieved high performance very quickly, reaching ~90% validation accuracy by just the second epoch and converging in less (11) epochs. However, its running time is the longest among all models (36.4).

***

### 2.7 Pre-trained ViT **with finetuning**

#### 2.7.1 Preparations
**Stage 1 (Frozen)**

- Action: The ViT backbone is set to trainable = False. A custom classification head (Dense 512 $\to$ Dropout $\to$ Dense 256 $\to$ Dropout $\to$ Output) is added. The model is compiled with a standard learning rate (1e-3) and trained for 5 epochs.

- Goal: To force the new, randomly initialized "Custom Head" layers to stabilize and learn basic patterns without disturbing the pre-trained weights of the ViT Backbone. This prevents "catastrophic forgetting" of the features the backbone already knows.

<br>

**Stage 2 (Unfrozen)** 

- Action: The backbone is set to trainable = True. The model is re-compiled with a much lower learning rate (1e-5) to ensure gentle updates. It is then trained for up to 10 more epochs, utilizing Early Stopping to prevent overfitting.

- Goal: Now that the head is stable, we let the whole network (Backbone + Head) train together. The low learning rate allows the backbone to "fine-tune" its features specifically for this dataset, perfecting the results without destroying the pre-trained knowledge.

In [None]:
# Define global timer for 2 stages
class CumulativeTimer(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.total_time = 0.0
        self.stage_start_time = 0.0

    def on_train_begin(self, logs={}):
        self.stage_start_time = time.time()
        print("Starting training stage...")

    def on_train_end(self, logs={}):
        end_time = time.time()
        duration = end_time - self.stage_start_time
        self.total_time += duration
        print(f"Stage finished in {duration:.2f} seconds.")
        print(f"Total Cumulative Time so far: {self.total_time:.2f} seconds ({self.total_time/60:.2f} minutes)")

# Instantiate ONE timer object to re-use across both .fit() calls
global_timer = CumulativeTimer()

#### 2.7.2 Stage 1

In [None]:
# SETUP BACKBONE
backbone = TFViTModel.from_pretrained(
    "google/vit-base-patch16-224",
    use_safetensors=False
)
backbone.trainable = False # Freeze the backbone

# 2. BUILD MODEL (Same architecture as before)
inputs = Input(shape=(3, 224, 224), name="input_image")

def get_vit_features(x):
    return backbone(x).last_hidden_state

x = layers.Lambda(get_vit_features)(inputs)
x = x[:, 0, :] # CLS token

# Custom Head
x = layers.Dense(512, activation='relu')(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, name="prediction_head")(x)

model_two_stage = Model(inputs=inputs, outputs=outputs, name="ViT_Two_Stage")


# Compile
model_two_stage.compile(
    optimizer=Adam(learning_rate=1e-3),   # Standard learning rate to stabilize head
    loss=CategoricalCrossentropy(from_logits=True), 
    metrics=['accuracy']
)


# ==================== TRAIN ====================
print("\n--- STAGE 1: Training Head Only (Frozen Backbone) ---")

train_steps = len(train_df) // 32
val_steps = len(val_df) // 32

history_stage1 = model_two_stage.fit(
    train_vit,
    epochs=5,  
    validation_data=val_vit,
    batch_size = 32,
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    # No early stopping needed here
    callbacks=[global_timer]
)

#### 2.7.3 Stage 2

In [None]:
# Unfreeze backbone
backbone.trainable = True
print("\nBackbone unfrozen. The model now has", model_two_stage.count_params(), "trainable parameters.")

# Re-compile
model_two_stage.compile(
    optimizer = Adam(learning_rate=1e-5), # low LR to avoid breaking pre-trained weights
    loss = CategoricalCrossentropy(from_logits=True), 
    metrics = ['accuracy']
)

# Early Stopping to prevent overfitting the backbone
early_stopping_stage2 = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)


# ==================== TRAIN ====================
history_stage2 = model_two_stage.fit(
    train_vit,
    epochs = 10,
    validation_data = val_vit,
    steps_per_epoch = train_steps,
    validation_steps = val_steps,
    callbacks = [early_stopping_stage2,global_timer]
)

#### 2.7.4 Evaluate

In [None]:
eval_train(history_stage2)

#### 2.7.4 Test

In [None]:
_, accuracy_vit_finetuned = model_two_stage.evaluate(test_vit)

***

Conclusion

The two-stage training approach successfully stabilized the model before fine-tuning, resulting in consistent performance gains. In Stage 1 (Frozen Backbone), the model quickly learned the new classification task, jumping from 54% to 81% accuracy in just 5 epochs. This confirmed that the custom head was successfully initialized without distorting the pre-trained features.

In Stage 2 (Fine-Tuning), unfreezing the backbone and applying a low learning rate further improved performance. The model refined its accuracy from ~82% to 86.7% on the training set and 86.1% on the validation set. The final validation loss (0.33) and accuracy remained very close to the training metrics, indicating that the model generalizes well with minimal overfitting. 

## 3. Results

### 3.2 Compile results

In [None]:
### DF with all accuracies and times

all_models = pd.DataFrame({
    'Model': [
        'Custom CNN', 
        'Custom ResNet', 
        'Pretrained ResNet', 
        'Fine-tuned ResNet', 
        'Pretrained ViT', 
        'Fine-tuned ViT'
    ],
    'Test Accuracy': [
        accuracy_CNN, 
        accuracy_rn_scratch, 
        accuracy_rn_pt, 
        accuracy_rn_pt_finetuned, 
        accuracy_vit, 
        accuracy_vit_finetuned
    ],
    'Time to Converge': [
        timer_cnn.total_time, 
        timer_rn_scratch.total_time, 
        timer_rn_pt.total_time, 
        timer_rn_pt_finetuned.total_time, 
        timer_vit.total_time, 
        global_timer.total_time
    ]
})

# ============================= COMMENT OUT ================================
# BACKUP DATA in case the notebook has crashed.
# backup_data = {
#     'Model': ['Custom CNN', 'Custom ResNet', 'Pretrained ResNet', 'Fine-tuned ResNet', 'Pretrained ViT', 'Fine-tuned ViT'],
#     'Test Accuracy': [0.9020, 0.8685, 0.8630, 0.9390, 0.9150, 0.7615],
#     'Time to Converge': [10.58, 4.38, 8.95, 4.45, 36.41, 22.59]
# }
# all_models = pd.DataFrame(backup_data)
# ==========================================================================


# Sort by Accuracy
all_models = all_models.sort_values("Test Accuracy", ascending=False)

### 3.? Comparison Plots

In [None]:
# ============== COMPARISON PLOTTING ==============

### COLORS & LOGIC

# Define plot colors
acc_color_light = "#b3cde0"  # Light Blue
acc_color_strong = "#005b96" # Strong Blue

time_color_light = "#fbb4ae" # Light Pink
time_color_strong = "#e31a1c" # Strong Red

# Find the best models
best_acc_model = all_models.loc[all_models['Test Accuracy'].idxmax(), 'Model']
best_time_model = all_models.loc[all_models['Time to Converge'].idxmin(), 'Model']

# Map highlighted color to best models
acc_palette = {model: acc_color_strong if model == best_acc_model else acc_color_light 
               for model in all_models['Model']}

time_palette = {model: time_color_strong if model == best_time_model else time_color_light 
                for model in all_models['Model']}


### PLOTTING

sns.set_theme(style="whitegrid")                 # Seaborn theme
fig, axes = plt.subplots(1, 2, figsize=(16, 6))  # 2 side-by-side plots

# ============= Plot 1: Accuracy =============
sns.barplot(
    data=all_models,
    x="Test Accuracy",
    y="Model",
    ax=axes[0],
    hue="Model",
    palette=acc_palette,   # Pass custom accuracy color map
    legend=False
)
axes[0].set_title("Model Accuracy", fontsize=18, fontweight='bold')
axes[0].set_xlabel("Accuracy", fontsize=14)
axes[0].set_ylabel("")
axes[0].set_xlim(0, 1.1)
axes[0].tick_params(axis='y', which='major', labelsize=16)

# Add labels
for container in axes[0].containers:
    axes[0].bar_label(container, fmt=('%.2f'), padding=3, fontsize=16)


# ============= Plot 2: Time =============
sns.barplot(
    data=all_models,
    x="Time to Converge",
    y="Model",
    ax=axes[1],
    hue="Model",
    palette=time_palette,  # Pass custom time color map
    legend=False
)
axes[1].set_title("Time to Convergence", fontsize=18, fontweight='bold')
axes[1].set_xlabel("Time (minutes)", fontsize=14)
axes[1].set_ylabel("")
axes[1].tick_params(axis='y', which='major', labelsize=16)
axes[1].set_xlim(0, 40)

# Add labels
for container in axes[1].containers:
    axes[1].bar_label(container, fmt='%.1f', padding=3, fontsize=16)

plt.tight_layout()
plt.show()

INTERPRET RESULTS

<br>

***

## 4. Conclusion and Discussion

#### **Generalizability**

The dataset consists of MRI images collected from multiple published studies, likely originating from different institutions and populations. Training on such heterogeneous data may enhance the model’s generalizability by reducing reliance on dataset- or scanner-specific features. However, since images from the same sources may appear in both training and evaluation sets, external validation on unseen institutions would be required to fully assess real-world generalizability.

#### **Oppertunities for improving**
Optimization of the models (specifically finetuning) and training/testing on more and diverse data.

#### **Important take aways for practitioners and researchers**

Based on our results the fine-tuned ResNet model is the best model. It is the most accurate and only slightly slower than the fastest model. The ResNet is sensible as convolutional architecture inherently understands shapes and edges in tumor. For the ViT model adding fine-tuning layers failed to improve performance of the model, likely due to overfitting, given our small data size and excessive (1 million extra) parameters.
Moreover, it would be interesting to test hybrid models combining CNN's, capturing local patterns, and Vision Transformers, capturing global patterns. According to Takahashi et al. (2025) combining ResNet and ViT outperformed other existing models in classifying brain tumors in MRI.This might result in a model that is able to handle more diverse and complex image datasets making tumor classification more reliable in the real world.

## 5. References

Shanto, M. N. I., Mubtasim, M. T., Rakshit, S. V., & Ullah, M. A. (2025). Enhanced Classification of Brain Tumors from MRI Scans Using a Hybrid CNN-Transformer Model. 2025 International Conference on Quantum Photonics, Artificial Intelligence, and Networking (QPAIN), 1–6. https://doi.org/10.1109/qpain66474.2025.11171896

Takahashi, S., Sakaguchi, Y., Kouno, N., Takasawa, K., Ishizu, K., Akagi, Y., Aoyama, R.,      Teraya, N., Bolatkan, A., Shinkai, N., Machino, H., Kobayashi, K., Asada, K., Komatsu, M., Kaneko, S., Sugiyama, M., & Hamamoto, R. (2024). Comparison of Vision Transformers and Convolutional Neural Networks in Medical Image Analysis: A Systematic Review. Journal Of Medical Systems, 48(1), 84. https://doi.org/10.1007/s10916-024-02105-8

## 6. Division of labor

**Charlotte de Vries:** Preprocessing data, Residual network models construction, training and evaluation, Conclusion and Discussion, Poster presentation design

**Jiazhen Tang:** Preprocessing data, Data setup (data structure and over-sampling), Vision Transformer models construction, training and evaluation

**Paulo Zirlis:** Preprocessing data, organising the Github environment, Introduction, Convolutional Neural Network model construction, training and evaluation + evaluation of all models