# Model Comparison: HSANet vs ViT, Swin, ResNet, VGG, EfficientNet

This notebook trains and evaluates multiple models on the Brain Tumor MRI Dataset and generates publication-ready comparison figures.

**Models compared:**
- ViT-B/16 (Vision Transformer)
- Swin-Tiny (Swin Transformer)
- ResNet-50
- VGG-16
- EfficientNet-B3 (Baseline)
- HSANet (Our model - loaded from checkpoint)

## 1. Install Dependencies

In [None]:
!pip install timm torch torchvision matplotlib seaborn scikit-learn pandas -q

## 2. Imports and Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.preprocessing import label_binarize
import pandas as pd
import time
import json
from pathlib import Path
from math import pi

# Create output directory
Path("comparison_figures").mkdir(exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3. Data Loading

In [None]:
# Data transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset - UPDATE PATH FOR YOUR KAGGLE
train_dataset = datasets.ImageFolder('/kaggle/input/brain-tumor-mri-dataset/Training', transform=train_transform)
test_dataset = datasets.ImageFolder('/kaggle/input/brain-tumor-mri-dataset/Testing', transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

class_names = train_dataset.classes
num_classes = len(class_names)
print(f"Classes: {class_names}")
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 4. Model Definitions

In [None]:
def get_model(model_name, num_classes=4):
    """Get model by name with pretrained weights"""
    if model_name == 'vit_base_patch16_224':
        model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)
    elif model_name == 'swin_tiny_patch4_window7_224':
        model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=num_classes)
    elif model_name == 'resnet50':
        model = timm.create_model('resnet50', pretrained=True, num_classes=num_classes)
    elif model_name == 'vgg16':
        model = timm.create_model('vgg16', pretrained=True, num_classes=num_classes)
    elif model_name == 'efficientnet_b3':
        model = timm.create_model('efficientnet_b3', pretrained=True, num_classes=num_classes)
    else:
        raise ValueError(f"Unknown model: {model_name}")
    return model

# Model configurations
MODELS = {
    'ViT-B/16': 'vit_base_patch16_224',
    'Swin-Tiny': 'swin_tiny_patch4_window7_224', 
    'ResNet-50': 'resnet50',
    'VGG-16': 'vgg16',
    'EfficientNet-B3': 'efficientnet_b3'
}

print(f"Models to train: {list(MODELS.keys())}")

## 5. Training Function

In [None]:
def train_and_evaluate(model_name, timm_name, epochs=15):
    """Train a model and return all metrics"""
    print(f"\n{'='*50}")
    print(f"Training {model_name}")
    print(f"{'='*50}")
    
    model = get_model(timm_name, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    
    # Count parameters
    params = sum(p.numel() for p in model.parameters()) / 1e6
    
    # Training loop
    train_losses, train_accs = [], []
    
    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0, 0, 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
        
        scheduler.step()
        train_acc = 100. * correct / total
        train_losses.append(running_loss / len(train_loader))
        train_accs.append(train_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader):.4f}, Acc: {train_acc:.2f}%")
    
    # Evaluation
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    
    start_time = time.time()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            _, preds = outputs.max(1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    inference_time = (time.time() - start_time) / len(test_dataset) * 1000
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    # Metrics
    accuracy = 100. * np.mean(all_preds == all_labels)
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    f1_macro = report['macro avg']['f1-score'] * 100
    cm = confusion_matrix(all_labels, all_preds)
    
    # ROC curves
    all_labels_bin = label_binarize(all_labels, classes=range(num_classes))
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(all_labels_bin[:, i], all_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    results = {
        'model': model_name,
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'params_m': params,
        'inference_ms': inference_time,
        'confusion_matrix': cm,
        'fpr': fpr,
        'tpr': tpr,
        'roc_auc': roc_auc,
        'train_losses': train_losses,
        'train_accs': train_accs,
        'per_class_f1': {class_names[i]: report[class_names[i]]['f1-score']*100 for i in range(num_classes)}
    }
    
    print(f"\n{model_name} Results:")
    print(f"  Accuracy: {accuracy:.2f}%")
    print(f"  F1 Macro: {f1_macro:.2f}%")
    print(f"  Params: {params:.2f}M")
    print(f"  Inference: {inference_time:.2f}ms")
    
    return results

## 6. Train All Models

In [None]:
# Train all comparison models
all_results = {}

for model_name, timm_name in MODELS.items():
    results = train_and_evaluate(model_name, timm_name, epochs=15)
    all_results[model_name] = results

# Add HSANet results (from your training)
all_results['HSANet'] = {
    'model': 'HSANet',
    'accuracy': 99.77,
    'f1_macro': 99.75,
    'params_m': 15.6,
    'inference_ms': 12.0,
    'per_class_f1': {'glioma': 99.69, 'meningioma': 99.69, 'notumor': 99.87, 'pituitary': 99.75},
    'train_losses': None,
    'train_accs': None
}

print("\n" + "="*50)
print("ALL RESULTS SUMMARY")
print("="*50)
for name, res in all_results.items():
    print(f"{name:15} | Acc: {res['accuracy']:.2f}% | F1: {res['f1_macro']:.2f}% | Params: {res['params_m']:.1f}M")

## 7. FIGURE 1: Accuracy Comparison Bar Chart

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')
colors = ['#2ecc71', '#3498db', '#9b59b6', '#e74c3c', '#f39c12', '#1abc9c']

fig, ax = plt.subplots(figsize=(12, 6))
models = list(all_results.keys())
accs = [all_results[m]['accuracy'] for m in models]

bars = ax.bar(models, accs, color=colors[:len(models)], edgecolor='black', linewidth=1.5)

# Highlight HSANet
if 'HSANet' in models:
    hsanet_idx = models.index('HSANet')
    bars[hsanet_idx].set_edgecolor('gold')
    bars[hsanet_idx].set_linewidth(3)

ax.set_ylabel('Accuracy (%)', fontsize=14)
ax.set_xlabel('Model', fontsize=14)
ax.set_title('Model Comparison: Classification Accuracy on Brain Tumor MRI', fontsize=16, fontweight='bold')
ax.set_ylim(95, 100.5)

for bar, acc in zip(bars, accs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
            f'{acc:.2f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.xticks(rotation=15, ha='right')
plt.tight_layout()
plt.savefig('comparison_figures/fig1_accuracy_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. FIGURE 2: Parameters vs Accuracy Scatter

In [None]:
fig, ax = plt.subplots(figsize=(10, 7))

for i, model in enumerate(models):
    params = all_results[model]['params_m']
    acc = all_results[model]['accuracy']
    
    size = 400 if model == 'HSANet' else 200
    marker = '*' if model == 'HSANet' else 'o'
    
    ax.scatter(params, acc, s=size, c=colors[i], marker=marker, 
               label=model, edgecolor='black', linewidth=2, zorder=5)

ax.set_xlabel('Parameters (Millions)', fontsize=14)
ax.set_ylabel('Accuracy (%)', fontsize=14)
ax.set_title('Efficiency Analysis: Parameters vs Accuracy', fontsize=16, fontweight='bold')
ax.legend(loc='lower right', fontsize=11)
ax.set_ylim(95, 100.5)
ax.grid(True, alpha=0.3)

# Add efficiency region
ax.axvspan(10, 20, alpha=0.1, color='green', label='Optimal Region')

plt.tight_layout()
plt.savefig('comparison_figures/fig2_params_vs_accuracy.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. FIGURE 3: Radar Chart Comparison

In [None]:
metrics = ['Accuracy', 'F1-Score', 'Efficiency', 'Speed']
num_metrics = len(metrics)

def normalize_metrics(results):
    max_params = max(r['params_m'] for r in all_results.values())
    max_time = max(r['inference_ms'] for r in all_results.values())
    
    return [
        results['accuracy'],
        results['f1_macro'],
        100 * (1 - results['params_m'] / max_params),
        100 * (1 - results['inference_ms'] / max_time)
    ]

fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))

angles = [n / float(num_metrics) * 2 * pi for n in range(num_metrics)]
angles += angles[:1]

selected_models = ['HSANet', 'ViT-B/16', 'Swin-Tiny', 'ResNet-50']
for i, model in enumerate(selected_models):
    if model in all_results:
        values = normalize_metrics(all_results[model])
        values += values[:1]
        
        linewidth = 3 if model == 'HSANet' else 1.5
        ax.plot(angles, values, 'o-', linewidth=linewidth, label=model, color=colors[i])
        ax.fill(angles, values, alpha=0.15, color=colors[i])

ax.set_xticks(angles[:-1])
ax.set_xticklabels(metrics, fontsize=13)
ax.set_ylim(0, 105)
ax.set_title('Multi-Dimensional Model Comparison', fontsize=16, fontweight='bold', y=1.08)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0), fontsize=11)

plt.tight_layout()
plt.savefig('comparison_figures/fig3_radar_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 10. FIGURE 4: Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for i, model in enumerate(models):
    if all_results[model].get('train_losses') is not None:
        losses = all_results[model]['train_losses']
        accs = all_results[model]['train_accs']
        
        linewidth = 3 if model == 'HSANet' else 1.5
        axes[0].plot(losses, label=model, linewidth=linewidth, color=colors[i])
        axes[1].plot(accs, label=model, linewidth=linewidth, color=colors[i])

axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss Curves', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training Accuracy Curves', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('comparison_figures/fig4_training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

## 11. FIGURE 5: ROC Curves Grid

In [None]:
models_with_roc = [m for m in models if 'roc_auc' in all_results[m] and all_results[m]['roc_auc']]

if models_with_roc:
    n_models = len(models_with_roc)
    n_cols = 3
    n_rows = (n_models + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
    axes = axes.flatten() if n_models > 1 else [axes]
    
    for idx, model in enumerate(models_with_roc):
        ax = axes[idx]
        for i, class_name in enumerate(class_names):
            ax.plot(all_results[model]['fpr'][i], all_results[model]['tpr'][i],
                   label=f'{class_name} (AUC={all_results[model]["roc_auc"][i]:.3f})')
        
        ax.plot([0, 1], [0, 1], 'k--', alpha=0.5)
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title(f'{model}', fontweight='bold')
        ax.legend(loc='lower right', fontsize=9)
    
    # Hide unused subplots
    for idx in range(n_models, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig('comparison_figures/fig5_roc_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No ROC data available")

## 12. FIGURE 6: Confusion Matrices Grid

In [None]:
models_with_cm = [m for m in models if 'confusion_matrix' in all_results[m] and all_results[m]['confusion_matrix'] is not None]

if models_with_cm:
    n_models = len(models_with_cm)
    n_cols = 3
    n_rows = (n_models + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
    axes = axes.flatten() if n_models > 1 else [axes]
    
    for idx, model in enumerate(models_with_cm):
        ax = axes[idx]
        cm = all_results[model]['confusion_matrix']
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                   xticklabels=class_names, yticklabels=class_names)
        ax.set_title(f'{model}', fontweight='bold')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
    
    for idx in range(n_models, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig('comparison_figures/fig6_confusion_matrices.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No confusion matrix data available")

## 13. FIGURE 7: Per-Class F1 Score Comparison

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))

x = np.arange(len(class_names))
width = 0.12
multiplier = 0

models_with_f1 = [m for m in models if 'per_class_f1' in all_results[m]]

for i, model in enumerate(models_with_f1):
    f1_scores = [all_results[model]['per_class_f1'].get(c, 0) for c in class_names]
    offset = width * multiplier
    
    edgecolor = 'gold' if model == 'HSANet' else 'black'
    linewidth = 2.5 if model == 'HSANet' else 0.5
    
    ax.bar(x + offset, f1_scores, width, label=model, color=colors[i],
           edgecolor=edgecolor, linewidth=linewidth)
    multiplier += 1

ax.set_xlabel('Tumor Class', fontsize=14)
ax.set_ylabel('F1-Score (%)', fontsize=14)
ax.set_title('Per-Class F1-Score Comparison Across Models', fontsize=16, fontweight='bold')
ax.set_xticks(x + width * (len(models_with_f1)-1) / 2)
ax.set_xticklabels(class_names, fontsize=12)
ax.legend(loc='lower right', fontsize=10)
ax.set_ylim(90, 101)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('comparison_figures/fig7_per_class_f1.png', dpi=300, bbox_inches='tight')
plt.show()

## 14. FIGURE 8: Computational Efficiency

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Sort by parameters
models_sorted = sorted(models, key=lambda m: all_results[m]['params_m'])
params = [all_results[m]['params_m'] for m in models_sorted]
times = [all_results[m]['inference_ms'] for m in models_sorted]
accs = [all_results[m]['accuracy'] for m in models_sorted]

# Parameters bar chart
bar_colors = ['gold' if m == 'HSANet' else '#3498db' for m in models_sorted]
axes[0].barh(models_sorted, params, color=bar_colors, edgecolor='black')
axes[0].set_xlabel('Parameters (Millions)', fontsize=12)
axes[0].set_title('Model Size Comparison', fontsize=14, fontweight='bold')
for i, v in enumerate(params):
    axes[0].text(v + 1, i, f'{v:.1f}M', va='center', fontsize=10)

# Inference time bar chart
bar_colors = ['gold' if m == 'HSANet' else '#e74c3c' for m in models_sorted]
axes[1].barh(models_sorted, times, color=bar_colors, edgecolor='black')
axes[1].set_xlabel('Inference Time (ms)', fontsize=12)
axes[1].set_title('Inference Speed Comparison', fontsize=14, fontweight='bold')
for i, v in enumerate(times):
    axes[1].text(v + 0.3, i, f'{v:.1f}ms', va='center', fontsize=10)

plt.tight_layout()
plt.savefig('comparison_figures/fig8_computational_efficiency.png', dpi=300, bbox_inches='tight')
plt.show()

## 15. Summary Table

In [None]:
# Create summary DataFrame
summary_data = []
for model in models:
    summary_data.append({
        'Model': model,
        'Accuracy (%)': f"{all_results[model]['accuracy']:.2f}",
        'F1-Score (%)': f"{all_results[model]['f1_macro']:.2f}",
        'Params (M)': f"{all_results[model]['params_m']:.1f}",
        'Inference (ms)': f"{all_results[model]['inference_ms']:.1f}"
    })

df = pd.DataFrame(summary_data)
print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)
print(df.to_string(index=False))

# Save as CSV
df.to_csv('comparison_figures/model_comparison_results.csv', index=False)
print("\nResults saved to comparison_figures/model_comparison_results.csv")

## 16. Download All Figures

In [None]:
import os

print("\n" + "="*50)
print("ALL FIGURES GENERATED:")
print("="*50)

for f in sorted(os.listdir('comparison_figures')):
    size = os.path.getsize(f'comparison_figures/{f}') / 1024
    print(f"  {f} ({size:.1f} KB)")

print("\nDownload the 'comparison_figures' folder to get all figures!")