In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF

# Constants
N_CLASSES = 4
CLASS_NAMES = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']
IMG_SIZE = (128, 128)
BATCH_SIZE = 32
base_dir = "Combined_MRI_Dataset"
output_dir = os.path.join(base_dir, "model_output")

class VentricleDataset(Dataset):
    def __init__(self, base_dir, split='train'):
        self.data_dir = os.path.join(base_dir, "normalized", split)
        self.images = []
        self.labels = []
        
        for class_idx, class_name in enumerate(CLASS_NAMES):
            class_dir = os.path.join(self.data_dir, class_name)
            if not os.path.exists(class_dir):
                continue
                
            for img_file in os.listdir(class_dir):
                if img_file.lower().endswith('mask.png'):
                    self.images.append(os.path.join(class_dir, img_file))
                    self.labels.append(class_idx)
        
        print(f"Loaded {len(self.images)} images for {split} split")
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = Image.open(self.images[idx])
        return TF.to_tensor(img), torch.tensor(self.labels[idx], dtype=torch.long)

class VentricleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Two convolutional blocks
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.batchnorm1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.batchnorm2 = nn.BatchNorm2d(128)
        
        # Flattened size after 2 pooling layers (4x reduction)
        self.fc_input_size = 128 * (IMG_SIZE[0]//4) * (IMG_SIZE[1]//4)
        
        # Classifier
        self.fc1 = nn.Linear(self.fc_input_size, 512)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, 128)
        self.dropout2 = nn.Dropout(0.2)
        self.out = nn.Linear(128, N_CLASSES)
        
    def forward(self, x):
        x = self.batchnorm1(self.pool1(F.relu(self.conv1(x))))
        x = self.batchnorm2(self.pool2(F.relu(self.conv2(x))))
        x = x.view(x.size(0), -1)
        x = self.dropout1(F.relu(self.fc1(x)))
        x = self.dropout2(F.relu(self.fc2(x)))
        return self.out(x)





In [8]:
def evaluate(model, loader, criterion):
    model.eval()
    correct, total, loss = 0, 0, 0
    class_correct = [0] * N_CLASSES
    class_total = [0] * N_CLASSES
    
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            for label, pred in zip(labels, predicted):
                class_correct[label] += (pred == label).item()
                class_total[label] += 1
    
    accuracy = correct / total
    avg_loss = loss / len(loader)
    class_acc = {CLASS_NAMES[i]: class_correct[i]/class_total[i] 
                for i in range(N_CLASSES) if class_total[i] > 0}
    
    return accuracy, avg_loss, class_acc

In [10]:
def train(model, train_loader, val_loader, epochs=10, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
    
    best_acc = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # Validation
        val_acc, val_loss, class_acc = evaluate(model, val_loader, criterion)
        train_acc = train_correct / train_total
        
        history['train_loss'].append(train_loss/len(train_loader))
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}: Train Loss: {history['train_loss'][-1]:.4f}, "
              f"Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pt"))
        
        scheduler.step(val_loss)
    
    return model, history

def plot_history(history):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_curves.png"))
    plt.close()

def main():
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading data...")
    train_set = VentricleDataset(base_dir, 'train')
    val_set = VentricleDataset(base_dir, 'val')
    
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)
    
    print("Training model...")
    model = VentricleCNN()
    model, history = train(model, train_loader, val_loader)
    
    print("Saving results...")
    plot_history(history)
    print(f"Best validation accuracy: {max(history['val_acc']):.4f}")

if __name__ == "__main__":
    main()

Loading data...
Loaded 12543 images for train split
Loaded 1792 images for val split
Training model...


Epoch 1/10: 100%|█████████████████████████████| 392/392 [06:19<00:00,  1.03it/s]


Epoch 1: Train Loss: 1.4965, Train Acc: 0.4729, Val Loss: 0.7987, Val Acc: 0.5921


Epoch 2/10: 100%|█████████████████████████████| 392/392 [06:01<00:00,  1.09it/s]


Epoch 2: Train Loss: 0.7488, Train Acc: 0.6299, Val Loss: 0.6204, Val Acc: 0.6981


Epoch 3/10: 100%|█████████████████████████████| 392/392 [04:55<00:00,  1.33it/s]


Epoch 3: Train Loss: 0.5779, Train Acc: 0.7366, Val Loss: 0.4018, Val Acc: 0.8298


Epoch 4/10: 100%|█████████████████████████████| 392/392 [06:23<00:00,  1.02it/s]


Epoch 4: Train Loss: 0.3905, Train Acc: 0.8370, Val Loss: 0.2867, Val Acc: 0.8750


Epoch 5/10: 100%|███████████████████████████| 392/392 [2:14:46<00:00, 20.63s/it]


Epoch 5: Train Loss: 0.2167, Train Acc: 0.9157, Val Loss: 0.2434, Val Acc: 0.8968


Epoch 6/10: 100%|███████████████████████████| 392/392 [6:07:38<00:00, 56.27s/it]


Epoch 6: Train Loss: 0.1356, Train Acc: 0.9500, Val Loss: 0.1689, Val Acc: 0.9381


Epoch 7/10: 100%|███████████████████████████| 392/392 [1:11:45<00:00, 10.98s/it]


Epoch 7: Train Loss: 0.0974, Train Acc: 0.9656, Val Loss: 0.1538, Val Acc: 0.9459


Epoch 8/10: 100%|█████████████████████████████| 392/392 [11:11<00:00,  1.71s/it]


Epoch 8: Train Loss: 0.0749, Train Acc: 0.9756, Val Loss: 0.1569, Val Acc: 0.9498


Epoch 9/10: 100%|█████████████████████████████| 392/392 [46:35<00:00,  7.13s/it]


Epoch 9: Train Loss: 0.0597, Train Acc: 0.9794, Val Loss: 0.1796, Val Acc: 0.9453


Epoch 10/10: 100%|████████████████████████████| 392/392 [16:21<00:00,  2.50s/it]


Epoch 10: Train Loss: 0.0529, Train Acc: 0.9825, Val Loss: 0.1362, Val Acc: 0.9587
Saving results...
Best validation accuracy: 0.9587
