# Model Comparison in The Domain of Brain Tumor Image Classification
For our final project for Spring 2025, CS 4644 - Deep Learning, we analyze and compare the results from three unique models:
1. 3D CNN - Turning 2D images into 3D datapoints to reconstruct a full brain image.
2. ResNet18 - Applying transfer learning by taking a pretrained ResNet-18 model (trained on ImageNet) and adapting it to MRI scans through the fine-tuning of a final fully connected layer.
3. Inception - Applying transfer learning in the same way as ResNet, but for another popular and successful model.

The following code allows the reader to experiment with these 3 models and observe their results.

# Step 0: Get necessary imports and set global variables

/Users/willakins/Downloads/project-folder/Git/TumorTrace/TumorTrace/notebooks


In [None]:
# Basic imports
import sys
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt

from data.Image_Loader import MRIDataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import pickle
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Model imports
from src.models.CNN_3D.model import CNN_3D
from src.models.ResNet.model import MyResNet
from src.models.Inception.model import MyInception

# Helper imports
from utils.utils import compute_mean_and_std
from utils.utils import save_trained_model_weights
from utils.utils import convert_pickle_to_folder
from src.runner import Trainer
from src.optimizer import get_optimizer
from data.data_transforms import (
    get_fundamental_transforms,
    get_fundamental_normalization_transforms,
    get_fundamental_augmentation_transforms,
    get_all_transforms,
)
from utils.confusion_matrix import (
    generate_confusion_data,
    generate_confusion_matrix,
    plot_confusion_matrix,
    get_pred_images_for_target,
    generate_and_plot_confusion_matrix,
    generate_and_plot_accuracy_table,
)

# Global variables
data_path = "data/processed_mri"
model_path = "src/models"

dataset_mean, dataset_std = compute_mean_and_std(data_path)

batch_size = 32
num_classes = 3

In [None]:
convert_pickle_to_folder(
    pickle_path="../data/archive/brain_tumor_mri/new_dataset/training_data.pickle",
    output_dir="../data/processed_mri"
)

In [None]:
model_resnet = MyResNet()
print(model_resnet)

In [None]:


# Data preloaded from pickle file
with open(data_path, 'rb') as file:
    loaded_data = pickle.load(file)

# Unpacking the data into the images and their corresponding labels
images, labels = zip(*loaded_data)
dataset = MRIDataset(images, labels, transformations, model_type=None)

# Splitting the dataset into training and testing
training_size = int(.8 * len(dataset))
testing_size = len(dataset) - training_size
training_dataset, testing_dataset = random_split(dataset, [training_size, testing_size])

# Two separate loaders for training and testing
train_loader = DataLoader(training_dataset, batch_size=16, shuffle=True)
testing_loader = DataLoader(testing_dataset, batch_size=16, shuffle=True)

# Debug comment out later
for sample_image, sample_label in train_loader:
    print(f"Image shape: {sample_image.shape}")
    print(f"Label: {sample_label}")
    break


# Step 1: Test 3D Convolutional Nerual Network

In [None]:
inp_size = (224,224) # Double check

In [None]:
model_cnn = CNN_3D(num_classes=num_classes)

cnn_optimizer_config = {"optimizer_type": "adam", "lr": 1e-3, "weight_decay": 1e-8} # Tune these
cnn_optimizer = get_optimizer(model_cnn, cnn_optimizer_config)

cnn_trainer = Trainer(
    data_dir=data_path,
    model=model_cnn,
    optimizer=cnn_optimizer,
    model_dir=os.path.join(model_path, "CNN_3D"),
    train_data_transforms=get_all_transforms(inp_size, [dataset_mean], [dataset_std]),
    val_data_transforms=get_fundamental_normalization_transforms(
        inp_size, [dataset_mean], [dataset_std]
    ),
    batch_size=batch_size,
    load_from_disk=False,
    cuda=torch.cuda.is_available(),
)

In [None]:
%%time
cnn_trainer.run_training_loop(num_epochs=5)

In [None]:
cnn_trainer.plot_loss_history()
cnn_trainer.plot_accuracy()

cnn_train_accuracy = cnn_trainer.train_accuracy_history[-1]
cnn_validation_accuracy = cnn_trainer.validation_accuracy_history[-1]
print(
    "Train Accuracy = {}; Validation Accuracy = {}".format(
        cnn_train_accuracy, cnn_validation_accuracy
    )
)

In [None]:
save_trained_model_weights(model_cnn, out_dir=os.path.join(model_path, "CNN_3D"))

# Step 2: Test ResNet Pretrained Model

In [None]:
inp_size = (224,224) # Double check

In [None]:
model_resnet = MyResNet(num_classes=num_classes)

resnet_optimizer_config = {"optimizer_type": "adam", "lr": 1e-3, "weight_decay": 1e-8} # Tune these
resnet_optimizer = get_optimizer(model_resnet, resnet_optimizer_config)

resnet_trainer = Trainer(
    data_dir=data_path,
    model=model_resnet,
    optimizer=resnet_optimizer,
    model_dir=os.path.join(model_path, "ResNet"),
    train_data_transforms=get_all_transforms(inp_size, [dataset_mean], [dataset_std]),
    val_data_transforms=get_fundamental_normalization_transforms(
        inp_size, [dataset_mean], [dataset_std]
    ),
    batch_size=batch_size,
    load_from_disk=False,
    cuda=torch.cuda.is_available(),
)

In [None]:
%%time
resnet_trainer.run_training_loop(num_epochs=5)

In [None]:
save_trained_model_weights(model_resnet, out_dir=os.path.join(model_path, "ResNet"))

# Step 3: Test Inception Pretrained Model

In [None]:
inp_size = (299,299) # Double check

In [None]:
model_inception = MyInception(num_classes=num_classes)

inception_optimizer_config = {"optimizer_type": "adam", "lr": 1e-3, "weight_decay": 1e-8} # Tune these
inception_optimizer = get_optimizer(model_inception, inception_optimizer_config)

inception_trainer = Trainer(
    data_dir=data_path,
    model=model_inception,
    optimizer=inception_optimizer,
    model_dir=os.path.join(model_path, "Inception"),
    train_data_transforms=get_all_transforms(inp_size, [dataset_mean], [dataset_std]),
    val_data_transforms=get_fundamental_normalization_transforms(
        inp_size, [dataset_mean], [dataset_std]
    ),
    batch_size=batch_size,
    load_from_disk=False,
    cuda=torch.cuda.is_available(),
)

In [None]:
%%time
inception_trainer.run_training_loop(num_epochs=5)

In [None]:
save_trained_model_weights(model_inception, out_dir=os.path.join(model_path, "Inception"))

# Step 4: Analyze Graphs and Final Accuracies

### Loss & Accuracy Graphs

In [None]:
cnn_trainer.plot_loss_history()
cnn_trainer.plot_accuracy()

cnn_train_accuracy = cnn_trainer.train_accuracy_history[-1]
cnn_validation_accuracy = cnn_trainer.validation_accuracy_history[-1]
print(
    "Train Accuracy = {}; Validation Accuracy = {}".format(
        cnn_train_accuracy, cnn_validation_accuracy
    )
)

In [None]:
resnet_trainer.plot_loss_history()
resnet_trainer.plot_accuracy()

resnet_train_accuracy = resnet_trainer.train_accuracy_history[-1]
resnet_validation_accuracy = resnet_trainer.validation_accuracy_history[-1]
print(
    "Train Accuracy = {}; Validation Accuracy = {}".format(
        resnet_train_accuracy, resnet_validation_accuracy
    )
)

In [None]:
inception_trainer.plot_loss_history()
inception_trainer.plot_accuracy()

inception_train_accuracy = inception_trainer.train_accuracy_history[-1]
inception_validation_accuracy = inception_trainer.validation_accuracy_history[-1]
print(
    "Train Accuracy = {}; Validation Accuracy = {}".format(
        inception_train_accuracy, inception_validation_accuracy
    )
)

### Confusion Matrices

In [None]:
generate_and_plot_confusion_matrix(model_cnn, cnn_trainer.val_dataset, use_cuda=torch.cuda.is_available())

In [None]:
generate_and_plot_confusion_matrix(model_resnet, resnet_trainer.val_dataset, use_cuda=torch.cuda.is_available())

In [None]:
generate_and_plot_confusion_matrix(model_inception, inception_trainer.val_dataset, use_cuda=torch.cuda.is_available())

### Analyze errors that occurred from confusion matrix

In [None]:
trainer = resnet_trainer # Change this
model = model_resnet # Change this

# Analyze confusion matrix and change these to observe results
predicted_class_num = 0
true_class_num = 0

correct_class = [k for k, v in trainer.val_dataset.class_dict.items() if v == true_class_num][0]
pred_class = key = [k for k, v in trainer.val_dataset.class_dict.items() if v == predicted_class_num][0]
print(trainer.val_dataset.class_dict)

paths = get_pred_images_for_target(model, trainer.val_dataset, predicted_class_num, true_class_num, torch.cuda.is_available())
max_count = 10
count = 0
for path in paths:
    img = Image.open(path).convert(mode='L')
    if (count != max_count):
        plt.imshow(img, cmap='gray')
        plt.title(f'Image of {correct_class}, misclassified as {pred_class}')
        plt.axis('off')  # Removes axis ticks
        plt.show()
        count += 1