<a href="https://colab.research.google.com/github/shreyas21004/mri/blob/main/mri_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# downloading the dataset

In [None]:
# Move kaggle.json to ~/.kaggle directory
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download the Brain Tumor MRI Dataset
!kaggle datasets download -d masoudnickparvar/brain-tumor-mri-dataset
!unzip brain-tumor-mri-dataset.zip

# data loading and preprocessing

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import AutoImageProcessor, DefaultDataCollator
from PIL import Image
import os

# Load Hugging Face Image Processor
processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

# Define Image Transformations for CNN models (ResNet, EfficientNet, DenseNet)
cnn_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

# Define Image Transformations for Swin Transformer
swin_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)  # Swin normalization
])

class BrainTumorDataset(Dataset):
    def __init__(self, root_dir, transform=None, model_type="cnn"):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.model_type = model_type
        self.class_to_idx = {"glioma": 0, "meningioma": 1, "notumor": 2, "pituitary": 3}

        for label in os.listdir(root_dir):
            label_dir = os.path.join(root_dir, label)
            for img_name in os.listdir(label_dir):
                img_path = os.path.join(label_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(self.class_to_idx[label])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        # Adjust output format depending on model type
        if self.model_type == "swin":
            return {"pixel_values": image, "labels": torch.tensor(label)}
        else:
            return image, torch.tensor(label)

# Create datasets
train_dataset_cnn = BrainTumorDataset(root_dir="Training", transform=cnn_transform, model_type="cnn")
test_dataset_cnn = BrainTumorDataset(root_dir="Testing", transform=cnn_transform, model_type="cnn")

train_dataset_swin = BrainTumorDataset(root_dir="Training", transform=swin_transform, model_type="swin")
test_dataset_swin = BrainTumorDataset(root_dir="Testing", transform=swin_transform, model_type="swin")

# Use Default Data Collator for Swin Transformer
data_collator = DefaultDataCollator(return_tensors="pt")

# Create DataLoaders
train_loader_cnn = DataLoader(train_dataset_cnn, batch_size=8, shuffle=True)
test_loader_cnn = DataLoader(test_dataset_cnn, batch_size=8, shuffle=False)

train_loader_swin = DataLoader(train_dataset_swin, batch_size=8, shuffle=True, collate_fn=data_collator)
test_loader_swin = DataLoader(test_dataset_swin, batch_size=8, shuffle=False, collate_fn=data_collator)

# Print dataset size
print(f"Training Samples: {len(train_dataset_cnn)}, Testing Samples: {len(test_dataset_cnn)}")



# training and evaluation

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from transformers import SwinForImageClassification, SwinConfig
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                            f1_score, roc_auc_score, confusion_matrix,
                            classification_report)
import numpy as np
import time
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd
import os

# Set up directories for Kaggle
os.makedirs("/kaggle/working/results", exist_ok=True)
os.makedirs("/kaggle/working/plots", exist_ok=True)
os.makedirs("/kaggle/working/saved_models", exist_ok=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def initialize_models(num_classes=4):
    """Initialize all models with robust anti-overfitting configurations"""
    models_dict = {}

    try:
        # ResNet50
        resnet = models.resnet50(pretrained=True)
        resnet.layer1 = nn.Sequential(
            resnet.layer1,
            nn.Dropout(0.3),
            nn.BatchNorm2d(256)
        )
        resnet.layer2 = nn.Sequential(
            resnet.layer2,
            nn.Dropout(0.4),
            nn.BatchNorm2d(512)
        )
        resnet.fc = nn.Sequential(
            nn.Dropout(0.6),
            nn.Linear(resnet.fc.in_features, num_classes)
        )
        models_dict["ResNet50"] = resnet.to(device)

        # EfficientNetB0
        efficientnet = models.efficientnet_b0(pretrained=True)
        efficientnet.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(efficientnet.classifier[1].in_features, num_classes)
        )
        efficientnet.features = nn.Sequential(
            efficientnet.features,
            nn.Dropout(0.2)
        )
        models_dict["EfficientNetB0"] = efficientnet.to(device)

        # DenseNet121
        densenet = models.densenet121(pretrained=True)
        densenet.features.transition1 = nn.Sequential(
            densenet.features.transition1,
            nn.Dropout(0.3)
        )
        densenet.features.transition2 = nn.Sequential(
            densenet.features.transition2,
            nn.Dropout(0.3)
        )
        densenet.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(densenet.classifier.in_features, num_classes)
        )
        models_dict["DenseNet121"] = densenet.to(device)

        # Swin Transformer
        config = SwinConfig.from_pretrained(
            "microsoft/swin-tiny-patch4-window7-224",
            num_labels=num_classes,
            attention_probs_dropout_prob=0.3,
            hidden_dropout_prob=0.3,
            path_dropout=0.2
        )
        swin = SwinForImageClassification.from_pretrained(
            "microsoft/swin-tiny-patch4-window7-224",
            config=config,
            ignore_mismatched_sizes=True
        )
        if hasattr(swin, 'classifier'):
            if isinstance(swin.classifier, nn.Linear):
                swin.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(swin.classifier.in_features, swin.classifier.out_features)
                )
        models_dict["SwinTransformer"] = swin.to(device)

    except Exception as e:
        print(f"Error initializing models: {str(e)}")
        raise

    return models_dict

def save_model_on_error(model, model_name, epoch, error):
    path = f"/kaggle/working/saved_models/error_{model_name}_epoch{epoch}.pth"
    torch.save(model.state_dict(), path)
    print(f"Model saved at {path} due to error: {str(error)}")

def train_model(model, train_loader, val_loader, model_name, epochs=15):
    """Training loop with validation, early stopping, and regularization"""
    if model_name == "SwinTransformer":
        optimizer = optim.AdamW([
            {'params': [p for n, p in model.named_parameters() if 'classifier' not in n],
             'weight_decay': 0.1, 'lr': 2e-5},
            {'params': [p for n, p in model.named_parameters() if 'classifier' in n],
             'weight_decay': 0.01, 'lr': 1e-4}
        ])
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-4, steps_per_epoch=len(train_loader), epochs=epochs,
            pct_start=0.1, anneal_strategy='cos', div_factor=10, final_div_factor=1e4
        )
    else:
        optimizer = optim.AdamW([
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in ['fc', 'classifier'])],
             'weight_decay': 1e-4},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in ['fc', 'classifier'])],
             'weight_decay': 0.01, 'lr': 1e-4}
        ])
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    criterion = nn.CrossEntropyLoss()
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    best_val_acc = 0.0
    patience = 3
    patience_counter = 0

    for epoch in range(epochs):
        try:
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            train_loop = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs} [Train]')
            for batch in train_loop:
                try:
                    if model_name == "SwinTransformer":
                        inputs = batch['pixel_values'].to(device)
                        labels = batch['labels'].to(device)
                        outputs = model(inputs, labels=labels)
                        loss = outputs.loss
                        logits = outputs.logits
                    else:
                        inputs, labels = batch
                        inputs, labels = inputs.to(device), labels.to(device)
                        logits = model(inputs)
                        loss = criterion(logits, labels)

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                    running_loss += loss.item()
                    _, predicted = torch.max(logits.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    train_loop.set_postfix(loss=loss.item(), acc=100. * correct / total)

                except Exception as e:
                    save_model_on_error(model, model_name, epoch, e)
                    raise

            train_loss = running_loss / len(train_loader)
            train_acc = 100. * correct / total
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)

            val_loss, val_acc = evaluate_model(model, val_loader, model_name)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)

            scheduler.step()

            if val_acc > best_val_acc + 0.001:
                best_val_acc = val_acc
                patience_counter = 0
                torch.save(model.state_dict(), f"/kaggle/working/results/best_{model_name}.pth")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"\nEarly stopping triggered at epoch {epoch + 1}")
                    model.load_state_dict(torch.load(f"/kaggle/working/results/best_{model_name}.pth"))
                    break

            print(f'Epoch {epoch + 1}/{epochs} - '
                  f'Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}% | '
                  f'Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}% | '
                  f'LR: {optimizer.param_groups[0]["lr"]:.2e}')

        except Exception as e:
            save_model_on_error(model, model_name, epoch, e)
            print(f"Training interrupted at epoch {epoch + 1} due to: {str(e)}")
            break

    return history


def evaluate_model(model, loader, model_name):
    """Evaluate model on validation/test set"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in loader:
            try:
                if model_name == "SwinTransformer":
                    inputs = batch['pixel_values'].to(device)
                    labels = batch['labels'].to(device)
                    outputs = model(inputs, labels=labels)
                    loss = outputs.loss
                    logits = outputs.logits
                else:
                    inputs, labels = batch
                    inputs, labels = inputs.to(device), labels.to(device)
                    logits = model(inputs)
                    loss = criterion(logits, labels)

                running_loss += loss.item()
                _, predicted = torch.max(logits.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            except Exception as e:
                print(f"Error during evaluation batch: {str(e)}")
                continue

    val_loss = running_loss / len(loader)
    val_acc = 100. * correct / total
    return val_loss, val_acc

def comprehensive_evaluation(model, loader, model_name, class_names):
    """Generate comprehensive evaluation metrics"""
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for batch in loader:
            try:
                if model_name == "SwinTransformer":
                    inputs = batch['pixel_values'].to(device)
                    labels = batch['labels'].to(device)
                    outputs = model(inputs)
                    if not hasattr(outputs, 'logits'):
                        raise ValueError("SwinTransformer output missing logits")
                    logits = outputs.logits
                else:
                    inputs, labels = batch
                    inputs, labels = inputs.to(device), labels.to(device)
                    logits = model(inputs)

                probs = torch.nn.functional.softmax(logits, dim=1)
                _, preds = torch.max(logits, 1)

                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

            except Exception as e:
                print(f"Error during evaluation batch: {str(e)}")
                continue

    # Calculate metrics
    metrics = {}
    try:
        metrics['accuracy'] = accuracy_score(all_labels, all_preds)
        metrics['precision'] = precision_score(all_labels, all_preds, average='weighted')
        metrics['recall'] = recall_score(all_labels, all_preds, average='weighted')
        metrics['f1_score'] = f1_score(all_labels, all_preds, average='weighted')
    except Exception as e:
        print(f"Error calculating basic metrics: {str(e)}")
        metrics.update({'accuracy': 0, 'precision': 0, 'recall': 0, 'f1_score': 0})

    try:
        metrics['roc_auc'] = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='weighted')
    except Exception as e:
        print(f"Error calculating ROC AUC: {str(e)}")
        metrics['roc_auc'] = 0.0

    try:
        metrics['confusion_matrix'] = confusion_matrix(all_labels, all_preds)
    except Exception as e:
        print(f"Error calculating confusion matrix: {str(e)}")
        metrics['confusion_matrix'] = np.zeros((len(class_names), len(class_names)))

    try:
        metrics['classification_report'] = classification_report(
            all_labels, all_preds, target_names=class_names, output_dict=True)
    except Exception as e:
        print(f"Error generating classification report: {str(e)}")
        metrics['classification_report'] = {name: {'precision': 0, 'recall': 0, 'f1-score': 0}
                                          for name in class_names}

    return metrics

def plot_history(history, model_name):
    """Plot and save training history"""
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title(f'{model_name} - Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title(f'{model_name} - Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.savefig(f'/kaggle/working/plots/{model_name}_training_history.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_confusion_matrix(cm, class_names, model_name):
    """Plot and save confusion matrix"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'{model_name} - Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(f'/kaggle/working/plots/{model_name}_confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.close()

def verify_dataloaders(cnn_loader, swin_loader):
    """Verify DataLoader formats"""
    try:
        cnn_batch = next(iter(cnn_loader))
        if not (isinstance(cnn_batch, (tuple, list)) or len(cnn_batch) != 2):
            raise ValueError("CNN DataLoader should return (inputs, labels) tuples")

        swin_batch = next(iter(swin_loader))
        if not isinstance(swin_batch, dict) or 'pixel_values' not in swin_batch or 'labels' not in swin_batch:
            raise ValueError("Swin DataLoader should return dicts with 'pixel_values' and 'labels'")

        print("DataLoader verification passed successfully")
    except Exception as e:
        print(f"DataLoader verification failed: {str(e)}")
        raise

def generate_comparison_table(results, class_names):
    """Generate and display model comparison table"""
    print("\n\nCOMPREHENSIVE MODEL COMPARISON")
    print("="*120)
    print(f"{'Model':<20} | {'Accuracy':<8} | {'F1-Score':<8} | {'ROC AUC':<8} | {'Params (M)':<10} | {'Time (min)':<10}")
    print("-"*120)

    for model_name, metrics in results.items():
        print(f"{model_name:<20} | {metrics['accuracy']:.4f}    | {metrics['f1_score']:.4f}    | "
              f"{metrics['roc_auc']:.4f}    | {metrics['total_params']/1e6:<10.2f} | "
              f"{metrics['training_time']/60:<10.2f}")

    # Generate detailed results dataframe
    detailed_results = []
    for model_name, metrics in results.items():
        row = {
            'Model': model_name,
            'Accuracy': metrics['accuracy'],
            'Precision': metrics['precision'],
            'Recall': metrics['recall'],
            'F1-Score': metrics['f1_score'],
            'ROC AUC': metrics['roc_auc'],
            'Total Params (M)': metrics['total_params']/1e6,
            'Trainable Params (M)': metrics['trainable_params']/1e6,
            'Training Time (min)': metrics['training_time']/60
        }
        for class_name in class_names:
            row.update({
                f'{class_name} Precision': metrics['classification_report'][class_name]['precision'],
                f'{class_name} Recall': metrics['classification_report'][class_name]['recall'],
                f'{class_name} F1': metrics['classification_report'][class_name]['f1-score']
            })
        detailed_results.append(row)

    df_results = pd.DataFrame(detailed_results)

    # Reorder columns for better readability
    column_order = ['Model', 'Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC AUC',
                  'Total Params (M)', 'Trainable Params (M)', 'Training Time (min)']
    for class_name in class_names:
        column_order.extend([f'{class_name} Precision', f'{class_name} Recall', f'{class_name} F1'])

    df_results = df_results[column_order]
    df_results.to_csv('/kaggle/working/results/model_comparison.csv', index=False)
    print("\nDetailed results saved to '/kaggle/working/results/model_comparison.csv'")

    # Generate LaTeX table
    latex_table = df_results.to_latex(index=False, float_format="%.4f",
                                    caption="Model Performance Comparison",
                                    label="tab:model_comparison")
    with open('/kaggle/working/results/model_comparison.tex', 'w') as f:
        f.write(latex_table)
    print("LaTeX table saved to '/kaggle/working/results/model_comparison.tex'")

    return df_results

if __name__ == "__main__":
    try:
        # Initialize models and datasets
        models_dict = initialize_models()
        class_names = ["glioma", "meningioma", "notumor", "pituitary"]

        # Verify DataLoaders are defined and have correct format
        global train_loader_cnn, test_loader_cnn, train_loader_swin, test_loader_swin
        verify_dataloaders(train_loader_cnn, train_loader_swin)

        # Create loaders dictionary
        loaders_dict = {
            "ResNet50": (train_loader_cnn, test_loader_cnn),
            "EfficientNetB0": (train_loader_cnn, test_loader_cnn),
            "DenseNet121": (train_loader_cnn, test_loader_cnn),
            "SwinTransformer": (train_loader_swin, test_loader_swin)
        }

        # Train and evaluate models
        results = {}
        for model_name, model in models_dict.items():
            print(f"\n{'='*50}")
            print(f"Training {model_name}")
            print(f"{'='*50}")

            try:
                train_loader, test_loader = loaders_dict[model_name]

                # Train model
                start_time = time.time()
                history = train_model(model, train_loader, test_loader, model_name, epochs=20)
                training_time = time.time() - start_time

                # Save training history
                plot_history(history, model_name)

                # Load best model for evaluation
                model.load_state_dict(torch.load(f"/kaggle/working/results/best_{model_name}.pth"))

                # Comprehensive evaluation
                eval_results = comprehensive_evaluation(model, test_loader, model_name, class_names)
                eval_results['training_time'] = training_time
                eval_results['total_params'] = sum(p.numel() for p in model.parameters())
                eval_results['trainable_params'] = sum(p.numel() for p in model.parameters() if p.requires_grad)

                results[model_name] = eval_results
                plot_confusion_matrix(eval_results['confusion_matrix'], class_names, model_name)

            except Exception as e:
                print(f"Error processing {model_name}: {str(e)}")
                continue

        # Generate comparison table and save results
        if results:
            comparison_df = generate_comparison_table(results, class_names)

            # Print class-wise comparison
            print("\n\nCLASS-WISE PERFORMANCE COMPARISON")
            print("="*90)
            for class_name in class_names:
                print(f"\n{class_name}:")
                print(f"{'Model':<20} | {'Precision':<10} | {'Recall':<10} | {'F1-Score':<10}")
                print("-"*50)
                for model_name, metrics in results.items():
                    print(f"{model_name:<20} | {metrics['classification_report'][class_name]['precision']:<10.4f} | "
                          f"{metrics['classification_report'][class_name]['recall']:<10.4f} | "
                          f"{metrics['classification_report'][class_name]['f1-score']:<10.4f}")
        else:
            print("\nNo models were successfully trained. Check error messages above.")

    except Exception as e:
        print(f"\nFatal error in main execution: {str(e)}")