In [None]:
!pip install lime
!pip install git+https://github.com/jacobgil/pytorch-grad-cam.git

In [None]:
# libraries:
import copy
import itertools
import os
import random
import warnings
import numpy as np
import pandas as pd
from scipy.stats import entropy

# Deep Learning - PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split

# Computer Vision & Image Processing
import cv2
from PIL import Image
from skimage import segmentation
from skimage.metrics import structural_similarity as ssim
from skimage.segmentation import felzenszwalb
from skimage.transform import resize
from torchvision import datasets, transforms

# Visualization
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from matplotlib.gridspec import GridSpec
import seaborn as sns

# Machine Learning Metrics
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

# Explainability Frameworks
import lime
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm
import shap
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# Suppress Warnings
warnings.filterwarnings('ignore', message='unrecognized nn.Module: Flatten')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


DATA + preprocessing:

In [None]:
train_dir = 'drive/MyDrive/MT/archive/Training'
test_dir = 'drive/MyDrive/MT/archive/Testing'

In [None]:
# Define transformations
train_transform = transforms.Compose([
    #transforms.TrivialAugmentWide(),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1848, 0.1848, 0.1848],
                           std=[0.1768, 0.1768, 0.1768])
])

val_test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1848, 0.1848, 0.1848],
                           std=[0.1768, 0.1768, 0.1768])
])

#for visualization:
test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

In [None]:
# Training dataset with augmentation
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)

# Validation dataset without augmentation
full_dataset = datasets.ImageFolder(train_dir, transform=val_test_transform)

# Testing dataset
test_dataset_norm = datasets.ImageFolder(test_dir, transform=val_test_transform)
test_dataset = datasets.ImageFolder(test_dir, transform=test_transform)

In [None]:
# Split train/validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
test_loader_norm = DataLoader(test_dataset_norm, batch_size=32, shuffle=False)

class-weights:

In [None]:
#Class weights: {0: 4.323996971990916, 1: 4.265870052277819, 2: 3.581191222570533, 3: 3.9203843514070007}

# Calculate class weights
train_counts = {
    0: 1321,  # Glioma
    1: 1339,  # Meningioma
    2: 1595,  # No Tumor
    3: 1457   # Pituitary
}

total_samples = sum(train_counts.values())  # 5712
class_weights = {class_idx: total_samples / count
                for class_idx, count in train_counts.items()}

print("Class weights:", class_weights)

Class weights: {0: 4.323996971990916, 1: 4.265870052277819, 2: 3.581191222570533, 3: 3.9203843514070007}


MODEL:

In [None]:
class BrainCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(BrainCNN, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout2d(0.2),

            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout2d(0.2),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.2)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

def train_and_validate(model, train_loader, val_loader, test_loader, class_weights=None, epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Calculate class weights if not provided
    if class_weights is None:
        class_weights = class_weights

    # Convert class weights to tensor
    weight_tensor = torch.FloatTensor([class_weights[i] for i in range(len(class_weights))]).to(device)

    # Initialize weighted loss function
    criterion = nn.CrossEntropyLoss(weight=weight_tensor)
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        epoch_train_loss = train_loss/len(train_loader)
        epoch_train_acc = 100 * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        epoch_val_loss = val_loss/len(val_loader)
        epoch_val_acc = 100 * val_correct / val_total

        history['train_loss'].append(epoch_train_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)

        print(f'Epoch {epoch+1}:')
        print(f'Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_acc:.2f}%')
        print(f'Val Loss: {epoch_val_loss:.4f}, Val Accuracy: {epoch_val_acc:.2f}%')

    # Final Test Phase
    model.eval()
    test_correct = 0
    test_total = 0
    test_loss = 0.0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    print('\nFinal Test Results:')
    print(f'Test Loss: {test_loss/len(test_loader):.4f}')
    print(f'Test Accuracy: {100 * test_correct / test_total:.2f}%')

    return history, model

load models:

In [None]:
def load_model_from_checkpoint(checkpoint_path, model_class):
    """
    Load a model from a checkpoint file.

    Args:
        checkpoint_path (str): Path to the checkpoint file
        model_class (nn.Module): The model class to instantiate

    Returns:
        tuple: (loaded_model, history)
    """
    # Load the full checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    # Initialize model
    model = model_class()

    # Load the model state dict
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        history = checkpoint.get('history', None)
    else:
        # If the checkpoint is just the state dict
        model.load_state_dict(checkpoint)
        history = None

    # Set model to evaluation mode
    model.eval()

    return model, history

# Load all three models
model_paths = {
    'base': 'drive/MyDrive/results/archive21/base_model.pth',
    'trivial_aug': 'drive/MyDrive/results/archive21/TA_model_3.pth',
    'gans': 'drive/MyDrive/results/archive21/gan_model_2.pth',
    'combined': 'drive/MyDrive/results/archive21/combined_model_2.pth'
}

models = {}
histories = {}

for model_name, path in model_paths.items():
    try:
        model, history = load_model_from_checkpoint(path, BrainCNN)
        models[model_name] = model
        histories[model_name] = history
        print(f"Successfully loaded {model_name} model")
    except Exception as e:
        print(f"Error loading {model_name} model: {str(e)}")

  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


Successfully loaded base model
Successfully loaded trivial_aug model
Successfully loaded gans model
Successfully loaded combined model


In [None]:
base_model = models['base']
ta_model = models['trivial_aug']
gan_model = models['gans']
combined_model = models['combined']

see misclassified :

In [None]:
def get_misclassified_indices(model, test_loader):
    misclassified = []
    base_idx = 0

    model.eval()
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(test_loader):
            outputs = model(data.to(next(model.parameters()).device))
            predictions = outputs.argmax(dim=1)

            # Get indices where predictions don't match targets
            incorrect = predictions.cpu() != targets
            batch_indices = incorrect.nonzero().squeeze().tolist()

            # Convert batch indices to dataset indices
            if isinstance(batch_indices, int):
                batch_indices = [batch_indices]

            dataset_indices = [base_idx + i for i in batch_indices]
            misclassified.extend(dataset_indices)

            base_idx += len(data)

    return misclassified

# Usage for each model
misclassified_indices = {
    model_name: get_misclassified_indices(model, test_loader_norm)
    for model_name, model in models_dict.items()
}

In [None]:
misclassified_indices

# <a id='gen'>Fidelity: LIME</a>

In [28]:
def normalize_heatmap(heatmap):
    # Convert tensor to numpy if needed
    if isinstance(heatmap, torch.Tensor):
        heatmap = heatmap.cpu().numpy()

    # Ensure heatmap is 2D
    if len(heatmap.shape) > 2:
        # Check if channels are last dimension
        if heatmap.shape[-1] in [3, 4]:
            heatmap = np.mean(heatmap, axis=-1)
        # Check if channels are first dimension
        elif heatmap.shape[0] in [3, 4]:
            heatmap = np.mean(heatmap, axis=0)

    # Normalize to [0,1]
    if heatmap.max() - heatmap.min() != 0:
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

    return heatmap

def prepare_image(image, denormalize=True):
    if torch.is_tensor(image):
        image = image.cpu().numpy()

    if image.shape[0] == 3:  # CHW to HWC
        image = np.transpose(image, (1, 2, 0))

    if denormalize:
        mean = np.array([0.1948, 0.1948, 0.1948])
        std = np.array([0.1768, 0.1768, 0.1768])
        image = image * std + mean

    return np.clip(image, 0, 1)

from skimage.segmentation import felzenszwalb

def get_specific_batch_images(loader, index):
    """Get specific image from DataLoader"""
    for i, (images, labels) in enumerate(loader):
        if i * loader.batch_size <= index < (i + 1) * loader.batch_size:
            batch_idx = index - (i * loader.batch_size)
            return images[batch_idx], labels[batch_idx]
    return None, None

def generate_lime_explanation(model, input_image, class_names, num_samples=1000):
    """Generate LIME explanations for a single input image."""
    def batch_predict(images):
        model.eval()
        batch = torch.stack([torch.from_numpy(i).permute(2, 0, 1).float() for i in images], dim=0)
        device = next(model.parameters()).device
        batch = batch.to(device)
        with torch.no_grad():
            logits = model(batch)
            probs = F.softmax(logits, dim=1)
        return probs.cpu().numpy()

    if torch.is_tensor(input_image):
        input_image = prepare_image(input_image, denormalize=True)

    # First get the model's actual prediction
    with torch.no_grad():
        device = next(model.parameters()).device
        input_tensor = torch.from_numpy(input_image.transpose(2, 0, 1)).float().unsqueeze(0).to(device)
        logits = model(input_tensor)
        probs = F.softmax(logits, dim=1)
        pred_class = torch.argmax(probs[0]).item()
        confidence = probs[0][pred_class].item() * 100

    explainer = lime_image.LimeImageExplainer()
    segmenter = SegmentationAlgorithm('felzenszwalb', scale=100, sigma=0.8, min_size=50)

    explanation = explainer.explain_instance(
        input_image,
        batch_predict,
        labels= len(class_names),
        hide_color=0,
        num_samples=num_samples,
        segmentation_fn=segmenter
    )

    exp_img, mask = explanation.get_image_and_mask(
        pred_class,
        positive_only=False,
        num_features=20,
        hide_rest=False,
        min_weight=0.01
    )

    mask = normalize_heatmap(mask)

    return exp_img, mask, pred_class, confidence

def visualize_results(results, original_image, class_names):
    # Modified to only show LIME explanations
    fig = plt.figure(figsize=(20, 4))
    fig.suptitle('LIME Explanations Comparison', fontsize=16)

    # Plot original image
    ax = plt.subplot(1, 5, 1)
    orig_img = prepare_image(original_image, denormalize=True)
    ax.imshow(orig_img)
    ax.set_title('Original Image')
    ax.axis('off')

    # Plot LIME for each model
    for idx, (model_name, result) in enumerate(results.items(), start=2):
        pred_class = result.get('predicted_class', None)
        pred_info = ""
        if pred_class is not None and class_names is not None:
            pred_info = f"Pred: {class_names[pred_class]}"
            if 'probabilities' in result:
                conf = result['probabilities'][pred_class] * 100
                pred_info += f"\nConf: {conf:.1f}%"

        ax = plt.subplot(1, 5, idx)
        ax.imshow(orig_img)
        ax.imshow(result['lime'], cmap='RdYlBu_r', alpha=0.5)
        ax.set_title(f'{model_name.capitalize()}\n{pred_info}')
        ax.axis('off')

    plt.tight_layout()
    plt.show()


def visualize_model_comparisons(models_dict, test_loader_norm, test_loader_unnorm, class_names, image_index=0):
    """
    Visualize LIME explanations for all models and compare their predictions.
    Handles cases where explanations for a label are not available.
    """
    # Get images
    image_norm, label = get_specific_batch_images(test_loader_norm, image_index)
    image_unnorm, _ = get_specific_batch_images(test_loader_unnorm, image_index)

    # Get predictions
    direct_preds = {}
    for model_name, model in models_dict.items():
        with torch.no_grad():
            device = next(model.parameters()).device
            input_tensor = image_norm.unsqueeze(0).to(device)
            logits = model(input_tensor)
            probs = F.softmax(logits, dim=1)
            pred_class = torch.argmax(probs[0]).item()
            confidence = probs[0][pred_class].item() * 100
            direct_preds[model_name] = {
                'class': pred_class,
                'confidence': confidence,
                'probabilities': probs[0].cpu().numpy()
            }

    true_class = class_names[label]
    print(f"\nTrue class: {true_class}")
    print("Model predictions:")
    for model_name, pred in direct_preds.items():
        print(f"{model_name}: {class_names[pred['class']]} ({pred['confidence']:.1f}%)")

    # Create visualization
    fig = plt.figure(figsize=(20, 8))
    fig.suptitle(f'LIME Explanations Comparison\nTrue Class: {true_class}', fontsize=16)

    # Plot original image
    ax = plt.subplot(2, 5, 1)
    orig_img = prepare_image(image_unnorm, denormalize=False)
    ax.imshow(orig_img)
    ax.set_title(f'Original Image\nTrue: {true_class}')
    ax.axis('off')

    # Generate explanations for each model
    print("\nGenerating LIME explanations:")
    for idx, (model_name, model) in enumerate(models_dict.items(), start=1):
        print(f"\nProcessing {model_name}...")

        pred_info = direct_preds[model_name]
        pred_class = pred_info['class']
        confidence = pred_info['confidence']

        def batch_predict(images):
            model.eval()
            batch = torch.stack([torch.from_numpy(i).permute(2, 0, 1).float() for i in images], dim=0)
            device = next(model.parameters()).device
            batch = batch.to(device)
            with torch.no_grad():
                logits = model(batch)
                probs = F.softmax(logits, dim=1)
            return probs.cpu().numpy()

        input_image = prepare_image(image_norm, denormalize=True)
        explainer = lime_image.LimeImageExplainer()
        segmenter = SegmentationAlgorithm('felzenszwalb', scale=100, sigma=0.8, min_size=50)

        # Generate explanation for all classes
        explanation = explainer.explain_instance(
            input_image,
            batch_predict,
            top_labels=len(class_names),
            hide_color=0,
            num_samples=1000,
            segmentation_fn=segmenter
        )

        # Plot explanation for predicted class (if available)
        if pred_class in explanation.local_exp:
            exp_img_pred, mask_pred = explanation.get_image_and_mask(
                pred_class,
                positive_only=False,
                num_features=20,
                hide_rest=False,
                min_weight=0.01
            )
            mask_pred = normalize_heatmap(mask_pred)

            ax = plt.subplot(2, 5, idx + 1)
            ax.imshow(orig_img)
            ax.imshow(mask_pred, cmap='RdBu', alpha=0.5)
            ax.set_title(f'{model_name.capitalize()}\nPredicted: {class_names[pred_class]}\nConf: {confidence:.1f}%')
            ax.axis('off')
        else:
            print(f"Warning: No explanation found for predicted class {class_names[pred_class]} in {model_name}.")

        # Plot explanation for true class (if available)
        if label in explanation.local_exp:
            exp_img_true, mask_true = explanation.get_image_and_mask(
                label,
                positive_only=False,
                num_features=20,
                hide_rest=False,
                min_weight=0.01
            )
            mask_true = normalize_heatmap(mask_true)

            ax = plt.subplot(2, 5, idx + 6)
            ax.imshow(orig_img)
            ax.imshow(mask_true, cmap='RdBu', alpha=0.5)
            ax.set_title(f'True Class Features\n{true_class}')
            ax.axis('off')
        else:
            print(f"Warning: No explanation found for true class {true_class} in {model_name}.")

    plt.tight_layout()
    plt.show()


In [None]:
models_dict = {
    'base': base_model,
    'trivial_aug': ta_model,
    'gans': gan_model,
    'combined': combined_model
}

class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']

# Run visualization
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=400) #meningioma cases
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=415)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=425)

In [None]:
# Run visualization
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=350) #meningioma cases
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=600)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=550)

In [None]:
# Run visualization
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=192) #glioma
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=278)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=285)

In [None]:
# Run visualization
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=68) #glioma
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=287)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=38)

In [None]:
# Run visualization
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=968) #no tumor
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=987)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=938)

In [None]:
# Run visualization
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1100) #pitu
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1200)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1195)

In [None]:
# Run visualization
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1211) #pitu
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1310)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1019)

focus on misclassified:

In [None]:
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=30)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=44)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=45)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=50)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=229)

In [None]:
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=3)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=43)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=4)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=44)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1089)

class-wise (base-focused):

In [None]:
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=702)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=635)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=701)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=714)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=843)

miss-classified by combined:

In [None]:
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=159)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1043)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=224)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=356)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=548)

In [None]:
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=242)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=361)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=583)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=1013)
visualize_model_comparisons(models_dict, test_loader_norm, test_loader, class_names,  image_index=20)

Overall fidelity scores for LIME:

for smaller set:

In [None]:
def calculate_class_fidelities(models_dict, test_loader_norm, class_names, num_samples=5):  # Reduced from 50 to 5
    """
    Calculate average LIME fidelity scores for each model and class using fewer samples.

    Args:
        models_dict: Dictionary of models
        test_loader_norm: Normalized test data loader
        class_names: List of class names
        num_samples: Number of images to evaluate per class (default reduced to 5)
    """
    # Initialize storage for fidelity scores
    fidelity_scores = {model_name: {class_name: [] for class_name in class_names}
                      for model_name in models_dict.keys()}

    # Get a limited number of images and labels
    all_images = []
    all_labels = []
    max_images_per_class = num_samples + 2  # Add small buffer
    class_counts = {i: 0 for i in range(len(class_names))}

    for images, labels in test_loader_norm:
        for img, label in zip(images, labels):
            label_idx = label.item()
            if class_counts[label_idx] < max_images_per_class:
                all_images.append(img)
                all_labels.append(label)
                class_counts[label_idx] += 1

        # Check if we have enough images for each class
        if all(count >= max_images_per_class for count in class_counts.values()):
            break

    # Convert to numpy for easier indexing
    all_labels = np.array([label.item() for label in all_labels])

    # Process each class
    for class_idx, class_name in enumerate(class_names):
        print(f"\nProcessing class: {class_name}")

        # Get indices for this class
        class_indices = np.where(all_labels == class_idx)[0]

        # Sample images for this class
        selected_indices = np.random.choice(class_indices,
                                          size=min(num_samples, len(class_indices)),
                                          replace=False)

        # Process each selected image
        for img_idx in selected_indices:
            image = all_images[img_idx]

            # Process each model
            for model_name, model in models_dict.items():
                print(f"Processing model {model_name}, image {img_idx}")

                # Rest of the code remains the same...
                def batch_predict(images):
                    model.eval()
                    batch = torch.stack([torch.from_numpy(i).permute(2, 0, 1).float() for i in images], dim=0)
                    device = next(model.parameters()).device
                    batch = batch.to(device)
                    with torch.no_grad():
                        logits = model(batch)
                        probs = F.softmax(logits, dim=1)
                    return probs.cpu().numpy()

                input_image = prepare_image(image, denormalize=True)

                # Get model's prediction
                with torch.no_grad():
                    device = next(model.parameters()).device
                    input_tensor = image.unsqueeze(0).to(device)
                    logits = model(input_tensor)
                    probs = F.softmax(logits, dim=1)
                    pred_class = torch.argmax(probs[0]).item()

                # Generate LIME explanation
                explainer = lime_image.LimeImageExplainer()
                segmenter = SegmentationAlgorithm('felzenszwalb', scale=100, sigma=0.8, min_size=50)

                explanation = explainer.explain_instance(
                    input_image,
                    batch_predict,
                    top_labels=len(class_names),  # Modified to explain all classes
                    hide_color=0,
                    num_samples=1000,
                    segmentation_fn=segmenter
                )

                # Store fidelity score
                fidelity_scores[model_name][class_name].append(explanation.score)

    # Calculate and display average fidelity scores
    print("\nAverage Fidelity Scores per Class:")
    for model_name in models_dict.keys():
        print(f"\n{model_name}:")
        for class_name in class_names:
            scores = fidelity_scores[model_name][class_name]
            avg_score = np.mean(scores) if scores else 0
            std_score = np.std(scores) if scores else 0
            print(f"{class_name}: {avg_score:.3f} ± {std_score:.3f}")

    return fidelity_scores

In [None]:
# Calculate fidelity scores with fewer samples
fidelity_scores = calculate_class_fidelities(
    models_dict=models_dict,
    test_loader_norm=test_loader_norm,
    class_names=class_names,
    num_samples=5
)

# Save the scores
import json
with open('fidelity_scores.json', 'w') as f:
    json.dump({k: {c: list(map(float, v)) for c, v in v.items()}
              for k, v in fidelity_scores.items()}, f)

larger set:

In [None]:
# Calculate fidelity scores with fewer samples
fidelity_scores = calculate_class_fidelities(
    models_dict=models_dict,
    test_loader_norm=test_loader_norm,
    class_names=class_names,
    num_samples=50
)

# Save the scores
import json
with open('fidelity_scores.json', 'w') as f:
    json.dump({k: {c: list(map(float, v)) for c, v in v.items()}
              for k, v in fidelity_scores.items()}, f)

# <a id='gen'>Sanity checks (Grad-CAM)</a>

In [None]:

class SaveFeatures():
    def __init__(self, module):
        self.features = None
        self.gradient = None
        self.hook = module.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.features = output

    def remove(self):
        self.hook.remove()

def get_last_conv_layer(model):
    """Find the last convolutional layer in the model."""
    last_conv_layer = None
    for module in model.features:
        if isinstance(module, torch.nn.Conv2d):
            last_conv_layer = module
    return last_conv_layer


def generate_gradcam(model, image, target_class=None):
    model.eval()
    features = None

    def save_features(module, input, output):
        nonlocal features
        features = output
        features.retain_grad()

    # Register hook on last conv layer
    for module in reversed(model.features):
        if isinstance(module, torch.nn.Conv2d):
            handle = module.register_forward_hook(save_features)
            break

    # Forward pass
    output = model(image.unsqueeze(0))

    if target_class is None:
        target_class = output.argmax(dim=1).item()

    # Backward pass
    model.zero_grad()
    output[0, target_class].backward()

    # Calculate Grad-CAM
    pooled_grads = torch.mean(features.grad, dim=[0, 2, 3])
    for i in range(features.shape[1]):
        features[:, i, :, :] *= pooled_grads[i]

    heatmap = torch.mean(features, dim=1).squeeze()
    heatmap = F.relu(heatmap)
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() + 1e-8)

    handle.remove()
    return heatmap

def compare_models_gradcam(models, image, class_names=None):
    num_models = len(models)
    plt.figure(figsize=(5*num_models, 5))

    for idx, (name, model) in enumerate(models.items(), 1):
        plt.subplot(1, num_models, idx)

        with torch.no_grad():
            pred = model(image.unsqueeze(0)).argmax().item()

        heatmap = generate_gradcam(model, image).detach().cpu().numpy()
        image_np = image.permute(1, 2, 0).cpu().numpy()
        image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())

        heatmap = transforms.functional.resize(
            torch.from_numpy(heatmap).unsqueeze(0).unsqueeze(0),
            image_np.shape[:2]
        ).squeeze().numpy()

        plt.imshow(image_np)
        plt.imshow(heatmap, cmap='jet', alpha=0.5)
        title = f"{name}\n"
        if class_names:
            title += f"Pred: {class_names[pred]}"
        plt.title(title)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Example usage
models_dict = {
    'base': base_model,
    'trivial_aug': ta_model,
    'gans': gan_model,
    'combined': combined_model
}

In [None]:
# Get a sample image from test loader
images, labels = next(iter(test_loader_norm))
image = images[25]

# Compare Grad-CAM across all models
compare_models_gradcam(models_dict, image)

with predictions:

In [None]:

class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']

def show_gradcam_with_predictions(models_dict, dataset, index=0, device='cuda'):
    """
    Show GradCAM visualization for a specific image index from the dataset.

    Args:
        models_dict (dict): Dictionary of models to generate GradCAM for
        dataset: The full dataset (not DataLoader)
        index (int): Index of the image to visualize
        device (str): Device to run the models on ('cuda' or 'cpu')
    """
    # Create a single-item subset and dataloader
    single_item_subset = Subset(dataset, [index])
    single_loader = DataLoader(single_item_subset, batch_size=1, shuffle=False)

    # Get the specific image
    image, label = next(iter(single_loader))
    image = image[0].to(device)  # Remove batch dimension and move to device
    true_label = label[0].item()

    num_plots = len(models_dict) + 1  # +1 for original image
    plt.figure(figsize=(5*num_plots, 5))

    # Original image
    plt.subplot(1, num_plots, 1)
    image_np = image.cpu().permute(1, 2, 0).numpy()
    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
    plt.imshow(image_np)
    plt.title(f"Original Image\nTrue Class: {class_names[true_label]}")
    plt.axis('off')

    # GradCAM for each model
    for idx, (name, model) in enumerate(models_dict.items(), 2):
        plt.subplot(1, num_plots, idx)

        # Create a deep copy of the model for GradCAM
        model_copy = copy.deepcopy(model)
        model_copy.eval()
        model_copy.to(device)

        # Get prediction from original model
        model.eval()
        with torch.no_grad():
            pred = model(image.unsqueeze(0))
            pred_class = pred.argmax().item()
            confidence = torch.softmax(pred, dim=1)[0, pred_class].item()

        # Generate GradCAM using the copy
        heatmap = generate_gradcam(model_copy, image.clone()).detach().cpu().numpy()

        # Clear memory
        del model_copy
        torch.cuda.empty_cache()

        # Resize heatmap to match image size
        heatmap = transforms.functional.resize(
            torch.from_numpy(heatmap).unsqueeze(0).unsqueeze(0),
            image_np.shape[:2]
        ).squeeze().numpy()

        plt.imshow(image_np)
        plt.imshow(heatmap, cmap='jet', alpha=0.5)
        plt.title(f"{name}\nPred: {class_names[pred_class]}\nConf: {confidence:.2%}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

misclassified examples:

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Use the underlying dataset instead:
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=30, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=44, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=45, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=50, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=229, device=device)


In [None]:
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=400, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=300, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=500, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=1000, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=1100, device=device)


In [None]:
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=242, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=361, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=583, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=1013, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=20, device=device)

In [None]:
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=400, device=device) #menings
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=415, device=device)
show_gradcam_with_predictions(models_dict, test_dataset_norm, index=425, device=device)

In [None]:
def analyze_gradcam_activation(heatmap, image):
    """
    Analyze the Grad-CAM activation map characteristics
    """
    # Calculate activation statistics
    activation_mean = np.mean(heatmap)
    activation_std = np.std(heatmap)

    # Calculate coverage (percentage of image with significant activation)
    significant_threshold = 0.5  # Threshold for significant activation
    coverage = np.mean(heatmap > significant_threshold)

    # Calculate focus (ratio of max activation to mean activation)
    focus_ratio = np.max(heatmap) / (activation_mean + 1e-6)

    return {
        'mean_activation': float(activation_mean),
        'activation_std': float(activation_std),
        'coverage': float(coverage),
        'focus_ratio': float(focus_ratio)
    }

def summarize_gradcam_results(models_dict, test_loader, class_names=None):
    """
    Generate a comprehensive summary of Grad-CAM results across models for the entire test set
    """
    summary = {model_name: {'correct': 0, 'stats': []} for model_name in models_dict.keys()}
    total_samples = len(test_loader.dataset)

    print(f"Analyzing {total_samples} images...")

    # Process all batches
    for images, labels in tqdm(test_loader, desc="Processing batches"):
        for i in range(len(images)):
            image = images[i]
            true_label = labels[i].item()

            for name, model in models_dict.items():
                # Get prediction
                with torch.no_grad():
                    pred = model(image.unsqueeze(0))
                    pred_class = pred.argmax().item()
                    confidence = torch.softmax(pred, dim=1)[0, pred_class].item()

                # Generate and analyze Grad-CAM
                heatmap = generate_gradcam(model, image).detach().cpu().numpy()
                heatmap = transforms.functional.resize(
                    torch.from_numpy(heatmap).unsqueeze(0).unsqueeze(0),
                    image.shape[1:]
                ).squeeze().numpy()

                # Analyze activation
                stats = analyze_gradcam_activation(heatmap, image)
                stats['confidence'] = confidence
                stats['correct_prediction'] = (pred_class == true_label)

                # Update summary
                summary[name]['stats'].append(stats)
                if pred_class == true_label:
                    summary[name]['correct'] += 1

    return format_gradcam_summary(summary, total_samples, class_names)

def format_gradcam_summary(summary, total_samples, class_names):
    """
    Format the Grad-CAM analysis results into a readable report
    """
    report = "Grad-CAM Analysis Summary\n" + "="*50 + "\n\n"
    report += f"Total samples analyzed: {total_samples}\n\n"

    for model_name, model_data in summary.items():
        report += f"\nModel: {model_name}\n{'-'*30}\n"

        # Accuracy
        accuracy = model_data['correct'] / total_samples
        report += f"Accuracy: {accuracy:.2%}\n"

        # Average statistics
        stats = model_data['stats']
        avg_stats = {
            'mean_activation': np.mean([s['mean_activation'] for s in stats]),
            'activation_std': np.mean([s['activation_std'] for s in stats]),
            'coverage': np.mean([s['coverage'] for s in stats]),
            'focus_ratio': np.mean([s['focus_ratio'] for s in stats]),
            'confidence': np.mean([s['confidence'] for s in stats])
        }

        report += f"Average Activation: {avg_stats['mean_activation']:.3f}\n"
        report += f"Average Coverage: {avg_stats['coverage']:.2%}\n"
        report += f"Focus Ratio: {avg_stats['focus_ratio']:.2f}\n"
        report += f"Average Confidence: {avg_stats['confidence']:.2%}\n"

        # Standard deviations
        std_stats = {
            'activation_std': np.std([s['mean_activation'] for s in stats]),
            'coverage_std': np.std([s['coverage'] for s in stats]),
            'focus_ratio_std': np.std([s['focus_ratio'] for s in stats]),
            'confidence_std': np.std([s['confidence'] for s in stats])
        }

        report += f"\nVariability Analysis:\n"
        report += f"Activation Std: {std_stats['activation_std']:.3f}\n"
        report += f"Coverage Std: {std_stats['coverage_std']:.2%}\n"
        report += f"Focus Ratio Std: {std_stats['focus_ratio_std']:.2f}\n"
        report += f"Confidence Std: {std_stats['confidence_std']:.2%}\n"

        # Analyze activation patterns
        report += "\nActivation Analysis:\n"
        if avg_stats['coverage'] > 0.7:
            report += "- Wide activation pattern (might be looking at too much)\n"
        elif avg_stats['coverage'] < 0.3:
            report += "- Focused activation pattern (concentrated attention)\n"

        if avg_stats['focus_ratio'] > 5:
            report += "- High focus ratio (very specific features)\n"
        elif avg_stats['focus_ratio'] < 2:
            report += "- Low focus ratio (more distributed attention)\n"

        # Add histogram visualization of key metrics
        plt.figure(figsize=(15, 10))
        plt.subplot(2, 2, 1)
        plt.hist([s['mean_activation'] for s in stats], bins=30)
        plt.title('Distribution of Mean Activation')

        plt.subplot(2, 2, 2)
        plt.hist([s['coverage'] for s in stats], bins=30)
        plt.title('Distribution of Coverage')

        plt.subplot(2, 2, 3)
        plt.hist([s['focus_ratio'] for s in stats], bins=30)
        plt.title('Distribution of Focus Ratio')

        plt.subplot(2, 2, 4)
        plt.hist([s['confidence'] for s in stats], bins=30)
        plt.title('Distribution of Confidence')

        plt.tight_layout()
        plt.show()

    return report

# Example usage:
def run_gradcam_analysis(models_dict, test_loader, class_names):
    # Run analysis on entire test set
    summary = summarize_gradcam_results(
        models_dict,
        test_loader,
        class_names=class_names
    )
    print(summary)

In [None]:
run_gradcam_analysis(models_dict, test_loader_norm, class_names)

SSIM

In [None]:
def randomize_layer_weights(model, num_layers):
    """Randomize weights of last n layers"""
    params = list(model.parameters())
    for i in range(min(num_layers * 2, len(params))):  # *2 because each layer has weights and bias
        original_weights = params[-(i+1)].data.clone()
        params[-(i+1)].data = torch.randn_like(original_weights)
    return original_weights, -(i+1)

def reset_layer_weights(model, weights, layer_idx):
    """Reset weights of a layer back to original"""
    params = list(model.parameters())
    params[layer_idx].data = weights

def calculate_ssim_masks_batch(models_dict, test_loader, num_layers=6, num_samples=None):
    """
    Calculate SSIM scores for multiple images

    Args:
        models_dict: Dictionary of models to analyze
        test_loader: DataLoader containing test images
        num_layers: Number of layers to progressively randomize
        num_samples: Optional number of samples to analyze (None for all)
    """
    # Initialize storage for SSIM scores
    ssim_scores = {name: [] for name in models_dict.keys()}
    x_labels = ['original', 'classifier', 'denseblock4', 'denseblock3',
                'denseblock2', 'denseblock1'][:num_layers+1]

    # Count total samples to process
    total_samples = len(test_loader.dataset) if num_samples is None else num_samples
    samples_processed = 0

    # Process each batch
    for images, _ in tqdm(test_loader, desc="Processing images"):
        batch_scores = {name: [] for name in models_dict.keys()}

        # Process each image in the batch
        for image in images:
            if samples_processed >= total_samples and num_samples is not None:
                break

            for name, model in models_dict.items():
                # Generate original mask
                original_mask = generate_gradcam(model, image).detach().cpu().numpy()
                current_scores = [1.0]  # SSIM with itself = 1

                # Store original weights for reset
                original_weights = []
                layer_indices = []

                # Progressive randomization
                for layer in range(num_layers):
                    weights, idx = randomize_layer_weights(model, layer + 1)
                    original_weights.append(weights)
                    layer_indices.append(idx)

                    # Generate new mask and calculate SSIM
                    random_mask = generate_gradcam(model, image).detach().cpu().numpy()
                    score = ssim(original_mask, random_mask, data_range=1.0)
                    current_scores.append(score)

                    # Reset weights for next iteration
                    for w, i in zip(original_weights[:-1], layer_indices[:-1]):
                        reset_layer_weights(model, w, i)

                batch_scores[name].append(current_scores)

            samples_processed += 1

        # Aggregate batch results
        for name in models_dict.keys():
            if not ssim_scores[name]:  # First batch
                ssim_scores[name] = [[] for _ in range(num_layers + 1)]
            for scores in batch_scores[name]:
                for i, score in enumerate(scores):
                    ssim_scores[name][i].append(score)

    # Calculate statistics
    ssim_stats = {name: {
        'mean': [np.mean(layer_scores) for layer_scores in model_scores],
        'std': [np.std(layer_scores) for layer_scores in model_scores]
    } for name, model_scores in ssim_scores.items()}

    # Plot results with error bars
    plt.figure(figsize=(12, 7))
    for name, stats in ssim_stats.items():
        plt.errorbar(range(len(stats['mean'])),
                    stats['mean'],
                    yerr=stats['std'],
                    marker='o',
                    label=f'{name}',
                    capsize=5)

    plt.xticks(range(len(x_labels)), x_labels, rotation=45)
    plt.ylabel('SSIM score (mean ± std)')
    plt.xlabel('Randomized layers')
    plt.grid(True)
    plt.legend()
    plt.title(f'Sanity Check: SSIM Similarity vs. Layer Randomization\n(n={samples_processed} images)')
    plt.tight_layout()
    plt.show()

    return ssim_stats

# Example usage:
def run_ssim_analysis(models_dict, test_loader, num_samples=None):
    """
    Run SSIM analysis on multiple images

    Args:
        models_dict: Dictionary of models to analyze
        test_loader: DataLoader containing test images
        num_samples: Optional number of samples to analyze (None for all)
    """
    stats = calculate_ssim_masks_batch(
        models_dict,
        test_loader,
        num_layers=6,
        num_samples=num_samples
    )
    return stats

In [None]:
stats = run_ssim_analysis(models_dict, test_loader_norm, num_samples=100)

# <a id='gen'>Sparsity using shap: </a>

(clean the backward hook from grad-cam)

In [None]:
def calculate_shap_sparsity(models_dict, test_loader_norm, class_names, num_samples=15):
    """Calculate SHAP sparsity scores for each model and class with optimized performance.

    Args:
        models_dict (dict): Dictionary of model name to model
        test_loader_norm (DataLoader): Test data loader
        class_names (list): List of class names
        num_samples (int): Number of samples per class to analyze
    """


    sparsity_scores = {model_name: {class_name: [] for class_name in class_names}
                      for model_name in models_dict.keys()}

    # Collect samples per class
    print("Collecting samples...")
    class_samples = {class_name: [] for class_name in class_names}
    for images, labels in test_loader_norm:
        for img, label in zip(images, labels):
            class_name = class_names[label.item()]
            if len(class_samples[class_name]) < num_samples:
                class_samples[class_name].append(img)

        # Check if we have enough samples
        if all(len(samples) >= num_samples for samples in class_samples.values()):
            break

    # Process each model
    for model_name, model in models_dict.items():
        print(f"\nProcessing model: {model_name}")
        model.eval()

        # Move model to GPU if available
        if torch.cuda.is_available():
            model = model.cuda()

        # Process each class
        for class_name in class_names:
            print(f"\nProcessing class: {class_name}")
            samples = class_samples[class_name][:num_samples]

            if not samples:
                continue

            # Process in small batches
            batch_size = 5
            for i in range(0, len(samples), batch_size):
                batch = samples[i:i+batch_size]
                batch = torch.stack(batch)

                if torch.cuda.is_available():
                    batch = batch.cuda()

                try:
                    # Create explainer for this batch
                    background = torch.zeros_like(batch[0:1])
                    explainer = shap.GradientExplainer(model, background)

                    # Get SHAP values
                    shap_values = explainer.shap_values(batch)

                    # Process each image in batch
                    for idx in range(len(batch)):
                        if isinstance(shap_values, list):
                            # For multi-class output, use predicted class
                            with torch.no_grad():
                                pred_class = model(batch[idx:idx+1]).argmax().item()
                            shap_map = np.abs(shap_values[pred_class][idx])
                        else:
                            shap_map = np.abs(shap_values[idx])

                        # Calculate sparsity with 5% threshold
                        threshold = np.max(shap_map) * 0.05
                        sparsity = np.mean(shap_map < threshold)
                        sparsity_scores[model_name][class_name].append(sparsity)

                        print(f"Image {i+idx+1}/{len(samples)}: Sparsity = {sparsity:.3f}")

                except Exception as e:
                    print(f"Error processing batch: {str(e)}")
                    continue

                finally:
                    torch.cuda.empty_cache()

    # Print results
    print("\nFinal Sparsity Scores:")
    for model_name in models_dict.keys():
        print(f"\n{model_name}:")
        for class_name in class_names:
            scores = sparsity_scores[model_name][class_name]
            if scores:
                avg_score = np.mean(scores)
                std_score = np.std(scores)
                print(f"{class_name}: {avg_score:.3f} ± {std_score:.3f} (n={len(scores)})")
            else:
                print(f"{class_name}: No valid scores")

    return sparsity_scores

In [None]:
sparsity_scores = calculate_shap_sparsity(models_dict, test_loader_norm, class_names)

# <a id='gen'>Class-wise contrastivity - using shap: </a>

try on small set:

In [None]:
def calculate_class_contrast(models_dict, test_loader_norm, class_names, num_samples=10):  # Reduced to 2 samples
    """
    Calculate and visualize class-wise SHAP contrast for each model.
    Debug version with minimal samples.
    """
    class_contrasts = {}

    # Get all images and labels
    print("\nCollecting samples...")
    all_images = {class_name: [] for class_name in class_names}
    all_labels = {class_name: [] for class_name in class_names}

    with torch.no_grad():
        for images, labels in test_loader_norm:
            for img, label in zip(images, labels):
                class_name = class_names[label.item()]
                if len(all_images[class_name]) < num_samples:
                    all_images[class_name].append(img.cpu())
                    all_labels[class_name].append(label.cpu())

            # Check if we have enough samples for all classes
            if all(len(imgs) >= num_samples for imgs in all_images.values()):
                break

    # Print collection results
    for class_name in class_names:
        print(f"Collected {len(all_images[class_name])} images for {class_name}")

    if any(len(imgs) < num_samples for imgs in all_images.values()):
        print("Warning: Could not collect enough samples for all classes")

    for model_name, model in models_dict.items():
        print(f"\nProcessing model: {model_name}")
        class_shap_values = {class_name: [] for class_name in class_names}

        # Set model to eval mode
        model.eval()

        # Process each class
        for class_name in class_names:
            print(f"\nProcessing class: {class_name}")
            if not all_images[class_name]:
                print(f"Warning: No samples found for class {class_name}")
                continue

            for i, image in enumerate(all_images[class_name][:num_samples]):
                try:
                    print(f"\nProcessing image {i+1}/{num_samples}")
                    # Move image to same device as model
                    device = next(model.parameters()).device
                    image = image.to(device)

                    # Create background with zeros
                    background = torch.zeros_like(image.unsqueeze(0))
                    explainer = shap.GradientExplainer(model, background)

                    # Get SHAP values for the actual image
                    shap_values = explainer.shap_values(image.unsqueeze(0))
                    print(f"Raw SHAP values type: {type(shap_values)}")
                    print(f"Raw SHAP values shape/length: {len(shap_values) if isinstance(shap_values, list) else shap_values.shape}")

                    # Process SHAP values
                    if isinstance(shap_values, list):
                        shap_map = np.mean(np.abs(shap_values[0][0]), axis=(0, -1))  # Average across classes and channels
                    else:
                        # Shape is (1, 3, 128, 128, 4)
                        shap_map = np.mean(np.abs(shap_values[0]), axis=(0, -1))  # Average across classes and channels

                    print(f"Processed SHAP map shape: {shap_map.shape}")
                    print(f"SHAP values range: [{np.min(shap_map):.3f}, {np.max(shap_map):.3f}]")

                    class_shap_values[class_name].append(shap_map)

                except Exception as e:
                    print(f"Error processing image {img_idx}: {str(e)}")
                    continue

        # Calculate average SHAP values per class
        avg_shap_values = {
            class_name: np.mean(values, axis=0) if values else None
            for class_name, values in class_shap_values.items()
        }

        # Print average values info
        for class_name, avg_val in avg_shap_values.items():
            if avg_val is not None:
                print(f"\nAverage SHAP values for {class_name}:")
                print(f"Shape: {avg_val.shape}")
                print(f"Range: [{np.min(avg_val):.3f}, {np.max(avg_val):.3f}]")

        # Visualize class contrasts
        num_comparisons = len(list(itertools.combinations(class_names, 2)))
        if num_comparisons > 0:
            plt.figure(figsize=(15, 5*((num_comparisons+1)//2)))

            for i, (class1, class2) in enumerate(itertools.combinations(class_names, 2)):
                if avg_shap_values[class1] is None or avg_shap_values[class2] is None:
                    print(f"\nSkipping {class1} vs {class2} - missing values")
                    continue

                contrast = avg_shap_values[class1] - avg_shap_values[class2]
                print(f"\nContrast shape for {class1} vs {class2}: {contrast.shape}")
                print(f"Contrast range: [{np.min(contrast):.3f}, {np.max(contrast):.3f}]")

                # Handle normalization carefully
                vmax = np.max(np.abs(contrast))
                if vmax < 1e-6:
                    print(f"Warning: Very small contrast between {class1} and {class2}")
                    continue

                norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)

                plt.subplot((num_comparisons+1)//2, 2, i+1)
                plt.imshow(contrast, cmap='RdBu_r', norm=norm)
                plt.title(f'{class1} vs {class2}')
                plt.colorbar()

            plt.suptitle(f'Class Contrasts - {model_name}')
            plt.tight_layout()
            plt.show()
        else:
            print("Warning: Not enough classes for comparison")

        class_contrasts[model_name] = avg_shap_values

    return class_contrasts

In [None]:
contrasts = calculate_class_contrast(models_dict, test_loader_norm, class_names) #10 samples

In [None]:
contrasts = calculate_class_contrast(models_dict, test_loader_norm, class_names)

new version:

In [None]:
def calculate_class_contrast_matrix(models_dict, test_loader_norm, class_names, num_samples=10):
    """
    Calculate class-wise contrastivity matrix using SHAP values.
    Returns a matrix of contrastivity scores between each pair of classes.
    """
    class_contrasts = {}

    # Get all images and labels
    all_images = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader_norm:
            all_images.extend(images.cpu())
            all_labels.extend(labels.cpu())
    all_labels = np.array([label.item() for label in all_labels])

    for model_name, model in models_dict.items():
        print(f"\nProcessing model: {model_name}")
        class_shap_values = {class_name: [] for class_name in class_names}

        # Set model to eval mode
        model.eval()

        # Process each class
        for class_idx, class_name in enumerate(class_names):
            class_indices = np.where(all_labels == class_idx)[0]
            selected_indices = np.random.choice(class_indices,
                                             size=min(num_samples, len(class_indices)),
                                             replace=False)

            for img_idx in selected_indices:
                image = all_images[img_idx].to(next(model.parameters()).device)

                try:
                    # Get SHAP values
                    background = torch.zeros_like(image.unsqueeze(0))
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(image.unsqueeze(0))

                    # Process SHAP values
                    if isinstance(shap_values, list):
                        shap_map = np.mean(np.abs(shap_values[0][0]), axis=-1)
                    else:
                        shap_map = np.mean(np.abs(shap_values[0]), axis=-1)

                    class_shap_values[class_name].append(shap_map)

                except Exception as e:
                    print(f"Error processing image {img_idx}: {str(e)}")
                    continue

        # Calculate average SHAP values per class
        avg_shap_values = {
            class_name: np.mean(values, axis=0) if values else None
            for class_name, values in class_shap_values.items()
        }

        # Create contrastivity matrix
        n_classes = len(class_names)
        contrast_matrix = np.zeros((n_classes, n_classes))

        for i, class1 in enumerate(class_names):
            for j, class2 in enumerate(class_names):
                if i != j and avg_shap_values[class1] is not None and avg_shap_values[class2] is not None:
                    # Calculate contrast score (normalized absolute difference)
                    diff = avg_shap_values[class1] - avg_shap_values[class2]
                    #contrast_score = np.mean(np.abs(diff)) / (np.mean(np.abs(avg_shap_values[class1])) + np.mean(np.abs(avg_shap_values[class2])) + 1e-6)
                    contrast_score = np.mean(np.abs(diff))  # Remove the normalization
                    contrast_matrix[i, j] = contrast_score

        # Print formatted table
        print(f"\nClass-wise Contrastivity Scores for {model_name}")
        print("=" * 50)
        print(f"{'':15}", end="")
        for class_name in class_names:
            print(f"{class_name:12}", end="")
        print()

        for i, class1 in enumerate(class_names):
            print(f"{class1:15}", end="")
            for j, class2 in enumerate(class_names):
                if i == j:
                    print(f"{'-':12}", end="")
                else:
                    print(f"{contrast_matrix[i,j]:12.3f}", end="")
            print()

        class_contrasts[model_name] = contrast_matrix

    return class_contrasts

In [None]:
models_dict = {
    'base': base_model,
    'trivial_aug': ta_model,
    'gans': gan_model,
    'combined': combined_model
}


class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']

contrast_matrices = calculate_class_contrast_matrix(
    models_dict=models_dict,
    test_loader_norm=test_loader_norm,
    class_names=class_names,
    num_samples=10  # Adjust based on your needs
)

contrastivity heatmaps using SHAP:

In [None]:
def create_enhanced_heatmaps(models_dict, test_loader_norm, test_loader, class_names, num_samples=20):
    """
    Create enhanced SHAP feature importance heatmaps with better visibility
    """
    # Get all normalized images and labels
    all_images_norm = []
    all_labels = []

    for images, labels in test_loader_norm:
        all_images_norm.append(images)
        all_labels.append(labels)

    all_images_norm = torch.cat(all_images_norm, dim=0)
    all_labels = torch.cat(all_labels, dim=0).cpu().numpy()

    # Create figure with more space between subplots
    fig = plt.figure(figsize=(24, 20))
    gs = GridSpec(4, 4, figure=fig, hspace=0.3, wspace=0.3)
    fig.suptitle('Enhanced Class-Specific SHAP Feature Importance Maps', fontsize=16, y=0.95)

    # Store all SHAP values to normalize across all plots
    all_shap_values = []

    # First pass to collect all SHAP values
    for model_name, model in models_dict.items():
        print(f"\nCollecting SHAP values for model: {model_name}")
        model.eval()

        for class_idx, class_name in enumerate(class_names):
            class_indices = np.where(all_labels == class_idx)[0]
            selected_indices = np.random.choice(class_indices,
                                             size=min(num_samples, len(class_indices)),
                                             replace=False)

            avg_shap_values = None
            count = 0

            for img_idx in selected_indices:
                image = all_images_norm[img_idx].unsqueeze(0).to(next(model.parameters()).device)

                try:
                    background = torch.zeros_like(image)
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(image)

                    if isinstance(shap_values, list):
                        shap_map = np.mean(np.abs(shap_values[class_idx][0]), axis=0)
                    else:
                        shap_map = np.mean(np.abs(shap_values[0]), axis=0)

                    if avg_shap_values is None:
                        avg_shap_values = shap_map
                    else:
                        avg_shap_values += shap_map
                    count += 1

                except Exception as e:
                    print(f"Error processing image {img_idx}: {str(e)}")
                    continue

            if count > 0:
                avg_shap_values /= count
                all_shap_values.append(avg_shap_values)

    # Get global min and max for consistent scaling
    global_max = max(map(np.max, all_shap_values))

    # Second pass to create plots with consistent scaling
    shap_idx = 0
    for model_idx, (model_name, model) in enumerate(models_dict.items()):
        for class_idx, class_name in enumerate(class_names):
            ax = fig.add_subplot(gs[model_idx, class_idx])

            # Use imshow with enhanced settings
            im = ax.imshow(all_shap_values[shap_idx],
                         cmap='viridis',  # Changed to viridis for better visibility
                         vmin=0,
                         vmax=global_max)

            ax.set_title(f'{model_name}\n{class_name}', fontsize=12, pad=10)
            ax.axis('off')

            # Add colorbar with scientific notation
            plt.colorbar(im, ax=ax, label='SHAP value', format='%.2e')
            shap_idx += 1

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    return fig

In [None]:
fig = create_enhanced_heatmaps(models_dict, test_loader_norm, test_loader, class_names)
plt.show() # 20 samples

In [None]:
def process_shap_values(shap_values, class_idx=None):
    """
    Process SHAP values properly for RGB images
    """
    if isinstance(shap_values, list):
        if class_idx is not None:
            shap_map = shap_values[class_idx][0]  # Get specific class
        else:
            shap_map = shap_values[0][0]  # Get first class if no specific class
    else:
        shap_map = shap_values[0]

    # Average across channels if dealing with RGB
    if shap_map.shape[-1] == 3:  # RGB image
        return np.mean(np.abs(shap_map), axis=-1)
    return np.mean(np.abs(shap_map), axis=0)

def create_enhanced_contrast_heatmaps(models_dict, test_loader_norm, class_names, num_samples=20):
    """
    Create enhanced SHAP feature importance heatmaps using class contrasts
    """
    # Collect samples per class
    class_samples = {class_name: {'images': [], 'count': 0} for class_name in class_names}

    print("Collecting samples...")
    with torch.no_grad():
        for images, labels in test_loader_norm:
            for img, label in zip(images, labels):
                class_name = class_names[label.item()]
                if class_samples[class_name]['count'] < num_samples:
                    class_samples[class_name]['images'].append(img)
                    class_samples[class_name]['count'] += 1

            if all(samples['count'] >= num_samples for samples in class_samples.values()):
                break

    # Calculate number of comparisons
    num_comparisons = len(list(itertools.combinations(class_names, 2)))
    num_models = len(models_dict)

    # Create figure
    fig = plt.figure(figsize=(20, 5 * num_models))
    gs = GridSpec(num_models, num_comparisons, figure=fig, hspace=0.4, wspace=0.3)
    fig.suptitle('Class Contrast SHAP Feature Importance Maps', fontsize=16, y=0.95)

    for model_idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"\nProcessing model: {model_name}")
        model.eval()

        # Calculate average SHAP values for each class
        class_avg_shap = {}

        for class_name in class_names:
            print(f"Processing class: {class_name}")
            if not class_samples[class_name]['images']:
                print(f"Warning: No samples for class {class_name}")
                continue

            avg_shap_values = None
            count = 0

            for image in class_samples[class_name]['images']:
                try:
                    device = next(model.parameters()).device
                    image = image.to(device).unsqueeze(0)

                    background = torch.zeros_like(image)
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(image)

                    # Process SHAP values
                    shap_map = process_shap_values(shap_values)

                    if avg_shap_values is None:
                        avg_shap_values = shap_map
                    else:
                        avg_shap_values += shap_map
                    count += 1

                except Exception as e:
                    print(f"Error processing image: {str(e)}")
                    continue

            if count > 0:
                avg_shap_values /= count
                class_avg_shap[class_name] = avg_shap_values

        # Create contrast visualizations
        for comp_idx, (class1, class2) in enumerate(itertools.combinations(class_names, 2)):
            if class1 not in class_avg_shap or class2 not in class_avg_shap:
                print(f"Skipping {class1} vs {class2} - missing values")
                continue

            # Calculate contrast
            contrast = class_avg_shap[class1] - class_avg_shap[class2]

            # Normalize contrast to [-1, 1]
            vmax = np.max(np.abs(contrast))
            if vmax < 1e-6:
                print(f"Warning: Very small contrast between {class1} and {class2}")
                continue

            contrast_normalized = contrast / (vmax + 1e-7)

            # Create subplot
            ax = fig.add_subplot(gs[model_idx, comp_idx])
            im = ax.imshow(contrast_normalized,
                          cmap='RdBu_r',
                          vmin=-1,
                          vmax=1)

            ax.set_title(f'{model_name}\n{class1} vs {class2}', fontsize=10)
            plt.colorbar(im, ax=ax, format='%.2e')
            ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    return fig

def visualize_single_class_features(models_dict, test_loader_norm, class_names, num_samples=20):
    """
    Create individual class feature importance heatmaps
    """
    num_models = len(models_dict)
    num_classes = len(class_names)

    fig = plt.figure(figsize=(4 * num_classes, 4 * num_models))
    gs = GridSpec(num_models, num_classes, figure=fig, hspace=0.3, wspace=0.3)
    fig.suptitle('Single Class SHAP Feature Importance Maps', fontsize=16, y=0.95)

    # Collect samples per class
    class_samples = {class_name: {'images': [], 'count': 0} for class_name in class_names}

    print("Collecting samples...")
    with torch.no_grad():
        for images, labels in test_loader_norm:
            for img, label in zip(images, labels):
                class_name = class_names[label.item()]
                if class_samples[class_name]['count'] < num_samples:
                    class_samples[class_name]['images'].append(img)
                    class_samples[class_name]['count'] += 1

            if all(samples['count'] >= num_samples for samples in class_samples.values()):
                break

    all_shap_values = []

    for model_idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"\nProcessing model: {model_name}")
        model.eval()

        for class_idx, class_name in enumerate(class_names):
            if not class_samples[class_name]['images']:
                print(f"Warning: No samples for class {class_name}")
                continue

            avg_shap_values = None
            count = 0

            for image in class_samples[class_name]['images']:
                try:
                    device = next(model.parameters()).device
                    image = image.to(device).unsqueeze(0)

                    background = torch.zeros_like(image)
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(image)

                    # Process SHAP values
                    shap_map = process_shap_values(shap_values, class_idx)

                    if avg_shap_values is None:
                        avg_shap_values = shap_map
                    else:
                        avg_shap_values += shap_map
                    count += 1

                except Exception as e:
                    print(f"Error processing image: {str(e)}")
                    continue

            if count > 0:
                avg_shap_values /= count
                all_shap_values.append(avg_shap_values)

                # Normalize values to [0, 1]
                normalized_values = avg_shap_values / (np.max(avg_shap_values) + 1e-7)

                ax = fig.add_subplot(gs[model_idx, class_idx])
                im = ax.imshow(normalized_values,
                             cmap='viridis',
                             vmin=0,
                             vmax=1)

                ax.set_title(f'{model_name}\n{class_name}', fontsize=10)
                plt.colorbar(im, ax=ax, format='%.2e')
                ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    return fig

In [None]:
contrast_fig = create_enhanced_contrast_heatmaps(
    models_dict,
    test_loader_norm,
    class_names,
    num_samples=20
)
contrast_fig.savefig('contrast_heatmaps.png', bbox_inches='tight', dpi=300)

single_class_fig = visualize_single_class_features(
    models_dict,
    test_loader_norm,
    class_names,
    num_samples=5
)
single_class_fig.savefig('single_class_heatmaps.png', bbox_inches='tight', dpi=300)

In [26]:
def process_shap_values(shap_values, class_idx=None):
    """
    Process SHAP values properly for RGB images
    """
    # Print shape information for debugging
    print("SHAP values type:", type(shap_values))
    if isinstance(shap_values, list):
        print("SHAP values[0] shape:", shap_values[0][0].shape)
    else:
        print("SHAP values shape:", shap_values.shape)

    if isinstance(shap_values, list):
        if class_idx is not None:
            # For classification, take the specified class
            shap_map = shap_values[class_idx][0]
        else:
            # If no class specified, take the first class
            shap_map = shap_values[0][0]
    else:
        shap_map = shap_values[0]

    print("Shape before reduction:", shap_map.shape)

    # Handle different possible shapes
    if len(shap_map.shape) == 4:  # (C, H, W, 3) or similar
        # Average across channels and RGB if present
        shap_map = np.mean(np.abs(shap_map), axis=(0, -1))
    elif len(shap_map.shape) == 3:  # (C, H, W) or (H, W, 3)
        if shap_map.shape[-1] == 3:  # RGB
            shap_map = np.mean(np.abs(shap_map), axis=-1)
        else:  # Channels first
            shap_map = np.mean(np.abs(shap_map), axis=0)

    print("Shape after reduction:", shap_map.shape)
    return shap_map

def create_enhanced_contrast_heatmaps(models_dict, test_loader_norm, class_names, num_samples=20):
    """
    Create enhanced SHAP feature importance heatmaps using class contrasts
    """
    # Collect samples per class
    class_samples = {class_name: {'images': [], 'count': 0} for class_name in class_names}

    print("Collecting samples...")
    with torch.no_grad():
        for images, labels in test_loader_norm:
            for img, label in zip(images, labels):
                class_name = class_names[label.item()]
                if class_samples[class_name]['count'] < num_samples:
                    class_samples[class_name]['images'].append(img)
                    class_samples[class_name]['count'] += 1

            if all(samples['count'] >= num_samples for samples in class_samples.values()):
                break

    # Calculate number of comparisons
    num_comparisons = len(list(itertools.combinations(class_names, 2)))
    num_models = len(models_dict)

    # Create figure
    fig = plt.figure(figsize=(20, 5 * num_models))
    gs = GridSpec(num_models, num_comparisons, figure=fig, hspace=0.4, wspace=0.3)
    fig.suptitle('Class Contrast SHAP Feature Importance Maps', fontsize=16, y=0.95)

    for model_idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"\nProcessing model: {model_name}")
        model.eval()

        # Calculate average SHAP values for each class
        class_avg_shap = {}

        for class_name in class_names:
            print(f"Processing class: {class_name}")
            if not class_samples[class_name]['images']:
                print(f"Warning: No samples for class {class_name}")
                continue

            avg_shap_values = None
            count = 0

            for image in class_samples[class_name]['images']:
                try:
                    device = next(model.parameters()).device
                    image = image.to(device).unsqueeze(0)

                    background = torch.zeros_like(image)
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(image)

                    # Process SHAP values
                    shap_map = process_shap_values(shap_values)

                    if avg_shap_values is None:
                        avg_shap_values = shap_map
                    else:
                        avg_shap_values += shap_map
                    count += 1

                except Exception as e:
                    print(f"Error processing image: {str(e)}")
                    continue

            if count > 0:
                avg_shap_values /= count
                class_avg_shap[class_name] = avg_shap_values

        # Create contrast visualizations
        for comp_idx, (class1, class2) in enumerate(itertools.combinations(class_names, 2)):
            if class1 not in class_avg_shap or class2 not in class_avg_shap:
                print(f"Skipping {class1} vs {class2} - missing values")
                continue

            # Calculate contrast
            contrast = class_avg_shap[class1] - class_avg_shap[class2]

            # Normalize contrast to [-1, 1]
            vmax = np.max(np.abs(contrast))
            if vmax < 1e-6:
                print(f"Warning: Very small contrast between {class1} and {class2}")
                continue

            contrast_normalized = contrast / (vmax + 1e-7)

            # Create subplot
            ax = fig.add_subplot(gs[model_idx, comp_idx])
            # Ensure contrast is 2D and properly normalized
            if len(contrast_normalized.shape) > 2:
                print(f"Warning: Unexpected shape {contrast_normalized.shape}, reducing dimensions")
                contrast_normalized = np.mean(contrast_normalized, axis=tuple(range(len(contrast_normalized.shape)-2)))

            im = ax.imshow(contrast_normalized,
                          cmap='RdBu_r',
                          vmin=-1,
                          vmax=1,
                          interpolation='nearest')

            ax.set_title(f'{model_name}\n{class1} vs {class2}', fontsize=10)
            plt.colorbar(im, ax=ax, format='%.2e')
            ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    return fig

def visualize_single_class_features(models_dict, test_loader_norm, class_names, num_samples=20):
    """
    Create individual class feature importance heatmaps
    """
    num_models = len(models_dict)
    num_classes = len(class_names)

    fig = plt.figure(figsize=(4 * num_classes, 4 * num_models))
    gs = GridSpec(num_models, num_classes, figure=fig, hspace=0.3, wspace=0.3)
    fig.suptitle('Single Class SHAP Feature Importance Maps', fontsize=16, y=0.95)

    # Collect samples per class
    class_samples = {class_name: {'images': [], 'count': 0} for class_name in class_names}

    print("Collecting samples...")
    with torch.no_grad():
        for images, labels in test_loader_norm:
            for img, label in zip(images, labels):
                class_name = class_names[label.item()]
                if class_samples[class_name]['count'] < num_samples:
                    class_samples[class_name]['images'].append(img)
                    class_samples[class_name]['count'] += 1

            if all(samples['count'] >= num_samples for samples in class_samples.values()):
                break

    all_shap_values = []

    for model_idx, (model_name, model) in enumerate(models_dict.items()):
        print(f"\nProcessing model: {model_name}")
        model.eval()

        for class_idx, class_name in enumerate(class_names):
            if not class_samples[class_name]['images']:
                print(f"Warning: No samples for class {class_name}")
                continue

            avg_shap_values = None
            count = 0

            for image in class_samples[class_name]['images']:
                try:
                    device = next(model.parameters()).device
                    image = image.to(device).unsqueeze(0)

                    background = torch.zeros_like(image)
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(image)

                    # Process SHAP values
                    shap_map = process_shap_values(shap_values, class_idx)

                    if avg_shap_values is None:
                        avg_shap_values = shap_map
                    else:
                        avg_shap_values += shap_map
                    count += 1

                except Exception as e:
                    print(f"Error processing image: {str(e)}")
                    continue

            if count > 0:
                avg_shap_values /= count
                all_shap_values.append(avg_shap_values)

                # Normalize values to [0, 1]
                normalized_values = avg_shap_values / (np.max(avg_shap_values) + 1e-7)

                ax = fig.add_subplot(gs[model_idx, class_idx])
                # Ensure values are 2D
                if len(normalized_values.shape) > 2:
                    print(f"Warning: Unexpected shape {normalized_values.shape}, reducing dimensions")
                    normalized_values = np.mean(normalized_values, axis=tuple(range(len(normalized_values.shape)-2)))

                im = ax.imshow(normalized_values,
                             cmap='viridis',
                             vmin=0,
                             vmax=1,
                             interpolation='nearest')

                ax.set_title(f'{model_name}\n{class_name}', fontsize=10)
                plt.colorbar(im, ax=ax, format='%.2e')
                ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    return fig

In [None]:
contrast_fig = create_enhanced_contrast_heatmaps(
    models_dict,
    test_loader_norm,
    class_names,
    num_samples=3  # Start with fewer samples
)

In [None]:
contrast_fig = create_enhanced_contrast_heatmaps(
    models_dict,
    test_loader_norm,
    class_names,
    num_samples=20
)

test regardless imaging planes:

In [None]:
def analyze_shap_distributions(models_dict, test_loader_norm, class_names, num_samples=20):
    """
    Analyze SHAP value distributions and sparsity across classes and models.
    Focus on view-independent metrics.
    """
    def calculate_sparsity_metrics(shap_values):
        """Calculate various sparsity metrics for SHAP values"""
        # Flatten SHAP values
        flat_shap = np.abs(shap_values).flatten()

        # Calculate metrics
        gini = 1 - np.sum((flat_shap / np.sum(flat_shap)) ** 2)
        top_10_percent = np.sum(np.sort(flat_shap)[-int(len(flat_shap)*0.1):]) / np.sum(flat_shap)
        sparsity = np.mean(flat_shap < np.max(flat_shap) * 0.1)  # Fraction of small values

        return {
            'gini': gini,
            'top_10_percent': top_10_percent,
            'sparsity': sparsity
        }

    def calculate_class_contrast_metrics(shap_values_dict):
        """Calculate contrast metrics between classes"""
        contrasts = {}
        for c1 in class_names:
            for c2 in class_names:
                if c1 < c2:
                    # Calculate KL divergence between SHAP distributions
                    s1 = np.abs(shap_values_dict[c1]).flatten()
                    s2 = np.abs(shap_values_dict[c2]).flatten()

                    # Normalize to probability distributions
                    s1 = s1 / np.sum(s1)
                    s2 = s2 / np.sum(s2)

                    # Add small epsilon to avoid division by zero
                    epsilon = 1e-10
                    s1 += epsilon
                    s2 += epsilon
                    s1 /= np.sum(s1)
                    s2 /= np.sum(s2)

                    # Calculate symmetric KL divergence
                    kl_div = (entropy(s1, s2) + entropy(s2, s1)) / 2

                    # Calculate Wasserstein distance (approximation using sorted values)
                    s1_sorted = np.sort(s1)
                    s2_sorted = np.sort(s2)
                    wasserstein = np.mean(np.abs(s1_sorted - s2_sorted))

                    contrasts[f"{c1}_vs_{c2}"] = {
                        'kl_divergence': kl_div,
                        'wasserstein': wasserstein
                    }
        return contrasts

    # Store results for each model
    results = {}

    # Process each model
    for model_name, model in models_dict.items():
        print(f"\nProcessing model: {model_name}")
        model.eval()

        # Collect SHAP values per class
        class_shap_values = {class_name: [] for class_name in class_names}

        # Process samples
        sample_count = {class_name: 0 for class_name in class_names}

        for images, labels in test_loader_norm:
            if all(count >= num_samples for count in sample_count.values()):
                break

            for img, label in zip(images, labels):
                class_name = class_names[label.item()]
                if sample_count[class_name] >= num_samples:
                    continue

                try:
                    # Get SHAP values
                    background = torch.zeros_like(img.unsqueeze(0))
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(img.unsqueeze(0))

                    # Process SHAP values
                    if isinstance(shap_values, list):
                        processed_shap = np.mean(np.abs(shap_values[0][0]), axis=0)
                    else:
                        processed_shap = np.mean(np.abs(shap_values[0]), axis=0)

                    class_shap_values[class_name].append(processed_shap)
                    sample_count[class_name] += 1

                except Exception as e:
                    print(f"Error processing sample: {str(e)}")
                    continue

        # Calculate metrics
        model_results = {
            'sparsity': {},
            'contrasts': {},
            'distributions': {}
        }

        # Average SHAP values per class
        avg_shap_values = {}
        for class_name, values in class_shap_values.items():
            if values:
                avg_shap_values[class_name] = np.mean(values, axis=0)
                # Calculate sparsity metrics
                model_results['sparsity'][class_name] = calculate_sparsity_metrics(avg_shap_values[class_name])

        # Calculate contrast metrics
        model_results['contrasts'] = calculate_class_contrast_metrics(avg_shap_values)

        # Store distribution information
        for class_name, values in class_shap_values.items():
            if values:
                flat_values = np.concatenate([v.flatten() for v in values])
                model_results['distributions'][class_name] = {
                    'mean': np.mean(flat_values),
                    'std': np.std(flat_values),
                    'percentiles': np.percentile(flat_values, [25, 50, 75])
                }

        results[model_name] = model_results

    return results

def visualize_metrics(results):
    """
    Create visualizations for the sparsity and contrast metrics
    """
    num_models = len(results)
    fig = plt.figure(figsize=(15, 5 * num_models))
    gs = GridSpec(num_models, 3, figure=fig)

    for i, (model_name, model_results) in enumerate(results.items()):
        # Plot sparsity metrics
        ax1 = fig.add_subplot(gs[i, 0])
        sparsity_data = []
        labels = []
        metrics = []
        for class_name, metrics_dict in model_results['sparsity'].items():
            for metric_name, value in metrics_dict.items():
                sparsity_data.append(value)
                labels.append(class_name)
                metrics.append(metric_name)

        ax1.bar(range(len(sparsity_data)), sparsity_data)
        ax1.set_xticks(range(len(sparsity_data)))
        ax1.set_xticklabels([f"{l}\n{m}" for l, m in zip(labels, metrics)], rotation=45)
        ax1.set_title(f"{model_name} - Sparsity Metrics")

        # Plot contrast metrics
        ax2 = fig.add_subplot(gs[i, 1])
        contrast_data = []
        contrast_labels = []
        for pair, metrics in model_results['contrasts'].items():
            contrast_data.append(metrics['kl_divergence'])
            contrast_labels.append(pair)

        ax2.bar(range(len(contrast_data)), contrast_data)
        ax2.set_xticks(range(len(contrast_data)))
        ax2.set_xticklabels(contrast_labels, rotation=45)
        ax2.set_title(f"{model_name} - KL Divergence Between Classes")

        # Plot distribution metrics
        ax3 = fig.add_subplot(gs[i, 2])
        for class_name, dist_metrics in model_results['distributions'].items():
            ax3.boxplot([dist_metrics['percentiles']], positions=[list(model_results['distributions'].keys()).index(class_name)],
                       labels=[class_name])
        ax3.set_title(f"{model_name} - SHAP Value Distributions")

    plt.tight_layout()
    return fig


In [None]:
models_dict = {
    'base': base_model,
    'trivial_aug': ta_model,
    'gans': gan_model,
    'combined': combined_model
}


class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']

results = analyze_shap_distributions(
    models_dict,
    test_loader_norm,
    class_names,
    num_samples=20
)

fig = visualize_metrics(results)
plt.savefig('shap_analysis.png', bbox_inches='tight', dpi=300)

In [None]:
def create_shap_summary_plots(models_dict, test_loader_norm, class_names, num_samples=20):
    """
    Create SHAP summary visualizations that don't rely on spatial interpretations
    """
    def process_shap_values(shap_values):
        if isinstance(shap_values, list):
            return shap_values[0][0]  # Get first class
        return shap_values[0]

    def calculate_feature_stats(shap_map):
        # Calculate various statistics that don't depend on spatial location
        abs_shap = np.abs(shap_map)
        return {
            'mean_importance': np.mean(abs_shap),
            'max_importance': np.max(abs_shap),
            'percentile_95': np.percentile(abs_shap, 95),
            'sparsity': np.mean(abs_shap < np.max(abs_shap) * 0.1)
        }

    results = {}
    for model_name, model in models_dict.items():
        print(f"\nProcessing model: {model_name}")
        model.eval()

        # Collect samples and SHAP values per class
        class_stats = {class_name: [] for class_name in class_names}
        shap_magnitudes = {class_name: [] for class_name in class_names}

        sample_count = {class_name: 0 for class_name in class_names}

        for images, labels in test_loader_norm:
            if all(count >= num_samples for count in sample_count.values()):
                break

            for img, label in zip(images, labels):
                class_name = class_names[label.item()]
                if sample_count[class_name] >= num_samples:
                    continue

                try:
                    # Calculate SHAP values
                    background = torch.zeros_like(img.unsqueeze(0))
                    explainer = shap.GradientExplainer(model, background)
                    shap_values = explainer.shap_values(img.unsqueeze(0))

                    # Process SHAP values
                    shap_map = process_shap_values(shap_values)
                    stats = calculate_feature_stats(shap_map)
                    class_stats[class_name].append(stats)

                    # Store magnitudes for distribution plots
                    shap_magnitudes[class_name].extend(np.abs(shap_map).flatten())

                    sample_count[class_name] += 1

                except Exception as e:
                    print(f"Error processing sample: {str(e)}")
                    continue

        results[model_name] = {
            'stats': class_stats,
            'magnitudes': shap_magnitudes
        }

    # Create visualizations
    num_models = len(results)
    fig = plt.figure(figsize=(15, 5 * num_models))
    gs = GridSpec(num_models, 3, figure=fig)

    for model_idx, (model_name, model_results) in enumerate(results.items()):
        # 1. Feature Importance Violin Plot
        ax1 = fig.add_subplot(gs[model_idx, 0])
        violin_data = []
        violin_positions = []
        violin_labels = []

        for i, (class_name, magnitudes) in enumerate(model_results['magnitudes'].items()):
            violin_data.append(magnitudes)
            violin_positions.append(i)
            violin_labels.append(class_name)

        ax1.violinplot(violin_data, positions=violin_positions)
        ax1.set_xticks(violin_positions)
        ax1.set_xticklabels(violin_labels, rotation=45)
        ax1.set_title(f"{model_name} - Feature Importance Distributions")

        # 2. Summary Statistics
        ax2 = fig.add_subplot(gs[model_idx, 1])
        for stat_name in ['mean_importance', 'max_importance', 'percentile_95']:
            stat_values = []
            for class_name in class_names:
                class_stat = np.mean([stats[stat_name]
                                    for stats in model_results['stats'][class_name]])
                stat_values.append(class_stat)

            ax2.plot(class_names, stat_values,
                    marker='o',
                    label=stat_name.replace('_', ' ').title())

        ax2.set_xticklabels(class_names, rotation=45)
        ax2.legend()
        ax2.set_title(f"{model_name} - Summary Statistics")

        # 3. Sparsity Comparison
        ax3 = fig.add_subplot(gs[model_idx, 2])
        sparsity_values = []
        for class_name in class_names:
            sparsity = np.mean([stats['sparsity']
                              for stats in model_results['stats'][class_name]])
            sparsity_values.append(sparsity)

        ax3.bar(class_names, sparsity_values)
        ax3.set_xticklabels(class_names, rotation=45)
        ax3.set_title(f"{model_name} - Feature Sparsity")
        ax3.set_ylim(0, 1)

    plt.tight_layout()
    return fig

In [None]:
summary_fig = create_shap_summary_plots(
    models_dict,
    test_loader_norm,
    class_names,
    num_samples=20
)
summary_fig.savefig('shap_summary.png', bbox_inches='tight', dpi=300)