# Importing Contigencies

In [None]:
!pip install --upgrade torch torchvision timm

In [None]:
!pip install timm

In [None]:
!pip install grad-cam

In [None]:
from google.colab import drive
import os
import zipfile
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import random
from collections import defaultdict
import time
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, confusion_matrix

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, models
from scipy.ndimage import gaussian_filter, map_coordinates
import torchvision.transforms.functional as TF
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.models as models
import warnings
import timm
from torch.cuda.amp import autocast, GradScaler
import traceback

In [None]:
drive.mount('/content/drive')

# Data preprocessing
## 1) Unzipping the MRI Images

In [None]:
# Path to your zip file
zip_path = '/content/drive/My Drive/Colab Notebooks/mri_images.zip'
# Create a directory to extract the files
extract_path = '/content/mri_images'
os.makedirs(extract_path, exist_ok=True)

In [None]:
# Get the name of the zip file without the path
zip_filename = os.path.basename(zip_path)

# Get the name of the extracted folder (remove .zip extension)
extracted_folder = os.path.splitext(zip_filename)[0]


In [None]:
# Unzip the file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall()

print(f"Zip file: {zip_filename}")
print(f"Extracted folder: {extracted_folder}")
print(f"Files extracted to: {extract_path}")


In [None]:
# Define the path to the glioma folder
glioma_path = os.path.join('Testing', 'glioma')

# Verify the path exists
if os.path.exists(glioma_path):
    print(f"Successfully located the glioma folder at: {glioma_path}")
else:
    print("Couldn't find the glioma folder. Please check the path.")

#### Viewing a Random Image:

In [None]:
# Get a list of all image files in the glioma folder
image_files = [f for f in os.listdir(glioma_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

if image_files:
    # Select a random image
    random_image = random.choice(image_files)
    image_path = os.path.join(glioma_path, random_image)

    # Open and display the image
    img = Image.open(image_path)
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Sample Glioma Image: {random_image}")
    plt.show()
else:
    print("No image files found in the glioma folder.")

In [None]:
# Get a list of all image files in the glioma folder
image_files = [f for f in os.listdir(glioma_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

if image_files:
    # Select a random image
    random_image = random.choice(image_files)
    image_path = os.path.join(glioma_path, random_image)

    # Open the image
    img = Image.open(image_path)

    # Convert the image to a numpy array
    img_array = np.array(img)

    # Display the image
    plt.figure(figsize=(8, 8))
    plt.imshow(img_array, cmap='gray')  # Use 'gray' colormap
    plt.axis('off')
    plt.title(f"Sample Glioma Image: {random_image}")
    plt.show()
else:
    print("No image files found in the glioma folder.")

## 2) Resizing all the images to ResNet-50 input size and Saving them in a new folder

In [None]:
def resize_images(input_folder, output_folder, target_size=(224, 224)):
    """
    Resize all images in the input folder and save them to the output folder.

    Args:
    input_folder (str): Path to the folder containing original images
    output_folder (str): Path to the folder where resized images will be saved
    target_size (tuple): The target size for the images (width, height)
    """
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for filename in os.listdir(input_folder):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
            try:
                img = Image.open(os.path.join(input_folder, filename))
                img = img.resize(target_size, Image.LANCZOS)
                img.save(os.path.join(output_folder, filename))
                print(f"Resized {filename}")
            except Exception as e:
                print(f"Error processing {filename}: {str(e)}")

# Define the size you want for all images
target_size = (224, 224)  # This is a common size for many CNN architectures

In [None]:
# Define the base directories
base_input_dir = 'Training'  # Adjust this if your folder structure is different
base_output_dir = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'

# List of tumor types
tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Resize images for each tumor type
for tumor_type in tumor_types:
    input_folder = os.path.join(base_input_dir, tumor_type)
    output_folder = os.path.join(base_output_dir, tumor_type)

    print(f"Resizing images in {tumor_type} folder...")
    resize_images(input_folder, output_folder, target_size)

print("Image resizing complete!")

In [None]:
# Define the base directories
base_input_dir = 'Testing'  # Adjust this if your folder structure is different
base_output_dir = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'

# List of tumor types
tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Resize images for each tumor type
for tumor_type in tumor_types:
    input_folder = os.path.join(base_input_dir, tumor_type)
    output_folder = os.path.join(base_output_dir, tumor_type)

    print(f"Resizing images in {tumor_type} folder...")
    resize_images(input_folder, output_folder, target_size)

print("Image resizing complete!")

### Loading the resized folders from google drive:

In [None]:
# Mount Google Drive if not already mounted
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Define the base directory in Google Drive
base_output_dir = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'

# List of tumor types
tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Choose a random tumor type and image
tumor_type = random.choice(tumor_types)
resized_folder = os.path.join(base_output_dir, tumor_type)

# Check if the folder exists
if not os.path.exists(resized_folder):
    print(f"Error: The folder {resized_folder} does not exist in your Google Drive.")
else:
    image_files = [f for f in os.listdir(resized_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

    if not image_files:
        print(f"Error: No image files found in {resized_folder}")
    else:
        random_image = random.choice(image_files)

        # Display the image
        img_path = os.path.join(resized_folder, random_image)
        img = Image.open(img_path)
        img_array = np.array(img)

        plt.figure(figsize=(8, 8))
        plt.imshow(img_array, cmap='gray')
        plt.axis('off')
        plt.title(f"Resized {tumor_type} Image: {random_image}\nSize: {img_array.shape}")
        plt.show()

        print(f"Image size: {img_array.shape}")

In [None]:
# Mount Google Drive if not already mounted
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Define the base directory in Google Drive
base_output_dir = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'

# List of tumor types
tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Choose a random tumor type and image
tumor_type = random.choice(tumor_types)
resized_folder = os.path.join(base_output_dir, tumor_type)

# Check if the folder exists
if not os.path.exists(resized_folder):
    print(f"Error: The folder {resized_folder} does not exist in your Google Drive.")
else:
    image_files = [f for f in os.listdir(resized_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

    if not image_files:
        print(f"Error: No image files found in {resized_folder}")
    else:
        random_image = random.choice(image_files)

        # Display the image
        img_path = os.path.join(resized_folder, random_image)
        img = Image.open(img_path)
        img_array = np.array(img)

        plt.figure(figsize=(8, 8))
        plt.imshow(img_array, cmap='gray')
        plt.axis('off')
        plt.title(f"Resized {tumor_type} Image: {random_image}\nSize: {img_array.shape}")
        plt.show()

        print(f"Image size: {img_array.shape}")

# PyTorch Data Pre-Processing

## 1) DataSet & DataLoader

Defining the data sets:

In [None]:
test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']

DataSet Class:

## 2) Data Augmentation

1. Random Crop
2. Random Flip
3. Random Rotate
4. Random Brightness
5. Random Contrast
6. On the fly augmentation during training (saves memory)

__Important__

1) Avoid extreme rotations or flips that could change the anatomical orientation.

2) Be cautious with color-based augmentations, as MRI intensity values often have specific meanings.

3)Maintain the overall structure and proportions of the brain.

1. Slight rotations (within ±10 degrees)
2. Small shifts (translations)
3. Zoom in/out (within a small range)
4. Minimal brightness and contrast adjustments
5. Gaussian noise addition (to simulate image noise)
6. Elastic deformations (subtle warping)  

In [None]:
class RandomMRIAugmentation(nn.Module):
    def __init__(self, rotation_range=10, translation_range=0.1, zoom_range=0.1, noise_factor=0.05, p=0.5):
        super().__init__()
        self.rotation_range = rotation_range
        self.translation_range = translation_range
        self.zoom_range = zoom_range
        self.noise_factor = noise_factor
        self.p = p  # Probability of applying each augmentation

    def forward(self, img):
        # Ensure input is a tensor
        if not isinstance(img, torch.Tensor):
            img = TF.to_tensor(img)

        # Random rotation
        if random.random() < self.p:
            angle = random.uniform(-self.rotation_range, self.rotation_range)
            img = TF.rotate(img, angle)

        # Random translation
        if random.random() < self.p:
            translate = [random.uniform(-self.translation_range, self.translation_range) for _ in range(2)]
            img = TF.affine(img, angle=0, translate=translate, scale=1, shear=0)

        # Random zoom
        if random.random() < self.p:
            scale = random.uniform(1-self.zoom_range, 1+self.zoom_range)
            img = TF.affine(img, angle=0, translate=(0,0), scale=scale, shear=0)

        # Add Gaussian noise
        if random.random() < self.p:
            noise = torch.randn_like(img) * self.noise_factor
            img = img + noise
            img = torch.clamp(img, 0, 1)

        return img

def elastic_transform(image, alpha=1000, sigma=30, alpha_affine=30):
    """Elastic deformation of images as described in [Simard2003]."""
    random_state = np.random.RandomState(None)

    shape = image.shape
    shape_size = shape[:2]

    # Random affine
    center_square = np.float32(shape_size) // 2
    square_size = min(shape_size) // 3
    pts1 = np.float32([center_square + square_size, [center_square[0]+square_size, center_square[1]-square_size], center_square - square_size])
    pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
    M = cv2.getAffineTransform(pts1, pts2)
    image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)

    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha

    x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
    indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))

    return map_coordinates(image, indices, order=1, mode='reflect').reshape(shape)

class RandomElasticDeformation(nn.Module):
    def __init__(self, p=0.2, alpha=1000, sigma=30, alpha_affine=30):
        super().__init__()
        self.p = p
        self.alpha = alpha
        self.sigma = sigma
        self.alpha_affine = alpha_affine

    def forward(self, img):
        if random.random() < self.p:
            if isinstance(img, torch.Tensor):
                img = img.numpy()
            img = elastic_transform(img, self.alpha, self.sigma, self.alpha_affine)
            img = torch.from_numpy(img)
        return img

def get_mri_augmentation(p_transform=0.5, p_elastic=0.2):
    return transforms.Compose([
        RandomMRIAugmentation(rotation_range=10, translation_range=0.1, zoom_range=0.1, noise_factor=0.05, p=p_transform),
        RandomElasticDeformation(p=p_elastic),
        transforms.Normalize(mean=[0.485], std=[0.229])  # Adjust these values based on your MRI data statistics
    ])

In [None]:
class MRIDataset(Dataset):
    def __init__(self, folder_path, tumor_types, transform=None, augment=False):
        self.folder_path = folder_path
        self.tumor_types = tumor_types
        self.transform = transform
        self.augment = augment
        self.image_paths = []
        self.labels = []

        for label, tumor_type in enumerate(tumor_types):
            tumor_folder = os.path.join(folder_path, tumor_type)
            for img_name in os.listdir(tumor_folder):
                self.image_paths.append(os.path.join(tumor_folder, img_name))
                self.labels.append(label)

        if self.augment:
            self.aug_transform = get_mri_augmentation(p_transform=0.5, p_elastic=0.2)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.augment:
            image = self.aug_transform(image)
        elif self.transform:
            image = self.transform(image)

        return image, label

In [None]:
def create_mri_datasets(train_folder, test_folder, tumor_types, val_split=0.2, batch_size=32):
    # Define base transform for validation and test sets
    base_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229])  # Adjust based on your MRI data
    ])

    # Create datasets
    full_train_dataset = MRIDataset(train_folder, tumor_types, augment=True)
    test_dataset = MRIDataset(test_folder, tumor_types, transform=base_transform)

    # Split the training dataset into train and validation
    train_indices, val_indices = train_test_split(
        range(len(full_train_dataset)),
        test_size=val_split,
        stratify=full_train_dataset.labels,
        random_state=42
    )

    train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices)
    val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices)

    # Override the transform for the validation set
    val_dataset.dataset.augment = False
    val_dataset.dataset.transform = base_transform

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, test_loader


In [None]:
# Usage
train_loader, val_loader, test_loader = create_mri_datasets(train_folder, test_folder, tumor_types)

# ResNets


## Model 1: ResNet50

Defining the model:


In [None]:
def initialize_resnet50_model(num_classes):
    model = models.resnet50(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

def train_model(model, train_loader, val_loader, num_epochs=100, patience=10):
    start_time = time.time()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    best_val_loss = float('inf')
    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
             best_val_loss = val_loss
             best_val_acc = val_acc
             epochs_no_improve = 0
             best_model = model.state_dict()
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print('Early stopping!')
            model.load_state_dict(best_model)
            break

    end_time = time.time()
    training_time = end_time - start_time
    return model, best_val_acc, training_time



In [None]:
def evaluate_model(model, test_loader, tumor_types):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    all_preds = []
    all_labels = []
    all_probs = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(torch.nn.functional.softmax(outputs, dim=1).cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    accuracy = (all_preds == all_labels).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    auc_roc = roc_auc_score(all_labels, all_probs, average='weighted', multi_class='ovr')
    avg_loss = total_loss / len(test_loader.dataset)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
    print(f'AUC-ROC: {auc_roc:.4f}')
    print(f'Average Loss: {avg_loss:.4f}')

    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=tumor_types, yticklabels=tumor_types)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    return accuracy, precision, recall, f1, auc_roc, avg_loss, all_preds, all_labels

def visualize_model_attention(model, input_tensor, target_class):
    model.eval()
    cam = GradCAM(model=model, target_layers=[model.layer4[-1]], use_cuda=torch.cuda.is_available())
    grayscale_cam = cam(input_tensor=input_tensor.unsqueeze(0), target_category=target_class)
    visualization = show_cam_on_image(input_tensor.permute(1, 2, 0).numpy(), grayscale_cam[0, :], use_rgb=True)
    plt.imshow(visualization)
    plt.axis('off')
    plt.title(f'Grad-CAM for class {target_class}')
    plt.show()

In [None]:
def get_model_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6  # Size in MB
    os.remove('temp.p')
    return size

def create_metrics_dataframe(model, test_acc, precision, recall, f1, auc_roc, train_time, test_loss):
    metrics = {
        'Metric': ['Overall Accuracy', 'F1 Score', 'Cross Entropy Loss', 'Training Time (s)', 'Number of Parameters', 'Model Size (MB)'],
        'Value': [
            test_acc,
            f1,
            test_loss,
            train_time,
            sum(p.numel() for p in model.parameters()),
            get_model_size(model)
        ]
    }
    df = pd.DataFrame(metrics)
    return df

In [None]:
# Define the base path in your Google Drive
base_path = '/content/drive/MyDrive/Colab Notebooks'

# Function to save the model
def save_model(model, filename):
    save_path = os.path.join(base_path, filename)
    try:
        torch.save(model.state_dict(), save_path)
        print(f"Model saved successfully to {save_path}")
    except Exception as e:
        print(f"Error saving model: {e}")

# Function to load the model
def load_model(model, filename):
    load_path = os.path.join(base_path, filename)
    try:
        model.load_state_dict(torch.load(load_path))
        print(f"Model loaded successfully from {load_path}")
    except Exception as e:
        print(f"Error loading model: {e}")
    return model


Training Loop: GPU T4

In [None]:
# Main execution
if __name__ == "__main__":
    # Setup
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Data augmentation and normalization for training
    # Just normalization for validation/testing
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    # Create the dataset
    full_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
    test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])
    results = []
    # K-Fold Cross-validation
    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(full_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(full_dataset, batch_size=32, sampler=train_subsampler)
        val_loader = DataLoader(full_dataset, batch_size=32, sampler=val_subsampler)

        model = initialize_resnet50_model(num_classes)
        model, val_acc, train_time = train_model(model, train_loader, val_loader)

        results.append({
            'Fold': fold+1,
            'Validation Accuracy': val_acc,
            'Training Time (s)': train_time
        })

        # Save the model for this fold
        save_model(model, f'model_fold_{fold+1}.pth')

    # After k-fold cross-validation, train on the entire training set
    print('FINAL TRAINING')
    print('--------------------------------')
    train_loader = DataLoader(full_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    final_model = initialize_model(num_classes)
    final_model, final_val_acc, final_train_time = train_model(final_model, train_loader, val_loader)

    results.append({
        'Fold': 'Final',
        'Validation Accuracy': final_val_acc,
        'Training Time (s)': final_train_time
    })

    # Create and display the summary table
    summary_df = pd.DataFrame(results)
    print("\nTraining Summary:")
    print(summary_df.to_string(index=False))

    # Evaluate on the test set
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    test_acc, precision, recall, f1, auc_roc, test_loss, _, _ = evaluate_model(final_model, test_loader, tumor_types)

    print(f"\nFinal Test Accuracy: {test_acc:.4f}")

    # Create and display the metrics DataFrame
    metrics_df = create_metrics_dataframe(final_model, test_acc, precision, recall, f1, auc_roc, final_train_time, test_loss)
    print("\nModel Metrics:")
    print(metrics_df.to_string(index=False))

    # Save the DataFrame
    metrics_csv_path = os.path.join(base_path, 'resnet50_metrics.csv')
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"\nMetrics saved to {metrics_csv_path}")

    # Save the final model
    save_model(final_model, 'final_mriresnet50_model.pth')

    print("Training, evaluation, and metrics logging complete!")

## Model 2: ResNet101

In [None]:
def initialize_model_resnet101(num_classes):
    model = models.resnet101(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

def train_model(model, train_loader, val_loader, num_epochs=100, patience=10):
    start_time = time.time()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Initialize criterion, optimizer, and scheduler
    criterion = nn.CrossEntropyLoss()
    # Separate parameter groups for different learning rates
    optimizer = optim.Adam([
        {'params': model.fc.parameters(), 'lr': 0.001},
        {'params': model.layer4.parameters(), 'lr': 0.0001}
    ])
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    best_val_loss = float('inf')
    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            epochs_no_improve = 0
            best_model = model.state_dict()
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve == patience:
            print('Early stopping triggered!')
            model.load_state_dict(best_model)
            break

    end_time = time.time()
    training_time = end_time - start_time
    print(f'Total training time: {training_time/60:.2f} minutes')

    return model, best_val_acc, training_time

In [None]:
# Main execution
if __name__ == "__main__":
    # Setup
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Data augmentation and normalization for training
    # Just normalization for validation/testing
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    # Create the dataset
    full_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
    test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])
    results = []

    # K-Fold Cross-validation
    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(full_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(full_dataset, batch_size=32, sampler=train_subsampler)
        val_loader = DataLoader(full_dataset, batch_size=32, sampler=val_subsampler)

        model = initialize_model_resnet101(num_classes)
        model, val_acc, train_time = train_model(model, train_loader, val_loader)

        results.append({
            'Fold': fold+1,
            'Validation Accuracy': val_acc,
            'Training Time (s)': train_time
        })

        # Save the model for this fold
        save_model(model, f'resnet101_model_fold_{fold+1}.pth')

    # After k-fold cross-validation, train on the entire training set
    print('FINAL TRAINING')
    print('--------------------------------')
    train_loader = DataLoader(full_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    final_model = initialize_model_resnet101(num_classes)
    final_model, final_val_acc, final_train_time = train_model(final_model, train_loader, val_loader)

    results.append({
        'Fold': 'Final',
        'Validation Accuracy': final_val_acc,
        'Training Time (s)': final_train_time
    })

    # Create and display the summary table
    summary_df = pd.DataFrame(results)
    print("\nTraining Summary:")
    print(summary_df.to_string(index=False))

    # Evaluate on the test set
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    test_acc, precision, recall, f1, auc_roc, test_loss, _, _ = evaluate_model(final_model, test_loader, tumor_types)

    print(f"\nFinal Test Accuracy: {test_acc:.4f}")

    # Create and display the metrics DataFrame
    metrics_df = create_metrics_dataframe(final_model, test_acc, precision, recall, f1, auc_roc, final_train_time, test_loss)
    print("\nModel Metrics:")
    print(metrics_df.to_string(index=False))

    # Save the DataFrame
    metrics_csv_path = os.path.join(base_path, 'resnet101_metrics.csv')
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"\nMetrics saved to {metrics_csv_path}")

    # Save the final model
    save_model(final_model, 'final_resnet101_classification_model.pth')

    print("Training, evaluation, and metrics logging complete!")

def evaluate_model(model, test_loader, tumor_types):
    model.eval()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    all_preds = []
    all_labels = []
    all_probs = []
    total_loss = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(torch.nn.functional.softmax(outputs, dim=1).cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    accuracy = (all_preds == all_labels).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    auc_roc = roc_auc_score(all_labels, all_probs, average='weighted', multi_class='ovr')
    avg_loss = total_loss / len(test_loader.dataset)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
    print(f'AUC-ROC: {auc_roc:.4f}')
    print(f'Average Loss: {avg_loss:.4f}')

    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=tumor_types, yticklabels=tumor_types)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    return accuracy, precision, recall, f1, auc_roc, avg_loss, all_preds, all_labels

def visualize_model_attention(model, input_tensor, target_class):
    model.eval()
    cam = GradCAM(model=model, target_layers=[model.layer4[-1]], use_cuda=torch.cuda.is_available())
    grayscale_cam = cam(input_tensor=input_tensor.unsqueeze(0), target_category=target_class)
    visualization = show_cam_on_image(input_tensor.permute(1, 2, 0).numpy(), grayscale_cam[0, :], use_rgb=True)
    plt.imshow(visualization)
    plt.axis('off')
    plt.title(f'Grad-CAM for class {target_class}')
    plt.show()

# EfficientNets

## Model 3: EfficientNetB0

In [None]:
def initialize_efficientnet0_model(num_classes):
    model = models.efficientnet_b0(pretrained=True)
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    return model

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=100, patience=10):
    start_time = time.time()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    best_val_loss = float('inf')
    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
             best_val_loss = val_loss
             best_val_acc = val_acc
             epochs_no_improve = 0
             best_model = model.state_dict()
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print('Early stopping!')
            model.load_state_dict(best_model)
            break

    end_time = time.time()
    training_time = end_time - start_time
    return model, best_val_acc, training_time


In [None]:
def visualize_model_attention(model, input_tensor, target_class):
    model.eval()
    cam = GradCAM(model=model, target_layers=[model.features[-1]], use_cuda=torch.cuda.is_available())
    grayscale_cam = cam(input_tensor=input_tensor.unsqueeze(0), target_category=target_class)
    visualization = show_cam_on_image(input_tensor.permute(1, 2, 0).numpy(), grayscale_cam[0, :], use_rgb=True)
    plt.imshow(visualization)
    plt.axis('off')
    plt.title(f'Grad-CAM for class {target_class}')
    plt.show()

In [None]:
# Main execution
if __name__ == "__main__":
    # Setup
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Data augmentation and normalization for training
    # Just normalization for validation/testing
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    # Create the dataset
    full_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
    test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])
    results = []

    # K-Fold Cross-validation
    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(full_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(full_dataset, batch_size=32, sampler=train_subsampler)
        val_loader = DataLoader(full_dataset, batch_size=32, sampler=val_subsampler)

        model = initialize_efficientnet0_model(num_classes)
        model, val_acc, train_time = train_model(model, train_loader, val_loader)

        results.append({
            'Fold': fold+1,
            'Validation Accuracy': val_acc,
            'Training Time (s)': train_time
        })

        # Save the model for this fold
        save_model(model, f'efficientnet_b0_model_fold_{fold+1}.pth')

    # After k-fold cross-validation, train on the entire training set
    print('FINAL TRAINING')
    print('--------------------------------')
    train_loader = DataLoader(full_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    final_model = initialize_model(num_classes)
    final_model, final_val_acc, final_train_time = train_model(final_model, train_loader, val_loader)

    results.append({
        'Fold': 'Final',
        'Validation Accuracy': final_val_acc,
        'Training Time (s)': final_train_time
    })

    # Create and display the summary table
    summary_df = pd.DataFrame(results)
    print("\nTraining Summary:")
    print(summary_df.to_string(index=False))

    # Evaluate on the test set
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    test_acc, precision, recall, f1, auc_roc, test_loss, _, _ = evaluate_model(final_model, test_loader, tumor_types)

    print(f"\nFinal Test Accuracy: {test_acc:.4f}")

    # Create and display the metrics DataFrame
    metrics_df = create_metrics_dataframe(final_model, test_acc, precision, recall, f1, auc_roc, final_train_time, test_loss)
    print("\nModel Metrics:")
    print(metrics_df.to_string(index=False))

    # Save the DataFrame
    metrics_csv_path = os.path.join(base_path, 'efficientnet_b0_model_metrics.csv')
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"\nMetrics saved to {metrics_csv_path}")

    # Save the final model
    save_model(final_model, 'final_efficientnet_b0_mri_classification_model.pth')

    print("Training, evaluation, and metrics logging complete for EfficientNet-B0 model!")

##Remark:
EfficienNet and ResNet are both CNNs that specialize in edge detection by nature. So it is normal to observe that model attention goes to the edge of tumours instead of directly hovering above it.

## Model 4: EfficientNetB1

In [None]:
def initialize_efficientnetb1_model(num_classes):
    model = models.efficientnet_b1(pretrained=True)
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    return model

def train_model(model, train_loader, val_loader, num_epochs=100, patience=10):
    start_time = time.time()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    best_val_loss = float('inf')
    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
             best_val_loss = val_loss
             best_val_acc = val_acc
             epochs_no_improve = 0
             best_model = model.state_dict()
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print('Early stopping!')
            model.load_state_dict(best_model)
            break

    end_time = time.time()
    training_time = end_time - start_time
    return model, best_val_acc, training_time



In [None]:
# Main execution
if __name__ == "__main__":
    # Setup
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Data augmentation and normalization for training
    # Just normalization for validation/testing
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    # Create the dataset
    full_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
    test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])
    results = []

    # K-Fold Cross-validation
    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(full_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(full_dataset, batch_size=32, sampler=train_subsampler)
        val_loader = DataLoader(full_dataset, batch_size=32, sampler=val_subsampler)

        model = initialize_model(num_classes)
        model, val_acc, train_time = train_model(model, train_loader, val_loader)

        results.append({
            'Fold': fold+1,
            'Validation Accuracy': val_acc,
            'Training Time (s)': train_time
        })

        # Save the model for this fold
        save_model(model, f'efficientnet_b1_model_fold_{fold+1}.pth')

    # After k-fold cross-validation, train on the entire training set
    print('FINAL TRAINING')
    print('--------------------------------')
    train_loader = DataLoader(full_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    final_model = initialize_efficientnetb1_model(num_classes)
    final_model, final_val_acc, final_train_time = train_model(final_model, train_loader, val_loader)

    results.append({
        'Fold': 'Final',
        'Validation Accuracy': final_val_acc,
        'Training Time (s)': final_train_time
    })

    # Create and display the summary table
    summary_df = pd.DataFrame(results)
    print("\nTraining Summary:")
    print(summary_df.to_string(index=False))

    # Evaluate on the test set
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    test_acc, precision, recall, f1, auc_roc, test_loss, _, _ = evaluate_model(final_model, test_loader, tumor_types)

    print(f"\nFinal Test Accuracy: {test_acc:.4f}")

    # Create and display the metrics DataFrame
    metrics_df = create_metrics_dataframe(final_model, test_acc, precision, recall, f1, auc_roc, final_train_time, test_loss)
    print("\nModel Metrics:")
    print(metrics_df.to_string(index=False))

    # Save the DataFrame
    metrics_csv_path = os.path.join(base_path, 'efficientnet_b1_model_metrics.csv')
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"\nMetrics saved to {metrics_csv_path}")

    # Save the final model
    save_model(final_model, 'final_efficientnet_b1_mri_classification_model.pth')

    print("Training, evaluation, and metrics logging complete for EfficientNet-B1 model!")

# ViT: Vision transformer

# Model 5: ViT Small

In [None]:
def initialize_vit_model(num_classes):
    # Initialize ViT small model from timm
    model = timm.create_model('vit_small_patch16_224', pretrained=True)
    # Modify the head to match our number of classes
    model.head = nn.Linear(model.head.in_features, num_classes)
    return model

def train_model(model, train_loader, val_loader, num_epochs=100, patience=10):
    start_time = time.time()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    # Adjusted learning rate for ViT
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    best_val_loss = float('inf')
    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            epochs_no_improve = 0
            best_model = model.state_dict()
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print('Early stopping!')
            model.load_state_dict(best_model)
            break

    end_time = time.time()
    training_time = end_time - start_time
    return model, best_val_acc, training_time



In [None]:
# Main execution
if __name__ == "__main__":
    # Setup
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Data augmentation and normalization for training
    # Using ViT-specific preprocessing
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # ViT standard normalization
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]),
    }

    # Rest of the code remains the same, just update the model names in save operations
    full_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
    test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])
    results = []

    # K-Fold Cross-validation
    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(full_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(full_dataset, batch_size=32, sampler=train_subsampler)
        val_loader = DataLoader(full_dataset, batch_size=32, sampler=val_subsampler)

        model = initialize_model(num_classes)
        model, val_acc, train_time = train_model(model, train_loader, val_loader)

        results.append({
            'Fold': fold+1,
            'Validation Accuracy': val_acc,
            'Training Time (s)': train_time
        })

        # Save the model for this fold
        save_model(model, f'vit_small_model_fold_{fold+1}.pth')

    # Final training on entire dataset
    print('FINAL TRAINING')
    print('--------------------------------')
    train_loader = DataLoader(full_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    final_model = initialize_model(num_classes)
    final_model, final_val_acc, final_train_time = train_model(final_model, train_loader, val_loader)

    results.append({
        'Fold': 'Final',
        'Validation Accuracy': final_val_acc,
        'Training Time (s)': final_train_time
    })

    # Create and display the summary table
    summary_df = pd.DataFrame(results)
    print("\nTraining Summary:")
    print(summary_df.to_string(index=False))

    # Evaluate on the test set
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    test_acc, precision, recall, f1, auc_roc, test_loss, _, _ = evaluate_model(final_model, test_loader, tumor_types)

    print(f"\nFinal Test Accuracy: {test_acc:.4f}")

    # Create and display the metrics DataFrame
    metrics_df = create_metrics_dataframe(final_model, test_acc, precision, recall, f1, auc_roc, final_train_time, test_loss)
    print("\nModel Metrics:")
    print(metrics_df.to_string(index=False))

    # Save the metrics
    metrics_csv_path = os.path.join(base_path, 'vit_small_model_metrics.csv')
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"\nMetrics saved to {metrics_csv_path}")

    # Save the final model
    save_model(final_model, 'final_vit_small_mri_classification_model.pth')

    print("Training, evaluation, and metrics logging complete for ViT Small model!")

# LeVits

## Model 6: Levit-256

In [None]:
class LeViT256Model(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(LeViT256Model, self).__init__()

        # Load pretrained LeViT-256 model
        self.levit = timm.create_model('levit_256', pretrained=pretrained, num_classes=0)

        # Get the number of features from LeViT
        levit_num_features = self.levit.num_features

        # Add final classification layers
        self.fc = nn.Sequential(
            nn.Linear(levit_num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # Process input through LeViT
        features = self.levit(x)

        # Final classification
        output = self.fc(features)
        return output

# Function to initialize the model
def initialize_levit256_model(num_classes, pretrained=True):
    return LeViT256Model(num_classes, pretrained)

# Evaluation Function
def evaluate_model(model, test_loader, tumor_types):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    all_preds = []
    all_labels = []
    all_probs = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(torch.nn.functional.softmax(outputs, dim=1).cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    accuracy = (all_preds == all_labels).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    auc_roc = roc_auc_score(all_labels, all_probs, average='weighted', multi_class='ovr')
    avg_loss = total_loss / len(test_loader.dataset)

    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
    print(f'AUC-ROC: {auc_roc:.4f}')
    print(f'Average Loss: {avg_loss:.4f}')

    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=tumor_types, yticklabels=tumor_types)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    return accuracy, precision, recall, f1, auc_roc, avg_loss, all_preds, all_labels

# Function to create metrics DataFrame
def create_metrics_dataframe(model, test_acc, precision, recall, f1, auc_roc, train_time, test_loss):
    metrics = {
        'Metric': ['Overall Accuracy', 'F1 Score', 'Cross Entropy Loss', 'Training Time (s)', 'Number of Parameters', 'Model Size (MB)'],
        'Value': [
            test_acc,
            f1,
            test_loss,
            train_time,
            sum(p.numel() for p in model.parameters()),
            sum(p.nelement() * p.element_size() for p in model.parameters()) / (1024 * 1024)
        ]
    }
    df = pd.DataFrame(metrics)
    return df

In [None]:
def train_levit_model(train_dataset, test_dataset, num_classes, num_epochs=100, patience=10, k_folds=5):
    results = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    criterion = nn.CrossEntropyLoss()

    # K-Fold Cross-validation
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_subsampler, num_workers=2, pin_memory=True)
        val_loader = DataLoader(train_dataset, batch_size=32, sampler=val_subsampler, num_workers=2, pin_memory=True)

        model = initialize_levit_model(num_classes).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

        best_val_loss = float('inf')
        best_model = None
        epochs_no_improve = 0
        start_time = time.time()

        for epoch in range(num_epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, labels in train_loader:
                try:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item() * inputs.size(0)
                    _, predicted = torch.max(outputs.data, 1)
                    train_total += labels.size(0)
                    train_correct += (predicted == labels).sum().item()
                except RuntimeError as e:
                    print(f"RuntimeError in training loop: {e}")
                    print(f"Input shape: {inputs.shape}")
                    continue

            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for inputs, labels in val_loader:
                    try:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        val_loss += loss.item() * inputs.size(0)
                        _, predicted = torch.max(outputs.data, 1)
                        val_total += labels.size(0)
                        val_correct += (predicted == labels).sum().item()
                    except RuntimeError as e:
                        print(f"RuntimeError in validation loop: {e}")
                        print(f"Input shape: {inputs.shape}")
                        continue

            # Calculate average losses and accuracies
            train_loss = train_loss / len(train_loader.dataset)
            val_loss = val_loss / len(val_loader.dataset)
            train_acc = train_correct / train_total
            val_acc = val_correct / val_total

            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            scheduler.step(val_loss)

            # Save the first model state or if we have a new best validation loss
            if best_model is None or val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model.state_dict()
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve == patience:
                print('Early stopping!')
                break

        end_time = time.time()
        training_time = end_time - start_time
        results.append({
            'Fold': fold+1,
            'Best Validation Loss': best_val_loss,
            'Training Time (s)': training_time
        })

    # Final training on entire dataset
    print('FINAL TRAINING')
    print('--------------------------------')
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

    final_model = initialize_levit_model(num_classes).to(device)
    optimizer = optim.Adam(final_model.parameters(), lr=0.001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

    best_val_loss = float('inf')
    best_model = None  # Initialize best_model
    epochs_no_improve = 0
    start_time = time.time()

    for epoch in range(num_epochs):
        # Training phase
        final_model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            try:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = final_model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
            except RuntimeError as e:
                print(f"RuntimeError in final training loop: {e}")
                print(f"Input shape: {inputs.shape}")
                continue

        # Validation phase
        final_model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                try:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = final_model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item() * inputs.size(0)
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
                except RuntimeError as e:
                    print(f"RuntimeError in final validation loop: {e}")
                    print(f"Input shape: {inputs.shape}")
                    continue

        # Calculate average losses and accuracies
        train_loss = train_loss / len(train_loader.dataset)
        val_loss = val_loss / len(test_loader.dataset)
        train_acc = train_correct / train_total
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        # Save the first model state or if we have a new best validation loss
        if best_model is None or val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = final_model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print('Early stopping!')
            break

    end_time = time.time()
    final_training_time = end_time - start_time
    results.append({
        'Fold': 'Final',
        'Best Validation Loss': best_val_loss,
        'Training Time (s)': final_training_time
    })

    # Load the best model state
    final_model.load_state_dict(best_model)

    # Save the final model
    torch.save({
        'epoch': len(results),
        'model_state_dict': best_model,
        'results': results
    }, 'final_levit_256_model.pth')

    return results, final_model

In [None]:
warnings.filterwarnings('ignore')

# Disable torch compile to avoid dynamo issues
torch._dynamo.config.suppress_errors = True

if __name__ == "__main__":
    # Setup
    base_path = '/content/drive/MyDrive/Colab Notebooks'
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Create checkpoint directory
    checkpoint_dir = os.path.join(base_path, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Data augmentation and normalization for training
    # Just normalization for validation/testing
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    try:
        # Create the datasets
        train_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
        test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])

        # Print dataset sizes
        print("Dataset sizes:")
        print(f"Training: {len(train_dataset)}")
        print(f"Testing: {len(test_dataset)}")

        # Train the LeViT-256 model
        results, final_model = train_levit_model(train_dataset, test_dataset, num_classes)

        # Display training results
        results_df = pd.DataFrame(results)
        print("\nTraining Results:")
        print(results_df)

        # Save training history
        history_path = os.path.join(base_path, 'levit_256_training_history.csv')
        results_df.to_csv(history_path, index=False)
        print(f"Training history saved to {history_path}")

        # Evaluate on the test set
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
        test_acc, precision, recall, f1, auc_roc, test_loss, predictions, true_labels = evaluate_model(
            final_model,
            test_loader,
            tumor_types
        )

        # Create and display the metrics DataFrame
        metrics_df = create_metrics_dataframe(
            final_model,
            test_acc,
            precision,
            recall,
            f1,
            auc_roc,
            results[-1]['Training Time (s)'],
            test_loss
        )
        print("\nModel Metrics:")
        print(metrics_df.to_string(index=False))

        # Save the metrics DataFrame
        metrics_csv_path = os.path.join(base_path, 'levit_256_model_metrics.csv')
        metrics_df.to_csv(metrics_csv_path, index=False)
        print(f"\nMetrics saved to {metrics_csv_path}")

        # Save the final model with all relevant information
        model_path = os.path.join(base_path, 'levit_256_model_final.pth')
        torch.save({
            'epoch': len(results),
            'model_state_dict': final_model.state_dict(),
            'results': results,
            'test_metrics': {
                'accuracy': test_acc,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'auc_roc': auc_roc,
                'test_loss': test_loss
            }
        }, model_path)
        print(f"\nModel saved to {model_path}")

        # Save predictions
        predictions_df = pd.DataFrame({
            'True_Label': [tumor_types[i] for i in true_labels],
            'Predicted_Label': [tumor_types[i] for i in predictions]
        })
        predictions_path = os.path.join(base_path, 'levit_256_test_predictions.csv')
        predictions_df.to_csv(predictions_path, index=False)
        print(f"Test predictions saved to {predictions_path}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        traceback.print_exc()  # This will print the full error traceback
        raise e

    print("\nLeViT-256 training, evaluation, and metrics logging complete!")

## Model 7: Levit-384

In [None]:
class LeViT384Model(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(LeViT384Model, self).__init__()

        # Load pretrained LeViT-384 model
        self.levit = timm.create_model('levit_384', pretrained=pretrained, num_classes=0)

        # Get the number of features from LeViT
        levit_num_features = self.levit.num_features

        # Add final classification layers
        self.fc = nn.Sequential(
            nn.Linear(levit_num_features, 1024),  # Increased intermediate layer size for LeViT-384
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        # Process input through LeViT
        features = self.levit(x)

        # Final classification
        output = self.fc(features)
        return output

# Function to initialize the model
def initialize_levit384_model(num_classes, pretrained=True):
    return LeViT384Model(num_classes, pretrained)

In [None]:
def train_levit_model(train_dataset, test_dataset, num_classes, num_epochs=100, patience=10, k_folds=5):
    results = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    criterion = nn.CrossEntropyLoss()

    # K-Fold Cross-validation
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_subsampler, num_workers=2, pin_memory=True)
        val_loader = DataLoader(train_dataset, batch_size=32, sampler=val_subsampler, num_workers=2, pin_memory=True)

        model = initialize_levit_model(num_classes).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

        best_val_loss = float('inf')
        best_model = None
        epochs_no_improve = 0
        start_time = time.time()

        for epoch in range(num_epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, labels in train_loader:
                try:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item() * inputs.size(0)
                    _, predicted = torch.max(outputs.data, 1)
                    train_total += labels.size(0)
                    train_correct += (predicted == labels).sum().item()
                except RuntimeError as e:
                    print(f"RuntimeError in training loop: {e}")
                    print(f"Input shape: {inputs.shape}")
                    continue

            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for inputs, labels in val_loader:
                    try:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        val_loss += loss.item() * inputs.size(0)
                        _, predicted = torch.max(outputs.data, 1)
                        val_total += labels.size(0)
                        val_correct += (predicted == labels).sum().item()
                    except RuntimeError as e:
                        print(f"RuntimeError in validation loop: {e}")
                        print(f"Input shape: {inputs.shape}")
                        continue

            # Calculate average losses and accuracies
            train_loss = train_loss / len(train_loader.dataset)
            val_loss = val_loss / len(val_loader.dataset)
            train_acc = train_correct / train_total
            val_acc = val_correct / val_total

            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            scheduler.step(val_loss)

            # Save the first model state or if we have a new best validation loss
            if best_model is None or val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model.state_dict()
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve == patience:
                print('Early stopping!')
                break

        end_time = time.time()
        training_time = end_time - start_time
        results.append({
            'Fold': fold+1,
            'Best Validation Loss': best_val_loss,
            'Training Time (s)': training_time
        })

    # Final training on entire dataset
    print('FINAL TRAINING')
    print('--------------------------------')
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

    final_model = initialize_levit_model(num_classes).to(device)
    optimizer = optim.Adam(final_model.parameters(), lr=0.001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

    best_val_loss = float('inf')
    best_model = None  # Initialize best_model
    epochs_no_improve = 0
    start_time = time.time()

    for epoch in range(num_epochs):
        # Training phase
        final_model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            try:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = final_model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
            except RuntimeError as e:
                print(f"RuntimeError in final training loop: {e}")
                print(f"Input shape: {inputs.shape}")
                continue

        # Validation phase
        final_model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                try:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = final_model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item() * inputs.size(0)
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
                except RuntimeError as e:
                    print(f"RuntimeError in final validation loop: {e}")
                    print(f"Input shape: {inputs.shape}")
                    continue

        # Calculate average losses and accuracies
        train_loss = train_loss / len(train_loader.dataset)
        val_loss = val_loss / len(test_loader.dataset)
        train_acc = train_correct / train_total
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        # Save the first model state or if we have a new best validation loss
        if best_model is None or val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = final_model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print('Early stopping!')
            break

    end_time = time.time()
    final_training_time = end_time - start_time
    results.append({
        'Fold': 'Final',
        'Best Validation Loss': best_val_loss,
        'Training Time (s)': final_training_time
    })

    # Load the best model state
    final_model.load_state_dict(best_model)

    # Save the final model
    torch.save({
        'epoch': len(results),
        'model_state_dict': best_model,
        'results': results
    }, 'final_levit_384_model.pth')

    return results, final_model

In [None]:
warnings.filterwarnings('ignore')

# Disable torch compile to avoid dynamo issues
torch._dynamo.config.suppress_errors = True

if __name__ == "__main__":
    # Setup
    base_path = '/content/drive/MyDrive/Colab Notebooks'
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Create checkpoint directory
    checkpoint_dir = os.path.join(base_path, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Data augmentation and normalization for training
    # Just normalization for validation/testing
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    try:
        # Create the datasets
        train_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
        test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])

        # Print dataset sizes
        print("Dataset sizes:")
        print(f"Training: {len(train_dataset)}")
        print(f"Testing: {len(test_dataset)}")

        # Train the LeViT-384 model
        results, final_model = train_levit_model(train_dataset, test_dataset, num_classes)

        # Display training results
        results_df = pd.DataFrame(results)
        print("\nTraining Results:")
        print(results_df)

        # Save training history
        history_path = os.path.join(base_path, 'levit_384_training_history.csv')  # Changed filename
        results_df.to_csv(history_path, index=False)
        print(f"Training history saved to {history_path}")

        # Evaluate on the test set
        # Reduced batch size for LeViT-384 due to higher memory requirements
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)  # Reduced batch size
        test_acc, precision, recall, f1, auc_roc, test_loss, predictions, true_labels = evaluate_model(
            final_model,
            test_loader,
            tumor_types
        )

        # Create and display the metrics DataFrame
        metrics_df = create_metrics_dataframe(
            final_model,
            test_acc,
            precision,
            recall,
            f1,
            auc_roc,
            results[-1]['Training Time (s)'],
            test_loss
        )
        print("\nModel Metrics:")
        print(metrics_df.to_string(index=False))

        # Save the metrics DataFrame
        metrics_csv_path = os.path.join(base_path, 'levit_384_model_metrics.csv')  # Changed filename
        metrics_df.to_csv(metrics_csv_path, index=False)
        print(f"\nMetrics saved to {metrics_csv_path}")

        # Save the final model with all relevant information
        model_path = os.path.join(base_path, 'levit_384_model_final.pth')  # Changed filename
        torch.save({
            'epoch': len(results),
            'model_state_dict': final_model.state_dict(),
            'results': results,
            'test_metrics': {
                'accuracy': test_acc,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'auc_roc': auc_roc,
                'test_loss': test_loss
            }
        }, model_path)
        print(f"\nModel saved to {model_path}")

        # Save predictions
        predictions_df = pd.DataFrame({
            'True_Label': [tumor_types[i] for i in true_labels],
            'Predicted_Label': [tumor_types[i] for i in predictions]
        })
        predictions_path = os.path.join(base_path, 'levit_384_test_predictions.csv')  # Changed filename
        predictions_df.to_csv(predictions_path, index=False)
        print(f"Test predictions saved to {predictions_path}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        traceback.print_exc()  # This will print the full error traceback
        raise e

    print("\nLeViT-384 training, evaluation, and metrics logging complete!")  # Updated message

# CoAtNets

## Model 8: CoAtNet-0

In [None]:
warnings.filterwarnings('ignore')
torch._dynamo.config.suppress_errors = True

class CoAtNet0Model(nn.Module):
    def __init__(self, num_classes):
        super(CoAtNet0Model, self).__init__()

        # Load pre-trained CoAtNet-0-RW model
        self.coatnet = timm.create_model('coatnet_0_rw_224', pretrained=True, num_classes=0)

        # Enable gradient checkpointing if available
        if hasattr(self.coatnet, 'set_grad_checkpointing'):
            self.coatnet.set_grad_checkpointing(enable=True)

        # Get the number of features from CoAtNet
        coatnet_num_features = self.coatnet.num_features

        # Add final classification layers with reduced size for memory efficiency
        self.fc = nn.Sequential(
            nn.Linear(coatnet_num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.coatnet(x)
        output = self.fc(features)
        return output

def initialize_coatnet_model(num_classes):
    return CoAtNet0Model(num_classes)

In [None]:
def train_coatnet_model(train_dataset, test_dataset, num_classes, num_epochs=50, patience=5, k_folds=5):
    results = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print("Training with pre-trained weights")

    # Set memory-efficient settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Initialize gradient scaler for mixed precision
    scaler = GradScaler()

    criterion = nn.CrossEntropyLoss()

    # Slightly larger batch size since we're fine-tuning
    batch_size = 24

    # K-Fold Cross-validation
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=train_subsampler,
            num_workers=2,
            pin_memory=True
        )
        val_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=val_subsampler,
            num_workers=2,
            pin_memory=True
        )

        model = initialize_coatnet_model(num_classes).to(device)

        # Use different learning rates for pre-trained layers and new layers
        optimizer = optim.AdamW([
            {'params': model.coatnet.parameters(), 'lr': 1e-5},  # Lower learning rate for pre-trained layers
            {'params': model.fc.parameters(), 'lr': 1e-4}       # Higher learning rate for new layers
        ])

        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

        best_val_loss = float('inf')
        best_model = None
        epochs_no_improve = 0
        start_time = time.time()

        for epoch in range(num_epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, labels in train_loader:
                try:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    # Use mixed precision training
                    with autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    # Scale gradients and optimize
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    train_loss += loss.item() * inputs.size(0)
                    _, predicted = torch.max(outputs.data, 1)
                    train_total += labels.size(0)
                    train_correct += (predicted == labels).sum().item()

                    # Clear cache periodically
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                except RuntimeError as e:
                    print(f"RuntimeError in training loop: {e}")
                    print(f"Input shape: {inputs.shape}")
                    continue

            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for inputs, labels in val_loader:
                    try:
                        inputs, labels = inputs.to(device), labels.to(device)
                        with autocast():
                            outputs = model(inputs)
                            loss = criterion(outputs, labels)
                        val_loss += loss.item() * inputs.size(0)
                        _, predicted = torch.max(outputs.data, 1)
                        val_total += labels.size(0)
                        val_correct += (predicted == labels).sum().item()
                    except RuntimeError as e:
                        print(f"RuntimeError in validation loop: {e}")
                        print(f"Input shape: {inputs.shape}")
                        continue

            # Calculate average losses and accuracies
            train_loss = train_loss / len(train_loader.dataset)
            val_loss = val_loss / len(val_loader.dataset)
            train_acc = train_correct / train_total
            val_acc = val_correct / val_total

            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            scheduler.step(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model.state_dict()
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve == patience:
                print('Early stopping!')
                break

            # Clear cache after each epoch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        end_time = time.time()
        training_time = end_time - start_time
        results.append({
            'Fold': fold+1,
            'Best Validation Loss': best_val_loss,
            'Training Time (s)': training_time
        })

    # Save the final model
    torch.save({
        'epoch': len(results),
        'model_state_dict': best_model,
        'results': results
    }, 'final_pretrained_coatnet_0_model.pth')

    return results, model

In [None]:
if __name__ == "__main__":
    # Setup
    base_path = '/content/drive/MyDrive/Colab Notebooks'
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Create checkpoint directory
    checkpoint_dir = os.path.join(base_path, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Data augmentation and normalization
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    try:
        # Create the datasets
        train_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
        test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])

        print("Dataset sizes:")
        print(f"Training: {len(train_dataset)}")
        print(f"Testing: {len(test_dataset)}")

        # Train the CoAtNet-0 model
        results, final_model = train_coatnet_model(train_dataset, test_dataset, num_classes)

        # Display training results
        results_df = pd.DataFrame(results)
        print("\nTraining Results:")
        print(results_df)

        # Save training history
        history_path = os.path.join(base_path, 'coatnet_0_training_history.csv')
        results_df.to_csv(history_path, index=False)
        print(f"Training history saved to {history_path}")

        # Evaluate on the test set
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
        test_acc, precision, recall, f1, auc_roc, test_loss, predictions, true_labels = evaluate_model(
            final_model,
            test_loader,
            tumor_types
        )

        # Create and display metrics
        metrics_df = create_metrics_dataframe(
            final_model,
            test_acc,
            precision,
            recall,
            f1,
            auc_roc,
            results[-1]['Training Time (s)'],
            test_loss
        )
        print("\nModel Metrics:")
        print(metrics_df.to_string(index=False))

        # Save metrics
        metrics_csv_path = os.path.join(base_path, 'coatnet_0_model_metrics.csv')
        metrics_df.to_csv(metrics_csv_path, index=False)
        print(f"\nMetrics saved to {metrics_csv_path}")

        # Save final model
        model_path = os.path.join(base_path, 'coatnet_0_model_final.pth')
        torch.save({
            'epoch': len(results),
            'model_state_dict': final_model.state_dict(),
            'results': results,
            'test_metrics': {
                'accuracy': test_acc,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'auc_roc': auc_roc,
                'test_loss': test_loss
            }
        }, model_path)
        print(f"\nModel saved to {model_path}")

        # Save predictions
        predictions_df = pd.DataFrame({
            'True_Label': [tumor_types[i] for i in true_labels],
            'Predicted_Label': [tumor_types[i] for i in predictions]
        })
        predictions_path = os.path.join(base_path, 'coatnet_0_test_predictions.csv')
        predictions_df.to_csv(predictions_path, index=False)
        print(f"Test predictions saved to {predictions_path}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        traceback.print_exc()
        raise e

    print("\nCoAtNet-0 training, evaluation, and metrics logging complete!")

Since the comp complexity a.k.a time to train is too long I will make the educated guess to not experiment with more complex co-at-net models based on the fact that research shows that accuracy doesn't improve all that much.

Due to it's complex hybrid architecture, CoAtNet is nearly impossible to interpret

## Model 9: CoAtNet1

In [None]:
warnings.filterwarnings('ignore')
torch._dynamo.config.suppress_errors = True

class CoAtNet1Model(nn.Module):
    def __init__(self, num_classes):
        super(CoAtNet1Model, self).__init__()

        # Load pre-trained CoAtNet-1-RW model
        self.coatnet = timm.create_model('coatnet_1_rw_224', pretrained=True, num_classes=0)

        # Enable gradient checkpointing if available
        if hasattr(self.coatnet, 'set_grad_checkpointing'):
            self.coatnet.set_grad_checkpointing(enable=True)

        # Get the number of features from CoAtNet-1
        coatnet_num_features = self.coatnet.num_features

        # Add final classification layers with reduced size for memory efficiency
        # Note: CoAtNet-1 has more features than CoAtNet-0, so we add an extra reduction layer
        self.fc = nn.Sequential(
            nn.Linear(coatnet_num_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.coatnet(x)
        output = self.fc(features)
        return output

def initialize_coatnet_model(num_classes):
    return CoAtNet1Model(num_classes)


In [None]:
def train_coatnet_model(train_dataset, test_dataset, num_classes, num_epochs=100, patience=10, k_folds=5):
    results = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print("Training with pre-trained CoAtNet-1 weights")

    # Set memory-efficient settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Initialize gradient scaler for mixed precision
    scaler = GradScaler()

    criterion = nn.CrossEntropyLoss()

    # Reduced batch size for CoAtNet-1 since it's larger
    batch_size = 16

    # K-Fold Cross-validation
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=train_subsampler,
            num_workers=2,
            pin_memory=True
        )
        val_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=val_subsampler,
            num_workers=2,
            pin_memory=True
        )

        model = initialize_coatnet_model(num_classes).to(device)

        # Adjusted learning rates for CoAtNet-1
        optimizer = optim.AdamW([
            {'params': model.coatnet.parameters(), 'lr': 5e-6},  # Lower learning rate for pre-trained layers
            {'params': model.fc.parameters(), 'lr': 5e-5}       # Lower learning rate for new layers
        ])

        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

        best_val_loss = float('inf')
        best_model = None
        epochs_no_improve = 0
        start_time = time.time()

        for epoch in range(num_epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, labels in train_loader:
                try:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    # Use mixed precision training
                    with autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    # Scale gradients and optimize
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    train_loss += loss.item() * inputs.size(0)
                    _, predicted = torch.max(outputs.data, 1)
                    train_total += labels.size(0)
                    train_correct += (predicted == labels).sum().item()

                    # Clear cache more frequently for CoAtNet-1
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                except RuntimeError as e:
                    print(f"RuntimeError in training loop: {e}")
                    print(f"Input shape: {inputs.shape}")
                    continue

            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for inputs, labels in val_loader:
                    try:
                        inputs, labels = inputs.to(device), labels.to(device)
                        with autocast():
                            outputs = model(inputs)
                            loss = criterion(outputs, labels)
                        val_loss += loss.item() * inputs.size(0)
                        _, predicted = torch.max(outputs.data, 1)
                        val_total += labels.size(0)
                        val_correct += (predicted == labels).sum().item()
                    except RuntimeError as e:
                        print(f"RuntimeError in validation loop: {e}")
                        print(f"Input shape: {inputs.shape}")
                        continue

            # Calculate average losses and accuracies
            train_loss = train_loss / len(train_loader.dataset)
            val_loss = val_loss / len(val_loader.dataset)
            train_acc = train_correct / train_total
            val_acc = val_correct / val_total

            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            scheduler.step(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model.state_dict()
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve == patience:
                print('Early stopping!')
                break

            # Clear cache after each epoch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        end_time = time.time()
        training_time = end_time - start_time
        results.append({
            'Fold': fold+1,
            'Best Validation Loss': best_val_loss,
            'Training Time (s)': training_time
        })

    # Save the final model
    torch.save({
        'epoch': len(results),
        'model_state_dict': best_model,
        'results': results
    }, 'final_pretrained_coatnet_1_model.pth')

    return results, model

In [None]:
if __name__ == "__main__":
    # Setup
    base_path = '/content/drive/MyDrive/Colab Notebooks'
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    print("Initializing CoAtNet-1 training pipeline...")
    print(f"Number of classes: {num_classes}")
    print(f"Tumor types: {tumor_types}")

    # Create checkpoint directory
    checkpoint_dir = os.path.join(base_path, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Data augmentation and normalization for CoAtNet-1
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.RandomAdjustSharpness(0.2),  # Added for CoAtNet-1
            transforms.RandomAutocontrast(),        # Added for CoAtNet-1
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    try:
        # Create the datasets
        train_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
        test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])

        print("\nDataset sizes:")
        print(f"Training: {len(train_dataset)}")
        print(f"Testing: {len(test_dataset)}")

        # Train the CoAtNet-1 model
        print("\nInitiating CoAtNet-1 training...")
        results, final_model = train_coatnet_model(
            train_dataset,
            test_dataset,
            num_classes,
            num_epochs=50,    # Adjust these parameters as needed
            patience=5,
            k_folds=5
        )

        # Display training results
        results_df = pd.DataFrame(results)
        print("\nTraining Results:")
        print(results_df)

        # Save training history
        history_path = os.path.join(base_path, 'coatnet_1_training_history.csv')
        results_df.to_csv(history_path, index=False)
        print(f"\nTraining history saved to {history_path}")

        # Evaluate on the test set
        print("\nEvaluating model on test set...")
        test_loader = DataLoader(
            test_dataset,
            batch_size=16,  # Reduced batch size for CoAtNet-1
            shuffle=False,
            num_workers=2,   # Reduced workers for memory efficiency
            pin_memory=True
        )

        test_acc, precision, recall, f1, auc_roc, test_loss, predictions, true_labels = evaluate_model(
            final_model,
            test_loader,
            tumor_types
        )

        # Create and display metrics
        metrics_df = create_metrics_dataframe(
            final_model,
            test_acc,
            precision,
            recall,
            f1,
            auc_roc,
            results[-1]['Training Time (s)'],
            test_loss
        )
        print("\nModel Metrics:")
        print(metrics_df.to_string(index=False))

        # Save metrics
        metrics_csv_path = os.path.join(base_path, 'coatnet_1_model_metrics.csv')
        metrics_df.to_csv(metrics_csv_path, index=False)
        print(f"\nMetrics saved to {metrics_csv_path}")

        # Save confusion matrix
        cm = confusion_matrix(true_labels, predictions)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=tumor_types,
                   yticklabels=tumor_types)
        plt.title('CoAtNet-1 Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(os.path.join(base_path, 'coatnet_1_confusion_matrix.png'))
        plt.close()

        # Save final model
        model_path = os.path.join(base_path, 'coatnet_1_model_final.pth')
        torch.save({
            'epoch': len(results),
            'model_state_dict': final_model.state_dict(),
            'results': results,
            'test_metrics': {
                'accuracy': test_acc,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'auc_roc': auc_roc,
                'test_loss': test_loss
            }
        }, model_path)
        print(f"\nModel saved to {model_path}")

        # Save predictions
        predictions_df = pd.DataFrame({
            'True_Label': [tumor_types[i] for i in true_labels],
            'Predicted_Label': [tumor_types[i] for i in predictions]
        })
        predictions_path = os.path.join(base_path, 'coatnet_1_test_predictions.csv')
        predictions_df.to_csv(predictions_path, index=False)
        print(f"Test predictions saved to {predictions_path}")

        # Print final summary
        print("\nFinal Summary:")
        print(f"Test Accuracy: {test_acc:.4f}")
        print(f"Average F1 Score: {np.mean(f1):.4f}")
        print(f"Average AUC-ROC: {np.mean(auc_roc):.4f}")
        print(f"Total Training Time: {results[-1]['Training Time (s)']:.2f} seconds")

    except Exception as e:
        print(f"\nAn error occurred: {str(e)}")
        traceback.print_exc()
        raise e

    print("\nCoAtNet-1 training, evaluation, and metrics logging complete!")

# XCiTs: Cross Covariance Image Transformers

## Model 10: XCiT_small_12

In [None]:
warnings.filterwarnings('ignore')
torch._dynamo.config.suppress_errors = True

class XCiTModel(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(XCiTModel, self).__init__()

        # Load XCiT small_12 model with pretrained parameter
        self.xcit = timm.create_model('xcit_small_12_p8_224', pretrained=pretrained, num_classes=0)

        # Enable gradient checkpointing for memory efficiency
        self.xcit.set_grad_checkpointing(enable=True)

        # Get the number of features from XCiT
        xcit_num_features = self.xcit.num_features

        # Add final classification layers
        self.fc = nn.Sequential(
            nn.Linear(xcit_num_features, 512),  # Reduced from 1024 to 512 for memory efficiency
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.xcit(x)
        output = self.fc(features)
        return output

def initialize_xcit_model(num_classes, pretrained=True):
    return XCiTModel(num_classes, pretrained=pretrained)

In [None]:
def train_xcit_model(train_dataset, test_dataset, num_classes, num_epochs=100, patience=10, k_folds=5, pretrained=True):
    results = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Training with {'pretrained' if pretrained else 'randomly initialized'} weights")

    # Set memory-efficient settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Initialize gradient scaler for mixed precision
    scaler = GradScaler()

    criterion = nn.CrossEntropyLoss()

    # Use smaller batch size to prevent OOM errors
    batch_size = 16  # Reduced from 32 but can be larger than XCiT-24 since model is smaller

    # K-Fold Cross-validation
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'FOLD {fold+1}')
        print('--------------------------------')

        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=train_subsampler,
            num_workers=2,
            pin_memory=True
        )
        val_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=val_subsampler,
            num_workers=2,
            pin_memory=True
        )

        model = initialize_xcit_model(num_classes, pretrained=pretrained).to(device)

        # Use different learning rates for pretrained vs non-pretrained
        initial_lr = 0.0001 if pretrained else 0.001
        optimizer = optim.Adam(model.parameters(), lr=initial_lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

        best_val_loss = float('inf')
        best_model = None
        epochs_no_improve = 0
        start_time = time.time()

        for epoch in range(num_epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, labels in train_loader:
                try:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    # Use mixed precision training
                    with autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    # Scale gradients and optimize
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                    train_loss += loss.item() * inputs.size(0)
                    _, predicted = torch.max(outputs.data, 1)
                    train_total += labels.size(0)
                    train_correct += (predicted == labels).sum().item()

                    # Clear cache periodically
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                except RuntimeError as e:
                    print(f"RuntimeError in training loop: {e}")
                    print(f"Input shape: {inputs.shape}")
                    continue

            # Validation phase
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for inputs, labels in val_loader:
                    try:
                        inputs, labels = inputs.to(device), labels.to(device)
                        with autocast():
                            outputs = model(inputs)
                            loss = criterion(outputs, labels)
                        val_loss += loss.item() * inputs.size(0)
                        _, predicted = torch.max(outputs.data, 1)
                        val_total += labels.size(0)
                        val_correct += (predicted == labels).sum().item()
                    except RuntimeError as e:
                        print(f"RuntimeError in validation loop: {e}")
                        print(f"Input shape: {inputs.shape}")
                        continue

            # Calculate average losses and accuracies
            train_loss = train_loss / len(train_loader.dataset)
            val_loss = val_loss / len(val_loader.dataset)
            train_acc = train_correct / train_total
            val_acc = val_correct / val_total

            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            scheduler.step(val_loss)

            if best_model is None or val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model.state_dict()
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve == patience:
                print('Early stopping!')
                break

            # Clear cache after each epoch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        end_time = time.time()
        training_time = end_time - start_time
        results.append({
            'Fold': fold+1,
            'Best Validation Loss': best_val_loss,
            'Training Time (s)': training_time
        })

    # Save the final model
    torch.save({
        'epoch': len(results),
        'model_state_dict': best_model,
        'results': results,
        'pretrained': pretrained
    }, 'final_xcit_small_12_model.pth')

    return results, model

In [None]:
if __name__ == "__main__":
    # Setup
    base_path = '/content/drive/MyDrive/Colab Notebooks'
    train_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Training'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
    tumor_types = ['glioma', 'meningioma', 'notumor', 'pituitary']
    num_classes = len(tumor_types)

    # Create checkpoint directory
    checkpoint_dir = os.path.join(base_path, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Data augmentation and normalization
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    try:
        # Create the datasets
        train_dataset = MRIDataset(train_folder, tumor_types, transform=data_transforms['train'])
        test_dataset = MRIDataset(test_folder, tumor_types, transform=data_transforms['val'])

        print("Dataset sizes:")
        print(f"Training: {len(train_dataset)}")
        print(f"Testing: {len(test_dataset)}")

        # Train the XCiT model
        results, final_model = train_xcit_model(train_dataset, test_dataset, num_classes)

        # Display training results
        results_df = pd.DataFrame(results)
        print("\nTraining Results:")
        print(results_df)

        # Save training history
        history_path = os.path.join(base_path, 'xcit_small_12_training_history.csv')
        results_df.to_csv(history_path, index=False)
        print(f"Training history saved to {history_path}")

        # Evaluate on the test set
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
        test_acc, precision, recall, f1, auc_roc, test_loss, predictions, true_labels = evaluate_model(
            final_model,
            test_loader,
            tumor_types
        )

        # Create and display metrics
        metrics_df = create_metrics_dataframe(
            final_model,
            test_acc,
            precision,
            recall,
            f1,
            auc_roc,
            results[-1]['Training Time (s)'],
            test_loss
        )
        print("\nModel Metrics:")
        print(metrics_df.to_string(index=False))

        # Save metrics
        metrics_csv_path = os.path.join(base_path, 'xcit_small_12_model_metrics.csv')
        metrics_df.to_csv(metrics_csv_path, index=False)
        print(f"\nMetrics saved to {metrics_csv_path}")

        # Save final model
        model_path = os.path.join(base_path, 'xcit_small_12_model_final.pth')
        torch.save({
            'epoch': len(results),
            'model_state_dict': final_model.state_dict(),
            'results': results,
            'test_metrics': {
                'accuracy': test_acc,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'auc_roc': auc_roc,
                'test_loss': test_loss
            }
        }, model_path)
        print(f"\nModel saved to {model_path}")

        # Save predictions
        predictions_df = pd.DataFrame({
            'True_Label': [tumor_types[i] for i in true_labels],
            'Predicted_Label': [tumor_types[i] for i in predictions]
        })
        predictions_path = os.path.join(base_path, 'xcit_small_12_test_predictions.csv')
        predictions_df.to_csv(predictions_path, index=False)
        print(f"Test predictions saved to {predictions_path}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        traceback.print_exc()
        raise e

    print("\nXCiT Small 12 training, evaluation, and metrics logging complete!")

# Visualising The Results

### Loading the model metrics

In [None]:
def load_model_metrics():
    # Dictionary to store file paths
    file_paths = {
        'ResNet50': '/content/drive/MyDrive/Colab Notebooks/resnet50_metrics.csv',
        'ResNet101': '/content/drive/MyDrive/Colab Notebooks/resnet101_metrics.csv',
        'EfficientNetB0': '/content/drive/MyDrive/Colab Notebooks/efficientnet_b0_model_metrics.csv',
        'EfficientNetB1': '/content/drive/MyDrive/Colab Notebooks/efficientnet_b1_model_metrics.csv',
        'Vit': '/content/drive/MyDrive/Colab Notebooks/vit_small_model_metrics.csv',
        'Levit256': '/content/drive/MyDrive/Colab Notebooks/levit_256_model_metrics.csv',
        'Levit384': '/content/drive/MyDrive/Colab Notebooks/levit_384_model_metrics.csv',
        'CoAtNet0': '/content/drive/MyDrive/Colab Notebooks/coatnet_0_model_metrics.csv',
        'CoAtNet1': '/content/drive/MyDrive/Colab Notebooks/coatnet_1_model_metrics.csv',
        'XCiT': '/content/drive/MyDrive/Colab Notebooks/xcit_small_12_model_metrics.csv'
    }

    # Dictionary to store dataframes
    dataframes = {}

    # Load each CSV file into a dataframe
    for model_name, file_path in file_paths.items():
        try:                                #error handling
            if os.path.exists(file_path):
                df = pd.read_csv(file_path)
                dataframes[model_name] = df
                print(f"Successfully loaded {model_name} metrics")
            else:
                print(f"Warning: File not found for {model_name} at {file_path}")
        except Exception as e:
            print(f"Error loading {model_name} metrics: {str(e)}")

    return dataframes

# Load all model metrics
model_metrics = load_model_metrics()

# You can access individual dataframes like this:
# resnet50_df = model_metrics['ResNet50']
# efficientnet_b0_df = model_metrics['EfficientNetB0']
# etc.

In [None]:
resnet50_df = model_metrics['ResNet50']
resnet50_df

In [None]:
def merge_model_metrics(model_metrics):
    """
    Merge all model metrics dataframes into a single dataframe where each row
    represents a model and columns are the different metrics.

    Args:
        model_metrics (dict): Dictionary of dataframes keyed by model name

    Returns:
        pandas.DataFrame: Combined dataframe with one row per model
    """
    # Initialize list to store model data
    all_model_data = []

    # Process each model's dataframe
    for model_name, df in model_metrics.items():
        # Create a dictionary for this model's metrics
        model_data = {'model': model_name}

        # Add each metric to the dictionary
        for _, row in df.iterrows():
            metric_name = row['Metric'].lower().replace(' ', '_')
            if '(' in metric_name:  # Handle metrics with units
                metric_name = metric_name.split('(')[0].strip('_')
            model_data[metric_name] = row['Value']

        all_model_data.append(model_data)

    # Create final dataframe
    merged_df = pd.DataFrame(all_model_data)

    # Ensure consistent column ordering
    desired_columns = ['model', 'overall_accuracy', 'f1_score',
                      'cross_entropy_loss', 'training_time',
                      'number_of_parameters', 'model_size']

    # Filter columns to only include those that exist
    final_columns = ['model'] + [col for col in desired_columns[1:]
                                if col in merged_df.columns]

    merged_df = merged_df[final_columns]

    print(f"Successfully merged {len(model_metrics)} models")
    print(f"Columns in final dataframe: {merged_df.columns.tolist()}")

    return merged_df

In [None]:
# Using the model_metrics dictionary from the previous code
merged_metrics = merge_model_metrics(model_metrics)

merged_metrics

### Best accuracy (bar chart)

In [None]:
# Sort the dataframe by accuracy
sorted_metrics = merged_metrics.sort_values('overall_accuracy')

# Create figure with larger size
plt.figure(figsize=(12, 6))

# Create the bar plot with sorted data
bars = plt.bar(sorted_metrics['model'], sorted_metrics['overall_accuracy'], color='cornflowerblue')

# Customize the plot
plt.title('Model Accuracy Comparison', pad=20, size=14)
plt.xlabel('Model', labelpad=10)
plt.ylabel('Accuracy', labelpad=10)

# Rotate x-axis labels for better readability
plt.xticks(rotation=45, ha='right')

# Set y-axis limits to focus on 85-100% range with extra space for labels
plt.ylim(0.85, 1.01)

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7, axis='y')

# Format y-axis as percentages
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1%}'.format(y)))

# Add value labels on top of each bar with padding
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.002,
             f'{height:.1%}',
             ha='center', va='bottom')

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

### Best F1 scores (Bar Chart)

In [None]:
# Sort the dataframe by F1 score
sorted_metrics = merged_metrics.sort_values('f1_score')

# Create figure with larger size
plt.figure(figsize=(12, 6))

# Create the bar plot with sorted data
bars = plt.bar(sorted_metrics['model'], sorted_metrics['f1_score'], color='cornflowerblue')

# Customize the plot
plt.title('Model F1 Score Comparison', pad=20, size=14)
plt.xlabel('Model', labelpad=10)
plt.ylabel('F1 Score', labelpad=10)

# Rotate x-axis labels for better readability
plt.xticks(rotation=45, ha='right')

# Set y-axis limits to focus on 85-100% range with extra space for labels
plt.ylim(0.85, 1.01)

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7, axis='y')

# Format y-axis as percentages
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1%}'.format(y)))

# Add value labels on top of each bar with padding
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.002,
             f'{height:.1%}',
             ha='center', va='bottom')

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

### Accuracy (x-axis) , F1 score (y -axis)

In [None]:
# Create figure with a square aspect ratio
plt.figure(figsize=(10, 10))

# Create the scatter plot
plt.scatter(merged_metrics['overall_accuracy'],
           merged_metrics['f1_score'],
           alpha=0.6,
           s=100)  # Increase point size for better visibility

# Add labels for each point
for i, model in enumerate(merged_metrics['model']):
    plt.annotate(model,
                (merged_metrics['overall_accuracy'].iloc[i],
                 merged_metrics['f1_score'].iloc[i]),
                xytext=(5, 5),  # Small offset for label
                textcoords='offset points')

# Add a diagonal line representing y=x
min_val = min(merged_metrics['overall_accuracy'].min(),
              merged_metrics['f1_score'].min())
max_val = max(merged_metrics['overall_accuracy'].max(),
              merged_metrics['f1_score'].max())
plt.plot([min_val, max_val], [min_val, max_val],
         'k--', alpha=0.3, label='y=x')

# Customize the plot
plt.title('Model Performance: Accuracy vs F1 Score', pad=20, size=14)
plt.xlabel('Accuracy', labelpad=10)
plt.ylabel('F1 Score', labelpad=10)

# Format axes as percentages
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: '{:.1%}'.format(x)))
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1%}'.format(y)))

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.3)

# Set axis limits with some padding
padding = 0.01
plt.xlim(min_val - padding, max_val + padding)
plt.ylim(min_val - padding, max_val + padding)

# Add legend
plt.legend()

# Make axes equal to preserve the square aspect ratio
plt.axis('equal')

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

### Accuracy against number of paramters (scatter plot)

In [None]:
import matplotlib.pyplot as plt

# Create figure with larger size
plt.figure(figsize=(10, 6))

# Create scatter plot
plt.scatter(merged_metrics['number_of_parameters'],
           merged_metrics['overall_accuracy'],
           alpha=0.6,
           s=100)  # s controls point size

# Add labels for each point
for i, model in enumerate(merged_metrics['model']):
    plt.annotate(model,
                (merged_metrics['number_of_parameters'].iloc[i],
                 merged_metrics['overall_accuracy'].iloc[i]),
                xytext=(5, 5),
                textcoords='offset points')

# Customize the plot
plt.title('Model Accuracy vs Number of Parameters', pad=20, size=14)
plt.xlabel('Number of Parameters (millions)', labelpad=10)
plt.ylabel('Accuracy', labelpad=10)

# Format x-axis in millions
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: '{:.1f}M'.format(x/1e6)))

# Format y-axis as percentages
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1%}'.format(y)))

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.3)

# Set y-axis limits to focus on relevant range
plt.ylim(0.85, 1.02)

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

### Time to train (bar chart)

In [None]:
import matplotlib.pyplot as plt

# Sort the dataframe by training time
sorted_metrics = merged_metrics.sort_values('training_time')

# Create figure with larger size
plt.figure(figsize=(12, 6))

# Create the bar plot with sorted data
bars = plt.bar(sorted_metrics['model'], sorted_metrics['training_time'], color='cornflowerblue')

# Customize the plot
plt.title('Model Training Time Comparison', pad=20, size=14)
plt.xlabel('Model', labelpad=10)
plt.ylabel('Training Time (hours)', labelpad=10)

# Rotate x-axis labels for better readability
plt.xticks(rotation=45, ha='right')

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7, axis='y')

# Format y-axis in hours
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1f}h'.format(y/3600)))

# Add value labels on top of each bar with padding
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 50,  # Added padding
             f'{height/3600:.1f}h',
             ha='center', va='bottom')

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

### Accuracy vs Training time

In [None]:
# Create figure with a larger width to better show time distribution
plt.figure(figsize=(12, 8))

# Create the scatter plot
plt.scatter(merged_metrics['training_time'],
           merged_metrics['overall_accuracy'],
           alpha=0.6,
           s=100)  # Increase point size for better visibility

# Add labels for each point
for i, model in enumerate(merged_metrics['model']):
    plt.annotate(model,
                (merged_metrics['training_time'].iloc[i],
                 merged_metrics['overall_accuracy'].iloc[i]),
                xytext=(5, 5),  # Small offset for label
                textcoords='offset points')

# Customize the plot
plt.title('Model Performance: Accuracy vs Training Time', pad=20, size=14)
plt.xlabel('Training Time (seconds)', labelpad=10)
plt.ylabel('Accuracy', labelpad=10)

# Format y-axis as percentage
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1%}'.format(y)))

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.3)

# Set y-axis limits to focus on the relevant accuracy range with padding
y_min = merged_metrics['overall_accuracy'].min() - 0.02
y_max = merged_metrics['overall_accuracy'].max() + 0.02
plt.ylim(y_min, y_max)

# Add some padding to x-axis
x_max = merged_metrics['training_time'].max() * 1.05
plt.xlim(0, x_max)

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

Vit and ResNet101 stick out as being the less efficient models to train

### Training time against number of parameters




In [None]:
from sklearn.metrics import r2_score

# Create figure with a larger width
plt.figure(figsize=(12, 8))

# Create the scatter plot with number of parameters scaled to millions
plt.scatter(merged_metrics['number_of_parameters'] / 1e6,
           merged_metrics['training_time'],
           alpha=0.6,
           s=100)  # Increase point size for better visibility

# Add labels for each point
for i, model in enumerate(merged_metrics['model']):
    plt.annotate(model,
                (merged_metrics['number_of_parameters'].iloc[i] / 1e6,
                 merged_metrics['training_time'].iloc[i]),
                xytext=(5, 5),  # Small offset for label
                textcoords='offset points')

# Customize the plot
plt.title('Model Training Time vs Number of Parameters', pad=20, size=14)
plt.xlabel('Number of Parameters (millions)', labelpad=10)
plt.ylabel('Training Time (seconds)', labelpad=10)

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.3)

# Start both axes at 0 since neither can be negative
plt.xlim(0, (merged_metrics['number_of_parameters'].max() / 1e6) * 1.05)
plt.ylim(0, merged_metrics['training_time'].max() * 1.05)

# Add a trend line
z = np.polyfit(merged_metrics['number_of_parameters'] / 1e6,
               merged_metrics['training_time'], 1)
p = np.poly1d(z)
x_trend = np.linspace(0, merged_metrics['number_of_parameters'].max() / 1e6, 100)
plt.plot(x_trend, p(x_trend), "r--", alpha=0.8,
         label=f'Trend line (R² = {r2_score(merged_metrics["training_time"], p(merged_metrics["number_of_parameters"] / 1e6)):.3f})')

# Add legend
plt.legend()

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

Training time depends on model architecture more than number of parameters

### Accuracy vs number of parametes

In [None]:
# Create figure
plt.figure(figsize=(12, 8))

# Create the scatter plot with number of parameters scaled to millions
plt.scatter(merged_metrics['number_of_parameters'] / 1e6,
           merged_metrics['overall_accuracy'],
           alpha=0.6,
           s=100)

# Add labels for each point
for i, model in enumerate(merged_metrics['model']):
    plt.annotate(model,
                (merged_metrics['number_of_parameters'].iloc[i] / 1e6,
                 merged_metrics['overall_accuracy'].iloc[i]),
                xytext=(5, 5),
                textcoords='offset points')

# Customize the plot
plt.title('Model Accuracy vs Number of Parameters', pad=20, size=14)
plt.xlabel('Number of Parameters (millions)', labelpad=10)
plt.ylabel('Accuracy', labelpad=10)

# Format y-axis as percentage
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1%}'.format(y)))

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.3)

# Start x-axis at 0 since parameters can't be negative
plt.xlim(0, (merged_metrics['number_of_parameters'].max() / 1e6) * 1.05)

# Set y-axis limits to focus on the relevant accuracy range with padding
y_min = merged_metrics['overall_accuracy'].min() - 0.02
y_max = merged_metrics['overall_accuracy'].max() + 0.02
plt.ylim(y_min, y_max)

# Add trend line
z = np.polyfit(merged_metrics['number_of_parameters'] / 1e6,
               merged_metrics['overall_accuracy'], 1)
p = np.poly1d(z)
x_trend = np.linspace(0, merged_metrics['number_of_parameters'].max() / 1e6, 100)
plt.plot(x_trend, p(x_trend), "r--", alpha=0.8,
         label=f'Trend line (R² = {r2_score(merged_metrics["overall_accuracy"], p(merged_metrics["number_of_parameters"] / 1e6)):.3f})')

# Add legend
plt.legend()

# Adjust layout
plt.tight_layout()

plt.show()

### Inference time (bar chart)

##### Loading the model

In [None]:
import torch
import os
from typing import Dict, NamedTuple
from torch import nn

class Models(NamedTuple):
    resnet50: nn.Module
    resnet101: nn.Module
    efficientnetb0: nn.Module
    efficientnetb1: nn.Module
    vit: nn.Module
    levit384: nn.Module
    levit256: nn.Module
    xcit: nn.Module

def load_models(num_classes: int, base_path: str = '/content/drive/MyDrive/Colab Notebooks/') -> Models:
    """Load all pretrained models and return them as a named tuple"""
    configs = {
        'resnet50': (initialize_resnet50_model, 'final_mriresnet50_model.pth', None),
        'resnet101': (initialize_model_resnet101, 'final_resnet101_classification_model.pth', None),
        'efficientnetb0': (initialize_efficientnet0_model, 'final_efficientnet_b0_mri_classification_model.pth', None),
        'efficientnetb1': (initialize_efficientnetb1_model, 'final_efficientnet_b1_mri_classification_model.pth', None),
        'vit': (initialize_vit_model, 'final_vit_small_mri_classification_model.pth', None),
        'levit384': (initialize_levit384_model, 'levit_384_model_final.pth', 'model_state_dict'),
        'levit256': (initialize_levit256_model, 'levit_256_model_final.pth', 'model_state_dict'),
        'xcit': (initialize_xcit_model, 'xcit_small_12_model_final.pth', 'model_state_dict')
    }

    loaded = {}
    for name, (initializer, path, state_dict_key) in configs.items():
        model_path = os.path.join(base_path, path)
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"{name} weights not found at {model_path}")

        try:
            model_classes = num_classes if name not in ['levit384', 'levit256', 'xcit'] else 4
            model = initializer(num_classes=model_classes)
            weights = torch.load(model_path)
            if state_dict_key:
                weights = weights[state_dict_key]
            model.load_state_dict(weights)
            model.eval()
            loaded[name] = model
        except Exception as e:
            raise RuntimeError(f"Error loading {name}: {str(e)}")

    return Models(**loaded)

# Load all models into a named tuple
models = load_models(num_classes=len(tumor_types))

# Now you can access models directly:
# models.resnet50
# models.vit
# models.xcit
# etc.

##### Inference

In [None]:
import statistics

def measure_inference_time(model, model_name, test_folder, tumor_types, device='cuda' if torch.cuda.is_available() else 'cpu'):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    model = model.to(device)
    model.eval()
    inference_times = []

    with torch.no_grad():
        for tumor_type in tumor_types:
            tumor_folder = os.path.join(test_folder, tumor_type)
            for img_name in os.listdir(tumor_folder):
                if img_name.endswith(('.jpg', '.jpeg', '.png')):
                    img_path = os.path.join(tumor_folder, img_name)
                    image = Image.open(img_path).convert('RGB')
                    input_tensor = transform(image).unsqueeze(0).to(device)

                    if device == 'cuda':
                        torch.cuda.empty_cache()
                        torch.cuda.synchronize()

                    start_time = time.time()
                    _ = model(input_tensor)
                    torch.cuda.synchronize()
                    inference_times.append(time.time() - start_time)

    avg_time = statistics.mean(inference_times)
    return {
        'Model': model_name,
        'Avg Inference Time (ms)': avg_time * 1000,
        'FPS': 1.0 / avg_time,
        'Total Images': len(inference_times)
    }

# Dictionary mapping for model names
model_names = {
    'ResNet-50': models.resnet50,
    'ResNet-101': models.resnet101,
    'EfficientNet-B0': models.efficientnetb0,
    'EfficientNet-B1': models.efficientnetb1,
    'ViT-Small': models.vit,
    'LeViT-384': models.levit384,
    'LeViT-256': models.levit256,
    'XCiT-Small': models.xcit
}

# Run comparison
results = []
for model_name, model in model_names.items():
    print(f"\nProcessing {model_name}...")
    metrics = measure_inference_time(model, model_name, test_folder, tumor_types)
    results.append(metrics)

# Create and display DataFrame
df = pd.DataFrame(results)
df = df.sort_values('Avg Inference Time (ms)')
df[['Avg Inference Time (ms)', 'FPS']] = df[['Avg Inference Time (ms)', 'FPS']].round(2)

print("\nInference Time Comparison:")
print(df)
df.to_csv('inference_times.csv', index=False)

##### Plotting:

In [None]:
# Clean up model names to match between dataframes
inference_df = df.copy()

# Create a mapping dictionary for model names with exact matches from the first dataframe
name_mapping = {
    'ViT-Small': 'Vit',
    'ResNet-50': 'ResNet50',
    'EfficientNet-B0': 'EfficientNetB0',
    'ResNet-101': 'ResNet101',
    'EfficientNet-B1': 'EfficientNetB1',
    'LeViT-256': 'Levit256',
    'LeViT-384': 'Levit384',
    'XCiT-Small': 'XCiT'  # Fixed to match exactly with first dataframe
}

# Apply the mapping
inference_df['Model'] = inference_df['Model'].map(name_mapping)

# Merge the dataframes
merged_complete = merged_metrics.merge(inference_df, left_on='model', right_on='Model', how='inner')

# Create the bar chart
plt.figure(figsize=(12, 6))

# Create the bar plot with sorted data
sorted_data = merged_complete.sort_values('Avg Inference Time (ms)')
bars = plt.bar(sorted_data['model'], sorted_data['Avg Inference Time (ms)'], color='cornflowerblue')

# Customize the plot
plt.title('Model Inference Time Comparison', pad=20, size=14)
plt.xlabel('Model', labelpad=10)
plt.ylabel('Average Inference Time (ms)', labelpad=10)

# Rotate x-axis labels for better readability
plt.xticks(rotation=45, ha='right')

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7, axis='y')

# Add value labels on top of each bar with padding
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.2,
             f'{height:.2f}ms',
             ha='center', va='bottom')

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Show the plot
plt.show()

# Visualizing Interpretability

### Selecting the same pictures

In [None]:
import os
import random
from PIL import Image
import matplotlib.pyplot as plt

def display_random_samples(test_folder, samples_per_category=10, seed=None):
    """
    Displays random samples of MRI images from each tumor category.

    Args:
        test_folder (str): Path to the test folder containing tumor type subfolders
        samples_per_category (int): Number of samples to display per category
        seed (int, optional): Random seed for reproducibility

    Returns:
        dict: Dictionary containing selected image paths for each category
    """
    if seed is not None:
        random.seed(seed)

    # Get all tumor categories (subfolders)
    tumor_categories = [d for d in os.listdir(test_folder)
                       if os.path.isdir(os.path.join(test_folder, d))]

    selected_samples = {}

    # Select random samples from each category
    for category in tumor_categories:
        category_path = os.path.join(test_folder, category)
        image_files = [f for f in os.listdir(category_path)
                      if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

        # Select random samples
        selected_images = random.sample(image_files,
                                      min(samples_per_category, len(image_files)))
        selected_samples[category] = selected_images

        # Display the samples
        plt.figure(figsize=(20, 4))
        plt.suptitle(f'Category: {category}')

        for idx, img_name in enumerate(selected_images, 1):
            img_path = os.path.join(category_path, img_name)
            img = Image.open(img_path).convert('RGB')

            plt.subplot(2, 5, idx)
            plt.imshow(img)
            plt.title(f'File: {img_name}')
            plt.axis('off')

        plt.tight_layout()
        plt.show()

    return selected_samples

# Example usage:
# selected_images = display_random_samples(test_folder='/content/drive/MyDrive/Colab Notebooks/Resized_Testing')

In [None]:
selected_images = display_random_samples(test_folder='/content/drive/MyDrive/Colab Notebooks/Resized_Testing')

In [None]:
#images :
# Te-gl_0261
# Te-gl_0037
# Te-gl_0131
#Te-me_0224
#Te-me_0231
#Te-me_0162
#Te-pi_0205
#Te-pi_0132
#Te-pi_0110

In [None]:
def load_images(test_folder, image_names):
    # Mapping of short names to full folder names
    folder_map = {
        'gl': 'glioma',
        'me': 'meningioma',
        'pi': 'pituitary'
    }

    plt.figure(figsize=(15, 10))

    for i, name in enumerate(image_names, 1):
        # Get category (gl, me, or pi)
        category_short = name.split('-')[1].split('_')[0]
        category_full = folder_map[category_short]

        # Construct full path
        img_path = os.path.join(test_folder, category_full, name + '.jpg')  # or whatever extension you have

        # Load and display image
        img = Image.open(img_path).convert('RGB')
        plt.subplot(3, 3, i)
        plt.imshow(img)
        plt.title(name)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Images to load
images = [
    'Te-gl_0261',
    'Te-gl_0037',
    'Te-gl_0131',
    'Te-me_0224',
    'Te-me_0231',
    'Te-me_0162',
    'Te-pi_0205',
    'Te-pi_0132',
    'Te-pi_0110'
]

# Use it like this:
load_images('/content/drive/MyDrive/Colab Notebooks/Resized_Testing', images)

### Grad-Cam (Resnets, EfficientNets)

In [None]:
# 1 - take the 9 images
# 2 - does their gradcam for each of the 4 models
# 3 - plot them side by side

In [None]:
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

def get_model_specific_params(model_name):
    """Returns model-specific parameters for GradCAM"""
    params = {
        'resnet50': {
            'target_layer': 'layer4',
            'input_size': 224,
            'layer_getter': lambda model: model.layer4[-1]
        },
        'resnet101': {
            'target_layer': 'layer4',
            'input_size': 224,
            'layer_getter': lambda model: model.layer4[-1]
        },
        'efficientnetb0': {
            'target_layer': 'features.8',
            'input_size': 224,
            'layer_getter': lambda model: model.features[8]
        },
        'efficientnetb1': {
            'target_layer': 'features.8',
            'input_size': 240,
            'layer_getter': lambda model: model.features[8]
        }
    }
    return params[model_name]

def get_gradcam_for_image(model, model_name, image_path, training_order):
    """Generate GradCAM visualization for a single model and image"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model_params = get_model_specific_params(model_name)

        transform = transforms.Compose([
            transforms.Resize((model_params['input_size'], model_params['input_size'])),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(device)

        # Setup GradCAM
        target_layer = model_params['layer_getter'](model)
        grad_cam = GradCAM(
            model=model,
            target_layers=[target_layer],
            reshape_transform=None
        )

        # Get prediction and confidence
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            prediction = torch.argmax(output).item()
            confidence = probabilities[0][prediction].item() * 100

        # Generate GradCAM
        targets = [ClassifierOutputTarget(prediction)]
        grayscale_cam = grad_cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]

        rgb_img = np.array(image.resize((model_params['input_size'], model_params['input_size']))) / 255.0
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        return visualization, confidence, prediction

    except Exception as e:
        print(f"Error processing {image_path} with {model_name}: {str(e)}")
        return None, None, None

def visualize_multiple_models_gradcam(models_dict, image_list, test_folder, training_order):
    """
    Visualize GradCAM for multiple images across different models
    """
    folder_map = {
        'gl': 'glioma',
        'me': 'meningioma',
        'pi': 'pituitary'
    }

    # Calculate grid dimensions
    n_images = len(image_list)
    n_models = len(models_dict)

    plt.figure(figsize=(4 * n_models, 3 * n_images))

    for img_idx, img_name in enumerate(image_list):
        # Get category and construct full path
        category_short = img_name.split('-')[1].split('_')[0]
        category_full = folder_map[category_short]
        img_path = os.path.join(test_folder, category_full, img_name + '.jpg')

        # Load original image for reference
        original_img = Image.open(img_path).convert('RGB')

        # Plot original image
        plt.subplot(n_images, n_models + 1, img_idx * (n_models + 1) + 1)
        plt.imshow(original_img)
        plt.title(f'Original\n{img_name}')
        plt.axis('off')

        # Generate and plot GradCAM for each model
        for model_idx, (model_name, model) in enumerate(models_dict.items(), 1):
            vis, conf, pred = get_gradcam_for_image(model, model_name.lower(), img_path, training_order)

            if vis is not None:
                plt.subplot(n_images, n_models + 1, img_idx * (n_models + 1) + model_idx + 1)
                plt.imshow(vis)
                plt.title(f'{model_name}\nConf: {conf:.1f}%\nPred: {training_order[pred]}')
                plt.axis('off')

    plt.tight_layout()
    plt.show()

# Use it like this:
models_dict = {
    'ResNet50': models.resnet50,
    'ResNet101': models.resnet101,
    'EfficientNetB0': models.efficientnetb0,
    'EfficientNetB1': models.efficientnetb1
}

images_to_process = [
    'Te-gl_0261',
    'Te-gl_0037',
    'Te-gl_0131',
    'Te-me_0224',
    'Te-me_0231',
    'Te-me_0162',
    'Te-pi_0205',
    'Te-pi_0132',
    'Te-pi_0110'
]

training_order = ['glioma', 'meningioma', 'notumor', 'pituitary']

visualize_multiple_models_gradcam(
    models_dict=models_dict,
    image_list=images_to_process,
    test_folder='/content/drive/MyDrive/Colab Notebooks/Resized_Testing',
    training_order=training_order
)

In [None]:
resnet50_model = initialize_resnet50_model(num_classes=len(tumor_types))
resnet50_model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/final_mriresnet50_model.pth'))
resnet50_model.eval()
resnet101_model = initialize_model_resnet101(num_classes=len(tumor_types))
resnet101_model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/final_resnet101_classification_model.pth'))
resnet101_model.eval()
efficientnetb0_model = initialize_efficientnetb1_model(num_classes=len(tumor_types))
efficientnetb0_model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/final_efficientnet_b1_mri_classification_model.pth'))
efficientnetb0_model.eval()
efficientnetb1_model = initialize_efficientnetb1_model(num_classes=len(tumor_types))
efficientnetb1_model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/final_efficientnet_b1_mri_classification_model.pth'))
efficientnetb1_model.eval()

### Guided BP

In [None]:
def get_guided_bp_for_resnet(model, image_path, training_order):
    """Generate Guided Backpropagation visualization for ResNet models"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model.eval()

        guided_bp_model = GuidedBackpropReLUModel(model=model, device=device)

        image = Image.open(image_path)
        if image.mode == 'L':
            image = Image.merge('RGB', (image, image, image))

        preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        input_tensor = preprocess(image).unsqueeze(0).to(device)
        input_tensor.requires_grad = True

        # Forward pass
        output = guided_bp_model.forward(input_tensor)

        # Get prediction and confidence
        probabilities = torch.nn.functional.softmax(output, dim=1)
        prediction = torch.argmax(output).item()
        confidence = probabilities[0][prediction].item() * 100

        # Create one hot output
        one_hot = torch.zeros_like(output)
        one_hot[0][prediction] = 1

        # Backward pass
        output.backward(gradient=one_hot)

        # Get gradients
        guided_bp_mask = input_tensor.grad[0].cpu().numpy()
        guided_bp_mask = guided_bp_mask.transpose(1, 2, 0)

        guided_bp_model.cleanup()

        # Process the visualization
        magnitude = np.sqrt(np.sum(guided_bp_mask**2, axis=2))
        magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)

        high_threshold = np.percentile(magnitude, 98)
        mid_threshold = np.percentile(magnitude, 90)

        visualization = np.zeros((224, 224, 3))
        visualization[:,:,2] = magnitude
        visualization[:,:,1] = magnitude * 0.8

        red_mask = magnitude > high_threshold
        blue_mask = (magnitude > mid_threshold) & (magnitude <= high_threshold)

        visualization[red_mask, 0] = 1.0
        visualization[red_mask, 1] = 0.2
        visualization[red_mask, 2] = 0.2

        visualization[blue_mask, 2] = 1.0
        visualization[blue_mask, 1] = 0.8
        visualization[blue_mask, 0] = 0.2

        visualization = np.power(visualization, 0.7)

        for i in range(3):
            visualization[:,:,i] = gaussian_filter(visualization[:,:,i], sigma=1)

        return visualization, confidence, prediction

    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None, None, None

def get_guided_bp_for_efficientnet(model, image_path, input_size, training_order):
    """Generate Guided Backpropagation visualization for EfficientNet models"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model.eval()

        guided_bp_model = GuidedBackpropReLUModel(model=model, device=device)

        image = Image.open(image_path)
        if image.mode == 'L':
            image = Image.merge('RGB', (image, image, image))

        preprocess = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        input_tensor = preprocess(image).unsqueeze(0).to(device)
        input_tensor.requires_grad = True

        # Forward pass
        output = guided_bp_model.forward(input_tensor)

        # Get prediction and confidence
        probabilities = torch.nn.functional.softmax(output, dim=1)
        prediction = torch.argmax(output).item()
        confidence = probabilities[0][prediction].item() * 100

        # Create one hot output
        one_hot = torch.zeros_like(output)
        one_hot[0][prediction] = 1

        # Backward pass
        output.backward(gradient=one_hot)

        # Get gradients
        guided_bp_mask = input_tensor.grad[0].cpu().numpy()
        guided_bp_mask = guided_bp_mask.transpose(1, 2, 0)

        guided_bp_model.cleanup()

        # Process the visualization
        magnitude = np.sqrt(np.sum(guided_bp_mask**2, axis=2))
        magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)

        high_threshold = np.percentile(magnitude, 98)
        mid_threshold = np.percentile(magnitude, 90)

        visualization = np.zeros((input_size, input_size, 3))
        visualization[:,:,2] = magnitude
        visualization[:,:,1] = magnitude * 0.8

        red_mask = magnitude > high_threshold
        blue_mask = (magnitude > mid_threshold) & (magnitude <= high_threshold)

        visualization[red_mask, 0] = 1.0
        visualization[red_mask, 1] = 0.2
        visualization[red_mask, 2] = 0.2

        visualization[blue_mask, 2] = 1.0
        visualization[blue_mask, 1] = 0.8
        visualization[blue_mask, 0] = 0.2

        visualization = np.power(visualization, 0.7)

        for i in range(3):
            visualization[:,:,i] = gaussian_filter(visualization[:,:,i], sigma=1)

        return visualization, confidence, prediction

    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None, None, None

In [None]:
# First define your models dictionary
models_dict = {
    'ResNet50': resnet50_model,
    'ResNet101': resnet101_model,
    'EfficientNetB0': efficientnetb0_model,
    'EfficientNetB1': efficientnetb1_model
}

# Define the images you want to process
images_to_process = [
    'Te-gl_0261',
    'Te-gl_0037',
    'Te-gl_0131',
    'Te-me_0224',
    'Te-me_0231',
    'Te-me_0162',
    'Te-pi_0205',
    'Te-pi_0132',
    'Te-pi_0110'
]

# Define your training order (class labels)
training_order = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Clear any existing hooks before running
for name, module in next(iter(models_dict.values())).named_modules():
    if hasattr(module, '_forward_hooks'):
        module._forward_hooks.clear()

# Call the visualization function
visualize_guided_bp(
    models_dict=models_dict,
    image_list=images_to_process,
    test_folder='/content/drive/MyDrive/Colab Notebooks/Resized_Testing',
    training_order=training_order
)

## GradCam and Guided BP together

In [None]:
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
import os
import math
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from scipy.ndimage import gaussian_filter

class GuidedBackpropReLUModel:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.forward_relu_outputs = []
        self.forward_hook_handles = []
        self.backward_hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def hook_fn(module, grad_in, grad_out):
            # If there's a negative gradient, change it to zero
            if isinstance(grad_in, tuple):
                return (torch.clamp(grad_in[0], min=0.0),)
            return torch.clamp(grad_in, min=0.0)

        def forward_hook_fn(module, input, output):
            self.forward_relu_outputs.append(output)

        # Register hooks
        for module in self.model.modules():
            if isinstance(module, torch.nn.ReLU):
                self.forward_hook_handles.append(module.register_forward_hook(forward_hook_fn))
                self.backward_hook_handles.append(module.register_backward_hook(hook_fn))

    def forward(self, x):
        self.forward_relu_outputs = []
        return self.model(x)

    def cleanup(self):
        for handle in self.forward_hook_handles:
            handle.remove()
        for handle in self.backward_hook_handles:
            handle.remove()

    def __call__(self, input_img, target_category=None):
        self.model.zero_grad()
        output = self.forward(input_img)

        if target_category is None:
            target_category = output.argmax().item()

        # Create one hot output
        one_hot = torch.zeros_like(output)
        one_hot[0][target_category] = 1

        output.backward(gradient=one_hot)

        # Gradient with respect to input
        gradient = input_img.grad.cpu().data.numpy()[0]
        gradient = gradient.transpose((1, 2, 0))

        return gradient

def get_model_specific_params(model_name):
    """Returns model-specific parameters for GradCAM"""
    params = {
        'resnet50': {
            'target_layer': 'layer4',
            'input_size': 224,
            'layer_getter': lambda model: model.layer4[-1]
        },
        'resnet101': {
            'target_layer': 'layer4',
            'input_size': 224,
            'layer_getter': lambda model: model.layer4[-1]
        },
        'efficientnetb0': {
            'target_layer': 'features.8',
            'input_size': 224,
            'layer_getter': lambda model: model.features[8]
        },
        'efficientnetb1': {
            'target_layer': 'features.8',
            'input_size': 240,
            'layer_getter': lambda model: model.features[8]
        }
    }
    return params[model_name]

def get_guided_bp_for_resnet(model, image_path, training_order):
    """Generate Guided Backpropagation visualization for ResNet models"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model.eval()

        guided_bp_model = GuidedBackpropReLUModel(model=model, device=device)

        image = Image.open(image_path)
        if image.mode == 'L':
            image = Image.merge('RGB', (image, image, image))

        preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        input_tensor = preprocess(image).unsqueeze(0).to(device)
        input_tensor.requires_grad = True

        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            prediction = torch.argmax(output).item()
            confidence = probabilities[0][prediction].item() * 100

        guided_bp_mask = guided_bp_model(input_tensor, target_category=prediction)
        guided_bp_model.cleanup()

        if isinstance(guided_bp_mask, torch.Tensor):
            guided_bp_mask = guided_bp_mask.cpu().numpy()

        magnitude = np.sqrt(np.sum(guided_bp_mask**2, axis=2))
        magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)

        high_threshold = np.percentile(magnitude, 98)
        mid_threshold = np.percentile(magnitude, 90)

        visualization = np.zeros((224, 224, 3))
        visualization[:,:,2] = magnitude
        visualization[:,:,1] = magnitude * 0.8

        red_mask = magnitude > high_threshold
        blue_mask = (magnitude > mid_threshold) & (magnitude <= high_threshold)

        visualization[red_mask, 0] = 1.0
        visualization[red_mask, 1] = 0.2
        visualization[red_mask, 2] = 0.2

        visualization[blue_mask, 2] = 1.0
        visualization[blue_mask, 1] = 0.8
        visualization[blue_mask, 0] = 0.2

        visualization = np.power(visualization, 0.7)

        for i in range(3):
            visualization[:,:,i] = gaussian_filter(visualization[:,:,i], sigma=1)

        return visualization, confidence, prediction

    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None, None, None

def get_guided_bp_for_efficientnet(model, image_path, input_size, training_order):
    """Generate Guided Backpropagation visualization for EfficientNet models"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model.eval()

        guided_bp_model = GuidedBackpropReLUModel(model=model, device=device)

        image = Image.open(image_path)
        if image.mode == 'L':
            image = Image.merge('RGB', (image, image, image))

        preprocess = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        input_tensor = preprocess(image).unsqueeze(0).to(device)
        input_tensor.requires_grad = True

        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            prediction = torch.argmax(output).item()
            confidence = probabilities[0][prediction].item() * 100

        guided_bp_mask = guided_bp_model(input_tensor, target_category=prediction)
        guided_bp_model.cleanup()

        if isinstance(guided_bp_mask, torch.Tensor):
            guided_bp_mask = guided_bp_mask.cpu().numpy()

        magnitude = np.sqrt(np.sum(guided_bp_mask**2, axis=2))
        magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)

        high_threshold = np.percentile(magnitude, 98)
        mid_threshold = np.percentile(magnitude, 90)

        visualization = np.zeros((input_size, input_size, 3))
        visualization[:,:,2] = magnitude
        visualization[:,:,1] = magnitude * 0.8

        red_mask = magnitude > high_threshold
        blue_mask = (magnitude > mid_threshold) & (magnitude <= high_threshold)

        visualization[red_mask, 0] = 1.0
        visualization[red_mask, 1] = 0.2
        visualization[red_mask, 2] = 0.2

        visualization[blue_mask, 2] = 1.0
        visualization[blue_mask, 1] = 0.8
        visualization[blue_mask, 0] = 0.2

        visualization = np.power(visualization, 0.7)

        for i in range(3):
            visualization[:,:,i] = gaussian_filter(visualization[:,:,i], sigma=1)

        return visualization, confidence, prediction

    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None, None, None

def get_gradcam_for_image(model, model_name, image_path, training_order):
    """Generate GradCAM visualization for a single model and image"""
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        model_params = get_model_specific_params(model_name)

        transform = transforms.Compose([
            transforms.Resize((model_params['input_size'], model_params['input_size'])),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        image = Image.open(image_path).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(device)

        target_layer = model_params['layer_getter'](model)
        grad_cam = GradCAM(
            model=model,
            target_layers=[target_layer],
            reshape_transform=None
        )

        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            prediction = torch.argmax(output).item()
            confidence = probabilities[0][prediction].item() * 100

        targets = [ClassifierOutputTarget(prediction)]
        grayscale_cam = grad_cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]

        rgb_img = np.array(image.resize((model_params['input_size'], model_params['input_size']))) / 255.0
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        return visualization, confidence, prediction

    except Exception as e:
        print(f"Error processing {image_path} with {model_name}: {str(e)}")
        return None, None, None

def visualize_all_methods(models_dict, image_list, test_folder, training_order):
    """
    Visualize both GradCAM and Guided Backpropagation for all models
    """
    folder_map = {
        'gl': 'glioma',
        'me': 'meningioma',
        'pi': 'pituitary'
    }

    n_images = len(image_list)
    n_models = len(models_dict)
    n_cols = 1 + 2 * n_models

    plt.figure(figsize=(4 * n_cols, 3 * n_images))

    for img_idx, img_name in enumerate(image_list):
        category_short = img_name.split('-')[1].split('_')[0]
        category_full = folder_map[category_short]
        img_path = os.path.join(test_folder, category_full, img_name + '.jpg')

        original_img = Image.open(img_path).convert('RGB')
        plt.subplot(n_images, n_cols, img_idx * n_cols + 1)
        plt.imshow(original_img)
        plt.title(f'Original\n{img_name}')
        plt.axis('off')

        for model_idx, (model_name, model) in enumerate(models_dict.items(), 1):
            gradcam_vis, gradcam_conf, gradcam_pred = get_gradcam_for_image(
                model,
                model_name.lower(),
                img_path,
                training_order
            )

            if model_name.startswith('EfficientNet'):
                input_size = 224 if 'B0' in model_name else 240
                guided_vis, guided_conf, guided_pred = get_guided_bp_for_efficientnet(
                    model,
                    img_path,
                    input_size,
                    training_order
                )
            else:
                guided_vis, guided_conf, guided_pred = get_guided_bp_for_resnet(
                    model,
                    img_path,
                    training_order
                )

            col_idx = 2 * model_idx
            plt.subplot(n_images, n_cols, img_idx * n_cols + col_idx)
            if gradcam_vis is not None:
                plt.imshow(gradcam_vis)
                plt.title(f'{model_name} GradCAM\nConf: {gradcam_conf:.1f}%\n'
                         f'Pred: {training_order[gradcam_pred]}')
            plt.axis('off')

            plt.subplot(n_images, n_cols, img_idx * n_cols + col_idx + 1)
            if guided_vis is not None:
                plt.imshow(guided_vis)
                plt.title(f'{model_name} GuidedBP\nConf: {guided_conf:.1f}%\n'
                         f'Pred: {training_order[guided_pred]}')
            plt.axis('off')

    plt.tight_layout()
    plt.show()

# Usage example:

models_dict = {
    'ResNet50': resnet50_model,
    'ResNet101': resnet101_model,
    'EfficientNetB0': efficientnetb0_model,
    'EfficientNetB1': efficientnetb1_model
}

images_to_process = [
    'Te-gl_0261',
    'Te-gl_0037',
    'Te-gl_0131',
    'Te-me_0224',
    'Te-me_0231',
    'Te-me_0162',
    'Te-pi_0205',
    'Te-pi_0132',
    'Te-pi_0110'
]

training_order = ['glioma', 'meningioma', 'notumor', 'pituitary']

# Clear any existing hooks
for name, module in next(iter(models_dict.values())).named_modules():
    if hasattr(module, '_forward_hooks'):
        module._forward_hooks.clear()

# Run visualization
visualize_all_methods(
    models_dict=models_dict,
    image_list=images_to_process,
    test_folder='/content/drive/MyDrive/Colab Notebooks/Resized_Testing',
    training_order=training_order
)

In [None]:
def visualize_all_methods_organized(models_dict, image_list, test_folder, training_order):
    """
    Visualize both GradCAM and Guided Backpropagation for all models,
    organizing each model's results on a separate row
    """
    folder_map = {
        'gl': 'glioma',
        'me': 'meningioma',
        'pi': 'pituitary'
    }

    n_images = len(image_list)
    n_models = len(models_dict)
    n_cols = 3  # Original + GradCAM + GuidedBP
    n_rows_per_image = n_models + 1  # One row per model plus original image row

    plt.figure(figsize=(12, 3 * n_images * n_rows_per_image))

    for img_idx, img_name in enumerate(image_list):
        base_row = img_idx * n_rows_per_image
        category_short = img_name.split('-')[1].split('_')[0]
        category_full = folder_map[category_short]
        img_path = os.path.join(test_folder, category_full, img_name + '.jpg')

        # Load original image
        original_img = Image.open(img_path).convert('RGB')

        # First row: Original image
        plt.subplot(n_images * n_rows_per_image, n_cols, base_row * n_cols + 1)
        plt.imshow(original_img)
        plt.title(f'Original\n{img_name}')
        plt.axis('off')

        # Hide empty plots in first row
        plt.subplot(n_images * n_rows_per_image, n_cols, base_row * n_cols + 2)
        plt.axis('off')
        plt.subplot(n_images * n_rows_per_image, n_cols, base_row * n_cols + 3)
        plt.axis('off')

        # Process each model in subsequent rows
        for model_idx, (model_name, model) in enumerate(models_dict.items(), 1):
            current_row = base_row + model_idx
            row_start = current_row * n_cols + 1

            # Get GradCAM visualization
            gradcam_vis, gradcam_conf, gradcam_pred = get_gradcam_for_image(
                model,
                model_name.lower(),
                img_path,
                training_order
            )

            # Get Guided BP visualization
            if model_name.startswith('EfficientNet'):
                input_size = 224 if 'B0' in model_name else 240
                guided_vis, guided_conf, guided_pred = get_guided_bp_for_efficientnet(
                    model,
                    img_path,
                    input_size,
                    training_order
                )
            else:
                guided_vis, guided_conf, guided_pred = get_guided_bp_for_resnet(
                    model,
                    img_path,
                    training_order
                )

            # Plot model name in first column
            plt.subplot(n_images * n_rows_per_image, n_cols, row_start)
            plt.text(0.5, 0.5, model_name, ha='center', va='center')
            plt.axis('off')

            # Plot GradCAM
            plt.subplot(n_images * n_rows_per_image, n_cols, row_start + 1)
            if gradcam_vis is not None:
                plt.imshow(gradcam_vis)
                plt.title(f'GradCAM\nConf: {gradcam_conf:.1f}%\n'
                         f'Pred: {training_order[gradcam_pred]}')
            plt.axis('off')

            # Plot Guided BP
            plt.subplot(n_images * n_rows_per_image, n_cols, row_start + 2)
            if guided_vis is not None:
                plt.imshow(guided_vis)
                plt.title(f'GuidedBP\nConf: {guided_conf:.1f}%\n'
                         f'Pred: {training_order[guided_pred]}')
            plt.axis('off')

    plt.tight_layout()
    plt.show()

# Usage:
models_dict = {
    'ResNet50': resnet50_model,
    'ResNet101': resnet101_model,
    'EfficientNetB0': efficientnetb0_model,
    'EfficientNetB1': efficientnetb1_model
}

# Clear any existing hooks
for name, module in next(iter(models_dict.values())).named_modules():
    if hasattr(module, '_forward_hooks'):
        module._forward_hooks.clear()

# Run visualization
visualize_all_methods_organized(
    models_dict=models_dict,
    image_list=images_to_process,
    test_folder='/content/drive/MyDrive/Colab Notebooks/Resized_Testing',
    training_order=training_order
)

## Attention roll out and Beyond Attention for Vit

In [None]:
vit_model = initialize_vit_model(num_classes=len(tumor_types))
vit_model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/final_vit_small_mri_classification_model.pth'))
vit_model.eval()

In [None]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
import math

class VitAttentionRollout:
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.attention_layers = []
        self.hooks = []

    def register_hooks(self):
        def hook_fn(name):
            def hook(module, input, output):
                with torch.no_grad():
                    B, N, C = output.shape
                    num_heads = self.model.blocks[0].attn.num_heads
                    head_dim = C // (3 * num_heads)

                    qkv = output.reshape(B, N, 3, num_heads, head_dim)
                    qkv = qkv.permute(2, 0, 3, 1, 4)
                    q, k, v = qkv[0], qkv[1], qkv[2]

                    attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim))
                    attn = torch.nn.functional.softmax(attn, dim=-1)
                    self.attention_layers.append(attn.detach())
            return hook

        for block in self.model.blocks:
            self.hooks.append(block.attn.qkv.register_forward_hook(hook_fn("attn")))

    def cleanup(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

    def get_attention_rollout(self, input_tensor):
        self.attention_layers = []
        self.cleanup()
        self.register_hooks()

        with torch.no_grad():
            _ = self.model(input_tensor)

        self.cleanup()

        attention_maps = [attn.mean(dim=1) for attn in self.attention_layers]
        flat_attn = attention_maps[0]
        for attn in attention_maps[1:]:
            flat_attn = torch.matmul(attn, flat_attn)

        attention = flat_attn[0, 1:, 1:]
        attention = attention.mean(dim=-1)

        grid_size = int(math.sqrt(attention.shape[0]))
        attention = attention.reshape(grid_size, grid_size).cpu().numpy()

        attention = (attention - attention.min()) / (attention.max() - attention.min())

        return attention

class VitBeyondAttention:
    def __init__(self, model, device='cuda', discard_ratio=0.9):
        self.model = model
        self.device = device
        self.attention_layers = []
        self.hooks = []
        self.discard_ratio = discard_ratio

    def register_hooks(self):
        def hook_fn(name):
            def hook(module, input, output):
                with torch.no_grad():
                    B, N, C = output.shape
                    num_heads = self.model.blocks[0].attn.num_heads
                    head_dim = C // (3 * num_heads)

                    qkv = output.reshape(B, N, 3, num_heads, head_dim)
                    qkv = qkv.permute(2, 0, 3, 1, 4)
                    q, k, v = qkv[0], qkv[1], qkv[2]

                    attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim))
                    attn = torch.nn.functional.softmax(attn, dim=-1)
                    self.attention_layers.append(attn.detach())
            return hook

        for block in self.model.blocks:
            self.hooks.append(block.attn.qkv.register_forward_hook(hook_fn("attn")))

    def cleanup(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

    def get_beyond_attention(self, input_tensor):
        self.attention_layers = []
        self.cleanup()
        self.register_hooks()

        with torch.no_grad():
            _ = self.model(input_tensor)

        self.cleanup()

        attention_maps = [attn.mean(dim=1) for attn in self.attention_layers]
        attention_maps = [F.relu(attn) for attn in attention_maps]

        for attn in attention_maps:
            flat_attn = attn.view(-1)
            threshold = torch.sort(flat_attn)[0][int(flat_attn.shape[0] * self.discard_ratio)]
            attn[attn < threshold] = 0

        flat_attn = attention_maps[0]
        for attn in attention_maps[1:]:
            flat_attn = torch.matmul(attn, flat_attn)

        eye = torch.eye(flat_attn.shape[-1], device=flat_attn.device)
        flat_attn = flat_attn + eye
        flat_attn = flat_attn / flat_attn.sum(dim=-1, keepdim=True)

        attention = flat_attn[0, 1:, 1:]
        attention = attention.mean(dim=-1)

        grid_size = int(math.sqrt(attention.shape[0]))
        attention = attention.reshape(grid_size, grid_size).cpu().numpy()

        attention = (attention - attention.min()) / (attention.max() - attention.min())

        return attention

def clear_hooks(model):
    """Clear any existing hooks from the model"""
    for name, module in model.named_modules():
        if hasattr(module, '_forward_hooks'):
            module._forward_hooks.clear()

def create_attention_overlay(image, attention_map, alpha=0.6):
    """Create a bright attention overlay with jet colormap"""
    heatmap = plt.cm.jet(attention_map)[:, :, :3]
    heatmap = np.clip(heatmap * 1.2, 0, 1)
    overlay = image * (1 - alpha) + heatmap * alpha
    return np.clip(overlay, 0, 1)

def visualize_combined_attention(model, image_path, training_order=None, discard_ratio=0.9):
    """
    Visualize both attention rollout and beyond attention overlays for a ViT model

    Args:
        model: ViT model
        image_path: Path to input image
        training_order: Optional list of class names
        discard_ratio: Ratio for discarding low attention weights in beyond attention

    Returns:
        tuple: (rollout_map, beyond_map, prediction, confidence)
    """
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)

        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        image = Image.open(image_path).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            prediction = torch.argmax(output).item()
            confidence = probabilities[0][prediction].item() * 100

        clear_hooks(model)

        attention_rollout = VitAttentionRollout(model, device=device)
        rollout_map = attention_rollout.get_attention_rollout(input_tensor)

        clear_hooks(model)

        beyond_attention = VitBeyondAttention(model, device=device, discard_ratio=discard_ratio)
        beyond_map = beyond_attention.get_beyond_attention(input_tensor)

        def resize_map(attention_map):
            return F.interpolate(
                torch.tensor(attention_map).unsqueeze(0).unsqueeze(0),
                size=(224, 224),
                mode='bilinear',
                align_corners=False
            ).squeeze().numpy()

        rollout_map = resize_map(rollout_map)
        beyond_map = resize_map(beyond_map)

        img_np = np.array(image.resize((224, 224))) / 255.0

        rollout_overlay = create_attention_overlay(img_np, rollout_map)
        beyond_overlay = create_attention_overlay(img_np, beyond_map)

        # Smaller figure size
        plt.figure(figsize=(12, 4))

        # Original image
        plt.subplot(131)
        plt.imshow(image)
        plt.title('Original Image')
        plt.axis('off')

        # Attention rollout overlay
        plt.subplot(132)
        plt.imshow(rollout_overlay)
        plt.title('Attention Rollout')
        plt.axis('off')

        # Beyond attention overlay
        plt.subplot(133)
        plt.imshow(beyond_overlay)
        plt.title('Beyond Attention')
        plt.axis('off')

        # Add prediction text to the bottom of the figure
        plt.figtext(0.5, 0.02,
                   f"Prediction: {training_order[prediction] if training_order else prediction} (Confidence: {confidence:.1f}%)",
                   ha='center', va='center')

        plt.tight_layout()
        plt.show()

        return rollout_map, beyond_map, prediction, confidence

    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None, None, None

In [None]:
# Clear any existing hooks first
clear_hooks(vit_model)

# Use your actual image path
img_path = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing/glioma/Te-gl_0261.jpg'
rollout_map, beyond_map, pred, conf = visualize_combined_attention(vit_model, img_path, tumor_types)

In [None]:
def get_image_path(base_path, img_name):
    """Helper function to construct the correct image path based on name prefix"""
    if 'gl' in img_name:
        subfolder = 'glioma'
    elif 'me' in img_name:
        subfolder = 'meningioma'
    elif 'pi' in img_name:
        subfolder = 'pituitary'
    return f"{base_path}/{subfolder}/{img_name}.jpg"

def visualize_batch_attention(model, image_names, base_path, training_order=None, discard_ratio=0.9):
    """
    Visualize attention for multiple images with smaller figure size
    """
    for idx, img_name in enumerate(image_names):
        try:
            img_path = get_image_path(base_path, img_name)

            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            model = model.to(device)

            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

            image = Image.open(img_path).convert('RGB')
            input_tensor = transform(image).unsqueeze(0).to(device)

            with torch.no_grad():
                output = model(input_tensor)
                probabilities = torch.nn.functional.softmax(output, dim=1)
                prediction = torch.argmax(output).item()
                confidence = probabilities[0][prediction].item() * 100

            clear_hooks(model)
            attention_rollout = VitAttentionRollout(model, device=device)
            rollout_map = attention_rollout.get_attention_rollout(input_tensor)

            clear_hooks(model)
            beyond_attention = VitBeyondAttention(model, device=device, discard_ratio=discard_ratio)
            beyond_map = beyond_attention.get_beyond_attention(input_tensor)

            def resize_map(attention_map):
                return F.interpolate(
                    torch.tensor(attention_map).unsqueeze(0).unsqueeze(0),
                    size=(224, 224),
                    mode='bilinear',
                    align_corners=False
                ).squeeze().numpy()

            rollout_map = resize_map(rollout_map)
            beyond_map = resize_map(beyond_map)

            img_np = np.array(image.resize((224, 224))) / 255.0
            rollout_overlay = create_attention_overlay(img_np, rollout_map)
            beyond_overlay = create_attention_overlay(img_np, beyond_map)

            # Reduced figure size from (12, 4) to (9, 3)
            fig_local = plt.figure(figsize=(9, 3))

            plt.subplot(131)
            plt.imshow(image)
            plt.title('Original')
            plt.axis('off')

            plt.subplot(132)
            plt.imshow(rollout_overlay)
            plt.title('Rollout')
            plt.axis('off')

            plt.subplot(133)
            plt.imshow(beyond_overlay)
            plt.title('Beyond')
            plt.axis('off')

            plt.figtext(0.5, 0.02,
                       f"{img_name}\nPred: {training_order[prediction] if training_order else prediction} ({confidence:.1f}%)",
                       ha='center', va='center')

            plt.tight_layout()
            plt.show()

            plt.close(fig_local)

        except Exception as e:
            print(f"Error processing {img_name}: {str(e)}")
            continue


# Usage example:
base_path = '/content/drive/MyDrive/Colab Notebooks/Resized_Testing'
images_to_process = [
    'Te-gl_0261',
    'Te-gl_0037',
    'Te-gl_0131',
    'Te-me_0224',
    'Te-me_0231',
    'Te-me_0162',
    'Te-pi_0205',
    'Te-pi_0132',
    'Te-pi_0110'
]

# Clear any existing hooks
clear_hooks(vit_model)

# Process all images
visualize_batch_attention(vit_model, images_to_process, base_path, tumor_types)