In [8]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image

# Constants
N_CLASSES = 4
CLASS_NAMES = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']
IMG_SIZE = (128, 128)
BATCH_SIZE = 32
base_dir = "Combined_MRI_Dataset"
output_dir = os.path.join(base_dir, "model_output")

class VentricleDataset(Dataset):
    def __init__(self, base_dir, split='test'):
        self.data_dir = os.path.join(base_dir, "normalized", split)
        self.images = []
        self.labels = []
        
        for class_idx, class_name in enumerate(CLASS_NAMES):
            class_dir = os.path.join(self.data_dir, class_name)
            if not os.path.exists(class_dir):
                print(f"Warning: Class directory {class_dir} not found.")
                continue
                
            for img_file in os.listdir(class_dir):
                if img_file.lower().endswith('mask.png'):
                    self.images.append(os.path.join(class_dir, img_file))
                    self.labels.append(class_idx)
        
        print(f"Loaded {len(self.images)} images for {split} split")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = Image.open(self.images[idx])
        return TF.to_tensor(img), torch.tensor(self.labels[idx], dtype=torch.long)



In [10]:
class VentricleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Two convolutional blocks
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.batchnorm1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.batchnorm2 = nn.BatchNorm2d(128)
        
        # Flattened size after 2 pooling layers (4x reduction)
        self.fc_input_size = 128 * (IMG_SIZE[0]//4) * (IMG_SIZE[1]//4)
        
        # Classifier
        self.fc1 = nn.Linear(self.fc_input_size, 512)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, 128)
        self.dropout2 = nn.Dropout(0.2)
        self.out = nn.Linear(128, N_CLASSES)
        
        # GradCAM attributes
        self.gradients = None
        self.activations = None
        
        # Hook the last convolutional layer
        self.conv2.register_forward_hook(self.save_activations)
        self.conv2.register_full_backward_hook(self.save_gradients)
    
    def save_activations(self, module, input, output):
        self.activations = output
    
    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]
    
    def forward(self, x):
        x = self.batchnorm1(self.pool1(F.relu(self.conv1(x))))
        x = self.batchnorm2(self.pool2(F.relu(self.conv2(x))))
        
        x = x.view(x.size(0), -1)  # Flatten
        x = self.dropout1(F.relu(self.fc1(x)))
        x = self.dropout2(F.relu(self.fc2(x)))
        return self.out(x)

def evaluate_model(model, test_loader):
    model.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = (np.array(all_preds) == np.array(all_labels)).mean()
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=CLASS_NAMES, digits=4)
    
    print(f"\nTest Accuracy: {accuracy:.4f}")
    print("\nConfusion Matrix:")
    print(cm)
    print("\nClassification Report:")
    print(report)
    
    return accuracy, cm, report


In [12]:

def plot_confusion_matrix(cm):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(output_dir, "confusion_matrix.png"))
    plt.close()


    
def generate_gradcam(model, input_tensor, target_class=None):
    model.eval()
    
    activations = None
    gradients = None
    
    # Define hooks
    def backward_hook(module, grad_in, grad_out):
        nonlocal gradients
        gradients = grad_out[0]
    
    def forward_hook(module, input, output):
        nonlocal activations
        activations = output
    
    # Register hooks
    handle_forward = model.conv2.register_forward_hook(forward_hook)
    handle_backward = model.conv2.register_full_backward_hook(backward_hook)
    
    # Forward pass
    output = model(input_tensor.unsqueeze(0))
    
    if target_class is None:
        target_class = torch.argmax(output)
    
    # Zero gradients
    model.zero_grad()
    
    # Backward pass
    torch.sum(output[0, target_class]).backward(retain_graph=True)
    
    # Convert to numpy
    gradients = gradients.detach().cpu().numpy()[0]
    activations = activations.detach().cpu().numpy()[0]
    
    # Important change: Take absolute value of gradients for weighting
    # This ensures we consider both positive and negative importance
    weights = np.mean(np.abs(gradients), axis=(1, 2))
    
    # Create the class activation map
    # Use absolute activations to capture both positive and negative importance
    cam = np.zeros(activations.shape[1:], dtype=np.float32)
    for i, w in enumerate(weights):
        # Consider absolute value of activations
        cam += w * np.abs(activations[i])

    
    # No need for ReLU since we're using absolute values
    
    # Resize to input size
    cam = cv2.resize(cam, (input_tensor.shape[2], input_tensor.shape[1]))
    
    # Normalize
    if np.max(cam) > 0:
        cam = cam / np.max(cam)
    else:
        print("Warning: Maximum CAM value is 0, cannot normalize")
    
    handle_forward.remove()
    handle_backward.remove()
    
    return cam

def grad_cam_visualization(model, test_loader, num_samples=5):

    os.makedirs(os.path.join(output_dir, "gradcam"), exist_ok=True)
    
    samples = []
    for images, labels in test_loader:
        for i in range(min(len(images), num_samples)):
            samples.append((images[i], labels[i]))
        if len(samples) >= num_samples:
            break
    
    # Generate and save Grad-CAM for each sample
    for idx, (image, label) in enumerate(samples):
        true_class = label.item()
        
        # Generate Grad-CAM
        cam = generate_gradcam(model, image, true_class)
        
        # Convert image tensor to numpy array for visualization
        img_np = image.squeeze().numpy()
        
        # Create heatmap overlay
        plt.figure(figsize=(15, 5))
        
        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(img_np, cmap='gray')
        plt.title(f'Original: {CLASS_NAMES[true_class]}')
        plt.axis('off')
        
        # Grad-CAM heatmap
        plt.subplot(1, 3, 2)
        plt.imshow(cam, cmap='jet')
        plt.title('Grad-CAM Heatmap')
        plt.axis('off')
        
        # Overlay
        plt.subplot(1, 3, 3)
        plt.imshow(img_np, cmap='gray')
        plt.imshow(cam, cmap='jet', alpha=0.5)
        plt.title('Overlay')
        plt.axis('off')
        
        # Save figure
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "gradcam", f"gradcam_sample_{idx}.png"))
        plt.close()
        
def load_model(model_path):
    model = VentricleCNN()
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    print("Model loaded successfully")
    return model

def main():
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading data...")
    test_set = VentricleDataset(base_dir, 'test')
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)
    
    print("Loading model...")
    model_path = os.path.join(output_dir, "best_model.pt")
    model = load_model(model_path)
    
    print("Evaluating model...")
    accuracy, cm, report = evaluate_model(model, test_loader)
    
    print("Generating visualizations...")
    plot_confusion_matrix(cm)
    grad_cam_visualization(model, test_loader)
    
    print("Saved all results to:", output_dir)

if __name__ == "__main__":
    main()

Loading data...
Loaded 3584 images for test split
Loading model...


  model.load_state_dict(torch.load(model_path, map_location='cpu'))


Model loaded successfully
Evaluating model...


Evaluating: 100%|█████████████████████████████| 112/112 [00:21<00:00,  5.19it/s]



Test Accuracy: 0.9520

Confusion Matrix:
[[ 702    0    6   19]
 [   0  527    0    0]
 [  10    0 1203   67]
 [  20    0   50  980]]

Classification Report:
                    precision    recall  f1-score   support

     Mild_Demented     0.9590    0.9656    0.9623       727
 Moderate_Demented     1.0000    1.0000    1.0000       527
      Non_Demented     0.9555    0.9398    0.9476      1280
Very_Mild_Demented     0.9193    0.9333    0.9263      1050

          accuracy                         0.9520      3584
         macro avg     0.9585    0.9597    0.9590      3584
      weighted avg     0.9522    0.9520    0.9520      3584

Generating visualizations...
Saved all results to: Combined_MRI_Dataset/model_output
