In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF

# Class mapping constants consistent with previous preprocessing
N_CLASSES = 4  # Mild_Demented(0), Moderate_Demented(1), Non_Demented(2), Very_Mild_Demented(3)
CLASS_NAMES = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']

# Base directory from original preprocessing
base_dir = "Combined_MRI_Dataset"
output_dir = os.path.join(base_dir, "model_output")

# Create output directories
os.makedirs(os.path.join(output_dir, "models"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "plots"), exist_ok=True)

# Custom Dataset class for processed ventricle images
class VentricleDataset(Dataset):
    def __init__(self, base_dir, split='train', normalization_type='ventricle'):
       
        self.base_dir = base_dir
        self.split = split
        self.normalization_type = normalization_type
        
        # Class mapping - consistent with preprocessing
        self.classes = {
            'Mild_Demented': 0,
            'Moderate_Demented': 1,
            'Non_Demented': 2,
            'Very_Mild_Demented': 3
        }
        
        # Use the normalized subdirectory matching the dataset organization
        self.data_dir = os.path.join(base_dir, "normalized", split)
        
        if not os.path.exists(self.data_dir):
            raise ValueError(f"Data directory {self.data_dir} does not exist")
            
        self.images = []
        self.labels = []
        
        # Process each class folder
        for class_name, label in self.classes.items():
            class_dir = os.path.join(self.data_dir, class_name)
            
            if not os.path.exists(class_dir):
                print(f"Warning: Class directory {class_dir} does not exist")
                continue
                
            # Process each image in the class folder
            for img_file in os.listdir(class_dir):
                # Filter for the specified normalization type
                if f"_{normalization_type}.png" in img_file.lower():
                    img_path = os.path.join(class_dir, img_file)
                    self.images.append(img_path)
                    self.labels.append(label)
        
        print(f"Loaded {len(self.images)} {normalization_type} images for {split} split")
        
        # Print class distribution
        class_counts = {}
        for label in self.labels:
            class_name = CLASS_NAMES[label]
            if class_name in class_counts:
                class_counts[class_name] += 1
            else:
                class_counts[class_name] = 1
        
        print("Class distribution:")
        for class_name, count in class_counts.items():
            print(f"  {class_name}: {count} images ({count/len(self.labels)*100:.1f}%)")
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Get image path and label
        img_path = self.images[idx]
        label = self.labels[idx]
        
        img = Image.open(img_path)
        
        # Convert image to tensor
        img_tensor = TF.to_tensor(img)
        
        return img_tensor, torch.tensor(label, dtype=torch.long)





In [8]:
class VentricleCNN(nn.Module):
    def __init__(self, input_channels=1, input_size=(128, 128)):
        super(VentricleCNN, self).__init__()
        # First convolutional block
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.batchnorm1 = nn.BatchNorm2d(64)
        
        # Second convolutional block
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.batchnorm2 = nn.BatchNorm2d(128)
        
        # Third convolutional block for deeper features
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.batchnorm3 = nn.BatchNorm2d(256)
        
        # After 3 pooling layers of 2x2, dimensions are reduced by factor of 8
        reduced_h = input_size[0] // 8
        reduced_w = input_size[1] // 8
        self.fc_input_size = 256 * reduced_h * reduced_w
        
        # Fully connected layers
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(self.fc_input_size, 512)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, 128)
        self.dropout2 = nn.Dropout(0.2)
        self.out = nn.Linear(128, N_CLASSES)
        
    def forward(self, x):
        # First block
        x = self.batchnorm1(self.pool1(F.relu(self.conv1(x))))
        
        # Second block
        x = self.batchnorm2(self.pool2(F.relu(self.conv2(x))))
        
        # Third block
        x = self.batchnorm3(self.pool3(F.relu(self.conv3(x))))
        
        # Fully connected layers
        x = self.flatten(x)
        x = self.dropout1(F.relu(self.fc1(x)))
        x = self.dropout2(F.relu(self.fc2(x)))
        x = self.out(x)
        
        return x

# Evaluation function for validation
def evaluate_model(model, data_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    
    class_correct = [0] * N_CLASSES
    class_total = [0] * N_CLASSES
    
    with torch.no_grad():
        for data in data_loader:
            images, labels = data
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Per-class accuracy
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += (predicted[i] == label).item()
                class_total[label] += 1
    
    accuracy = correct / total
    avg_loss = running_loss / len(data_loader)
    
    # Calculate per-class accuracy
    class_accuracies = {}
    for i in range(N_CLASSES):
        if class_total[i] > 0:
            class_acc = class_correct[i] / class_total[i]
            class_accuracies[CLASS_NAMES[i]] = class_acc
    
    return accuracy, avg_loss, class_accuracies

In [10]:
# Training function with class metrics
def train_model(model, train_loader, val_loader=None, num_epochs=10, learning_rate=0.001, 
                save_dir=None):
    if save_dir is None:
        save_dir = "."  # Default to current directory
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # Metrics for tracking performance
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    class_accuracies_history = []
    
    # For tracking best model
    best_val_acc = 0.0
    best_model_state = None
    best_epoch = 0
    
    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for i, data in enumerate(progress_bar):
            # Get the inputs and labels
            inputs, labels = data
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            # Update running loss
            running_loss += loss.item()
            
            # Calculate training accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            current_acc = 100 * correct / total
            progress_bar.set_postfix({'loss': loss.item(), 'acc': f'{current_acc:.2f}%'})
        
        # Calculate average loss and accuracy for the epoch
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.4f}')
        
        # Evaluate on validation set if provided
        if val_loader:
            val_acc, val_loss, class_accs = evaluate_model(model, val_loader, criterion)
            val_accuracies.append(val_acc)
            val_losses.append(val_loss)
            class_accuracies_history.append(class_accs)
            
            print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')
            
            # Print per-class accuracy
            print("Per-class validation accuracy:")
            for class_name, acc in class_accs.items():
                print(f"  {class_name}: {acc:.4f}")
            
            # Check if this is the best model so far
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = model.state_dict().copy()
                best_epoch = epoch + 1
                
                # Save the best model 
                best_model_path = os.path.join(save_dir, "best_model.pt")
                torch.save({
                    'epoch': best_epoch,
                    'model_state_dict': best_model_state,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'validation_accuracy': best_val_acc,
                    'validation_loss': val_loss,
                    'class_accuracies': class_accs,
                }, best_model_path)
                
                print(f"✅ New best model saved at epoch {best_epoch} with validation accuracy: {best_val_acc:.4f}")
            
            # Update learning rate based on validation performance
            scheduler.step(val_loss)  # Using validation loss for scheduler
    
    print(f"\nTraining completed. Best model was from epoch {best_epoch} with validation accuracy: {best_val_acc:.4f}")
    
    # Load the best model state
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, train_losses, train_accuracies, val_losses, val_accuracies, best_epoch, best_val_acc, class_accuracies_history

def training_pipeline():
    img_size = (128, 128)  # Same as in preprocessing
    batch_size = 32
    num_epochs = 10
    learning_rate = 0.001
    normalization_type = 'ventricle'  # can switch it up to test our model: 'norm', 'enhanced', 'ventricle', 'mask'
    
    models_dir = os.path.join(output_dir, "models")
    plots_dir = os.path.join(output_dir, "plots")
    
    # Load the processed datasets
    print("\nLoading datasets...")
    train_dataset = VentricleDataset(base_dir=base_dir, split='train', normalization_type=normalization_type)
    val_dataset = VentricleDataset(base_dir=base_dir, split='val', normalization_type=normalization_type)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    print(f"Dataset sizes: Train={len(train_dataset)}, Validation={len(val_dataset)}")
    
    # Create and train the model
    print("\n=== TRAINING PHASE ===")
    model = VentricleCNN(input_size=img_size)
    
    # Updated function call with additional return values
    model, train_losses, train_accuracies, val_losses, val_accuracies, best_epoch, best_val_acc, class_accs_history = train_model(
        model, train_loader, val_loader, num_epochs, learning_rate, save_dir=models_dir
    )
    
    # Plot enhanced training curves with all metrics
    print("\nGenerating training curves plot...")
    plt.figure(figsize=(15, 10))
    
    # Training and validation loss
    plt.subplot(2, 2, 1)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
    plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best model (epoch {best_epoch})')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    
    # Training accuracy
    plt.subplot(2, 2, 2)
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies)
    plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best model (epoch {best_epoch})')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Validation accuracy
    plt.subplot(2, 2, 3)
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies)
    plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best model (epoch {best_epoch})')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Combined metrics
    plt.subplot(2, 2, 4)
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
    plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best model (epoch {best_epoch})')
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.title('Training vs Validation Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f"training_curves_{normalization_type}.png"))
    plt.close()
        
    # Save model information
    with open(os.path.join(output_dir, f"model_info_{normalization_type}.txt"), 'w') as f:
        f.write(f"Normalization Type: {normalization_type}\n")
        f.write(f"Best Epoch: {best_epoch}\n")
        f.write(f"Best Validation Accuracy: {best_val_acc:.4f}\n")
        f.write("\nPer-class Validation Accuracy:\n")
        for class_name, acc in class_accs_history[best_epoch-1].items():
            f.write(f"  {class_name}: {acc:.4f}\n")
    
    print("\nTraining and evaluation complete!")
    print(f"Model trained on {normalization_type} normalization")
    print(f"Best model was from epoch {best_epoch} with validation accuracy: {best_val_acc:.4f}")
    
    return model, best_val_acc

def main():
    print("Starting model training pipeline for MRI ventricle analysis...")
    training_pipeline()

if __name__ == "__main__":
    main()

Starting model training pipeline for MRI ventricle analysis...

Loading datasets...
Loaded 12543 ventricle images for train split
Class distribution:
  Mild_Demented: 2545 images (20.3%)
  Moderate_Demented: 1845 images (14.7%)
  Non_Demented: 4480 images (35.7%)
  Very_Mild_Demented: 3673 images (29.3%)
Loaded 1792 ventricle images for val split
Class distribution:
  Mild_Demented: 363 images (20.3%)
  Moderate_Demented: 264 images (14.7%)
  Non_Demented: 640 images (35.7%)
  Very_Mild_Demented: 525 images (29.3%)
Dataset sizes: Train=12543, Validation=1792

=== TRAINING PHASE ===


Epoch 1/10: 100%|█████| 392/392 [05:19<00:00,  1.23it/s, loss=0.716, acc=55.82%]


Epoch 1/10, Training Loss: 1.0513, Training Accuracy: 0.5582
Validation Loss: 0.6341, Validation Accuracy: 0.6769
Per-class validation accuracy:
  Mild_Demented: 0.7107
  Moderate_Demented: 0.9659
  Non_Demented: 0.7703
  Very_Mild_Demented: 0.3943
✅ New best model saved at epoch 1 with validation accuracy: 0.6769


Epoch 2/10: 100%|█████| 392/392 [05:43<00:00,  1.14it/s, loss=0.349, acc=69.82%]


Epoch 2/10, Training Loss: 0.6372, Training Accuracy: 0.6982
Validation Loss: 0.6910, Validation Accuracy: 0.6607
Per-class validation accuracy:
  Mild_Demented: 0.9339
  Moderate_Demented: 0.9924
  Non_Demented: 0.4203
  Very_Mild_Demented: 0.5981


Epoch 3/10: 100%|█████| 392/392 [05:40<00:00,  1.15it/s, loss=0.307, acc=78.46%]


Epoch 3/10, Training Loss: 0.4728, Training Accuracy: 0.7846
Validation Loss: 0.3270, Validation Accuracy: 0.8650
Per-class validation accuracy:
  Mild_Demented: 0.8320
  Moderate_Demented: 1.0000
  Non_Demented: 0.9187
  Very_Mild_Demented: 0.7543
✅ New best model saved at epoch 3 with validation accuracy: 0.8650


Epoch 4/10: 100%|█████| 392/392 [05:50<00:00,  1.12it/s, loss=0.143, acc=86.30%]


Epoch 4/10, Training Loss: 0.3156, Training Accuracy: 0.8630
Validation Loss: 0.1864, Validation Accuracy: 0.9425
Per-class validation accuracy:
  Mild_Demented: 0.9614
  Moderate_Demented: 1.0000
  Non_Demented: 0.9797
  Very_Mild_Demented: 0.8552
✅ New best model saved at epoch 4 with validation accuracy: 0.9425


Epoch 5/10: 100%|█████| 392/392 [05:37<00:00,  1.16it/s, loss=0.151, acc=92.03%]


Epoch 5/10, Training Loss: 0.1959, Training Accuracy: 0.9203
Validation Loss: 0.1315, Validation Accuracy: 0.9459
Per-class validation accuracy:
  Mild_Demented: 0.9587
  Moderate_Demented: 1.0000
  Non_Demented: 0.9953
  Very_Mild_Demented: 0.8495
✅ New best model saved at epoch 5 with validation accuracy: 0.9459


Epoch 6/10: 100%|█████| 392/392 [05:29<00:00,  1.19it/s, loss=0.118, acc=95.38%]


Epoch 6/10, Training Loss: 0.1280, Training Accuracy: 0.9538
Validation Loss: 0.0525, Validation Accuracy: 0.9844
Per-class validation accuracy:
  Mild_Demented: 0.9725
  Moderate_Demented: 1.0000
  Non_Demented: 0.9969
  Very_Mild_Demented: 0.9695
✅ New best model saved at epoch 6 with validation accuracy: 0.9844


Epoch 7/10: 100%|███| 392/392 [21:07<00:00,  3.23s/it, loss=0.00501, acc=97.05%]


Epoch 7/10, Training Loss: 0.0852, Training Accuracy: 0.9705
Validation Loss: 0.0336, Validation Accuracy: 0.9922
Per-class validation accuracy:
  Mild_Demented: 0.9945
  Moderate_Demented: 1.0000
  Non_Demented: 0.9922
  Very_Mild_Demented: 0.9867
✅ New best model saved at epoch 7 with validation accuracy: 0.9922


Epoch 8/10: 100%|████| 392/392 [04:56<00:00,  1.32it/s, loss=0.0967, acc=97.50%]


Epoch 8/10, Training Loss: 0.0694, Training Accuracy: 0.9750
Validation Loss: 0.0459, Validation Accuracy: 0.9894
Per-class validation accuracy:
  Mild_Demented: 0.9945
  Moderate_Demented: 1.0000
  Non_Demented: 0.9766
  Very_Mild_Demented: 0.9962


Epoch 9/10: 100%|█████| 392/392 [23:42<00:00,  3.63s/it, loss=0.127, acc=98.35%]


Epoch 9/10, Training Loss: 0.0449, Training Accuracy: 0.9835
Validation Loss: 0.0195, Validation Accuracy: 0.9944
Per-class validation accuracy:
  Mild_Demented: 0.9917
  Moderate_Demented: 1.0000
  Non_Demented: 0.9984
  Very_Mild_Demented: 0.9886
✅ New best model saved at epoch 9 with validation accuracy: 0.9944


Epoch 10/10: 100%|███| 392/392 [07:37<00:00,  1.17s/it, loss=0.0127, acc=97.80%]


Epoch 10/10, Training Loss: 0.0677, Training Accuracy: 0.9780
Validation Loss: 0.0398, Validation Accuracy: 0.9877
Per-class validation accuracy:
  Mild_Demented: 0.9862
  Moderate_Demented: 1.0000
  Non_Demented: 0.9750
  Very_Mild_Demented: 0.9981

Training completed. Best model was from epoch 9 with validation accuracy: 0.9944

Generating training curves plot...

Training and evaluation complete!
Model trained on ventricle normalization
Best model was from epoch 9 with validation accuracy: 0.9944
