In [None]:
# --- 1. Import Libraries ---
import os
import cv2
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import math
import random

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
!pip install segmentation_models_pytorch
!pip install albumentations

In [None]:
# --- Function to seed everything for reproducibility ---
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
import os
import re
import pandas as pd

# --- 1. CONFIGURE YOUR PATHS ---
IMAGE_PATH = '/content/drive/MyDrive/200_AD_CN_MCI_11112025/images/'
MASK_PATH = '/content/drive/MyDrive/200_AD_CN_MCI_11112025/masks/'
OUTPUT_CSV_PATH = '/content/drive/MyDrive/200_AD_CN_MCI_11112025/metadata.csv'

# --- 2. THE "ADVANCED NATURAL SORT" KEY FUNCTION ---
# This version correctly handles both text and numbers for perfect sorting.
def advanced_natural_sort_key(filename):
    """
    Creates a sort key that handles text prefixes and numbers in parenthesis.
    Sorts first by the text part, then by the number.
    e.g., 'ad_image_0001 (10).png' comes after 'ad_image_0001 (2).png'
    and 'cn_image_0001 (1).png' comes after all 'ad_image_0001' files.
    """
    # Find the number in parenthesis
    match = re.search(r'\((\d+)\)', filename)
    if match:
        # The number is our primary numeric sort key
        number = int(match.group(1))
        # The text part before the number is our primary text sort key
        prefix = filename[:match.start()]
        return (prefix, number)
    else:
        # If no number is found, sort by the whole filename
        return (filename, 0)

# --- 3. THE MAIN SCRIPT ---

def create_parallel_metadata():
    print("--- Starting Metadata Creation ---")

    try:
        image_files = [f for f in os.listdir(IMAGE_PATH) if f.endswith('.png')]
        mask_files = [f for f in os.listdir(MASK_PATH) if f.endswith('.png')]
    except FileNotFoundError as e:
        print(f"ERROR: A directory was not found! {e}")
        return

    # --- CRITICAL STEP: Sort both lists using the ADVANCED natural sort key ---
    image_files.sort(key=advanced_natural_sort_key)
    mask_files.sort(key=advanced_natural_sort_key)

    print(f"Found and sorted {len(image_files)} images.")
    print(f"Found and sorted {len(mask_files)} masks.")

    # --- Validation Step ---
    if len(image_files) != len(mask_files):
        print("\n--- FATAL ERROR: MISMATCH IN FILE COUNTS! ---")
        return

    if not image_files:
        print("\n--- ERROR: No image files found. ---")
        return

    # --- Create the pairs and the DataFrame ---
    file_pairs = [{'image_id': img, 'mask_id': mask} for img, mask in zip(image_files, mask_files)]
    metadata_df = pd.DataFrame(file_pairs)

    # Save the DataFrame to a CSV file
    metadata_df.to_csv(OUTPUT_CSV_PATH, index=False)

    print(f"\n--- SUCCESS! ---")
    print(f"Created metadata.csv with {len(metadata_df)} perfectly matched pairs.")
    print(f"File saved to: {OUTPUT_CSV_PATH}")

    print("\n--- Here is a sample of your CORRECT metadata.csv: ---")
    print("--- Top 5 rows: ---")
    print(metadata_df.head())
    print("\n--- Bottom 5 rows: ---")
    print(metadata_df.tail())

# --- Run the main function ---
create_parallel_metadata()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# --- 2. Define Parameters and Paths ---
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMAGE_PATH = '/content/drive/MyDrive/200_AD_CN_MCI_11112025/images/'
MASK_PATH = '/content/drive/MyDrive/200_AD_CN_MCI_11112025/masks/'
METADATA_PATH = '/content/drive/MyDrive/200_AD_CN_MCI_11112025/metadata.csv'
BATCH_SIZE = 16
EPOCHS = 200
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
ENCODER = 'resnext50_32x4d'
PRETRAINED_WEIGHTS = 'imagenet'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")


In [None]:
# --- 3. Data Loading and Splitting (Train-Val-Test) ---
metadata_df = pd.read_csv(METADATA_PATH)
train_val_df, test_df = train_test_split(metadata_df, test_size=0.15, random_state=42)
train_df, val_df = train_test_split(train_val_df, test_size=0.15, random_state=42)

print(f"Total images: {len(metadata_df)}")
print(f"Training images: {len(train_df)}")
print(f"Validation images: {len(val_df)}")
print(f"Testing images: {len(test_df)}")

In [None]:
# --- 4. Augmentations and PyTorch Dataset ---
train_augs = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Affine(scale=(0.9, 1.1), translate_percent=(-0.06, 0.06), rotate=(-20, 20), p=0.7),
    A.ElasticTransform(p=0.4, alpha=100, sigma=120 * 0.05),
    A.GridDistortion(p=0.4),
    A.RandomBrightnessContrast(p=0.4),
    A.GaussNoise(p=0.2),
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])

val_augs = A.Compose([A.Normalize(mean=(0.5,), std=(0.5,)), ToTensorV2()])

class BrainMRIDataset(Dataset):
    def __init__(self, df, image_dir, mask_dir, augmentations=None):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.augmentations = augmentations
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_dir, row['image_id'])
        mask_path = os.path.join(self.mask_dir, row['mask_id'])
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (IMG_HEIGHT, IMG_WIDTH))
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, (IMG_HEIGHT, IMG_WIDTH))
        image = np.expand_dims(image, axis=-1)
        mask = np.expand_dims(mask, axis=-1)
        if self.augmentations:
            transformed = self.augmentations(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        mask = mask / 255.0
        mask[mask > 0.5] = 1.0
        mask[mask <= 0.5] = 0.0
        return image, mask.permute(2, 0, 1)


In [None]:
# --- 5. Create Datasets and DataLoaders ---
train_dataset = BrainMRIDataset(train_df, IMAGE_PATH, MASK_PATH, augmentations=train_augs)
val_dataset = BrainMRIDataset(val_df, IMAGE_PATH, MASK_PATH, augmentations=val_augs)
test_dataset = BrainMRIDataset(test_df, IMAGE_PATH, MASK_PATH, augmentations=val_augs)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print("DataLoaders created successfully.")

In [None]:
# --- 6. Loss Function and Metrics ---
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, smooth=1e-6):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth
    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        inputs, targets = inputs.view(-1), targets.view(-1)
        true_pos = (inputs * targets).sum()
        false_neg = ((1 - inputs) * targets).sum()
        false_pos = (inputs * (1 - targets)).sum()
        tversky_index = (true_pos + self.smooth) / (true_pos + self.alpha * false_neg + self.beta * false_pos + self.smooth)
        return 1 - tversky_index

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha, self.gamma, self.reduction = alpha, gamma, reduction
    def forward(self, inputs, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt)**self.gamma * bce_loss
        return focal_loss.mean() if self.reduction == 'mean' else focal_loss

class WeightedFocalTverskyLoss(nn.Module):
    def __init__(self, focal_weight=0.8, tversky_weight=0.2):
        super(WeightedFocalTverskyLoss, self).__init__()
        self.focal_loss = FocalLoss()
        self.tversky_loss = TverskyLoss()
        self.focal_weight, self.tversky_weight = focal_weight, tversky_weight
    def forward(self, inputs, targets):
        return self.focal_weight * self.focal_loss(inputs, targets) + self.tversky_weight * self.tversky_loss(inputs, targets)

class BCEFocalTverskyLoss(nn.Module):
    def __init__(self, bce_weight=0.5, focal_tversky_weight=0.5):
        super(BCEFocalTverskyLoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss() # More stable than Sigmoid + BCE
        self.focal_tversky_loss = WeightedFocalTverskyLoss()
        self.bce_weight = bce_weight
        self.focal_tversky_weight = focal_tversky_weight

    def forward(self, inputs, targets):
        bce = self.bce_loss(inputs, targets)
        focal_tversky = self.focal_tversky_loss(inputs, targets)
        return self.bce_weight * bce + self.focal_tversky_weight * focal_tversky

def dice_coef(y_pred, y_true, smooth=1):
    y_pred_sig = torch.sigmoid(y_pred)
    intersection = (y_pred_sig.view(-1) * y_true.view(-1)).sum()
    return (2. * intersection + smooth) / (y_pred_sig.sum() + y_true.sum() + smooth)


In [None]:
# --- 7. The Model ---
model = smp.UnetPlusPlus(encoder_name=ENCODER, encoder_weights=PRETRAINED_WEIGHTS, in_channels=1, classes=1).to(DEVICE)
print(f"Model created with a pre-trained {ENCODER} encoder.")

In [None]:
# --- 8. The Training Loop (with TQDM progress bars) ---
from tqdm import tqdm # Make sure to import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,weight_decay=WEIGHT_DECAY)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15)
loss_fn = BCEFocalTverskyLoss()

# Early Stopping parameters
patience = 30
best_val_loss = float('inf')
patience_counter = 0

history = {'train_loss': [], 'val_loss': [], 'train_dice': [], 'val_dice': []}

print("\n--- Starting Model Training ---")

for epoch in range(EPOCHS):
    model.train()
    train_loss, train_dice = 0.0, 0.0

    # --- Training Phase with TQDM ---
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for images, masks in loop:
        images, masks = images.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_dice += dice_coef(outputs, masks).item()

        # Update the progress bar with the current loss
        loop.set_postfix(loss=loss.item())

    avg_train_loss = train_loss / len(train_loader)
    avg_train_dice = train_dice / len(train_loader)
    history['train_loss'].append(avg_train_loss)
    history['train_dice'].append(avg_train_dice)

    # --- Validation Phase with TQDM ---
    model.eval()
    val_loss, val_dice = 0.0, 0.0
    with torch.no_grad():
        loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
        for images, masks in loop:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            val_loss += loss.item()
            val_dice += dice_coef(outputs, masks).item()

            # Update the progress bar with the current loss
            loop.set_postfix(loss=loss.item())

    avg_val_loss = val_loss / len(val_loader)
    avg_val_dice = val_dice / len(val_loader)
    history['val_loss'].append(avg_val_loss)
    history['val_dice'].append(avg_val_dice)

    print(f"Epoch {epoch+1}/{EPOCHS}: Train Loss: {avg_train_loss:.4f}, Train Dice: {avg_train_dice:.4f} | Val Loss: {avg_val_loss:.4f}, Val Dice: {avg_val_dice:.4f}")

    # Step the scheduler based on validation loss
    scheduler.step(avg_val_loss)

    # Early Stopping and Model Saving
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_segmentation_model.pth')
        print("   -> Model saved (best validation loss)")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("--- Early stopping triggered ---")
            break

print("\n--- MODEL TRAINING COMPLETE ---")

In [None]:
# --- 9. Final Evaluation on the Test Set (with TQDM progress bar) ---
from tqdm import tqdm # Make sure tqdm is imported

print("\n--- Evaluating on the Test Set ---")

# Load the best performing model from training
model.load_state_dict(torch.load('best_segmentation_model.pth'))
model.eval()

test_loss = 0.0
test_dice = 0.0
with torch.no_grad():
    # Wrap the test_loader with tqdm
    loop = tqdm(test_loader, desc="Testing")
    for images, masks in loop:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        outputs = model(images)

        loss = loss_fn(outputs, masks)
        test_loss += loss.item()

        dice_score = dice_coef(outputs, masks).item()
        test_dice += dice_score

        # Update the progress bar with the current metrics
        loop.set_postfix(loss=loss.item(), dice=dice_score)

avg_test_loss = test_loss / len(test_loader)
avg_test_dice = test_dice / len(test_loader)

print(f"\nFinal Test Set Performance:")
print(f"   - Test Loss: {avg_test_loss:.4f}")
print(f"   - Test Dice Coefficient: {avg_test_dice:.4f}")

In [None]:
# --- PLOTTING TRAINING HISTORY ---

print("\n--- Plotting Training and Validation History ---")

# The 'history' dictionary was populated during the training loop
train_loss = history['train_loss']
val_loss = history['val_loss']
train_dice = history['train_dice']
val_dice = history['val_dice']

# Get the number of epochs the model actually ran for
# This is important if early stopping was triggered
epochs_ran = range(1, len(train_loss) + 1)

plt.style.use('seaborn-v0_8-darkgrid') # Using a nice style for the plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))

# --- Plot 1: Loss (Training vs. Validation) ---
ax1.plot(epochs_ran, train_loss, 'b-o', label='Training Loss')
ax1.plot(epochs_ran, val_loss, 'r-o', label='Validation Loss')
ax1.set_title('Training & Validation Loss', fontsize=16)
ax1.set_xlabel('Epochs', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.legend(fontsize=12)
ax1.grid(True)

# --- Plot 2: Dice Coefficient (Training vs. Validation) ---
ax2.plot(epochs_ran, train_dice, 'b-o', label='Training Dice Coefficient')
ax2.plot(epochs_ran, val_dice, 'r-o', label='Validation Dice Coefficient')
ax2.set_title('Training & Validation Dice Coefficient', fontsize=16)
ax2.set_xlabel('Epochs', fontsize=12)
ax2.set_ylabel('Dice Coefficient', fontsize=12)
ax2.legend(fontsize=12)
ax2.grid(True)

plt.tight_layout()
plt.show()

# Load the best model for evaluation (this line might already be in your next cell)
# It's good practice to explicitly load it here before the test evaluation.
print("\n--- Best Model Weights Loaded for Final Evaluation ---")
model.load_state_dict(torch.load('best_segmentation_model.pth'))

In [None]:
# --- 10. Visualization on Test Samples (with Post-Processing) ---

print("\n--- Visualizing Sample Predictions from the Test Set ---")
num_samples_to_show = 10
# Ensure we don't try to show more samples than exist in the test set
num_samples_to_show = min(num_samples_to_show, len(test_dataset))
indices = np.random.choice(range(len(test_dataset)), num_samples_to_show, replace=False)

model.eval()
with torch.no_grad():
    for i in indices:
        # Get a single image and mask from the test dataset
        test_img_tensor, ground_truth_tensor = test_dataset[i]

        # The model expects a batch dimension, so add it: (C, H, W) -> (B, C, H, W)
        test_img_input = test_img_tensor.unsqueeze(0).to(DEVICE)

        # Make the prediction
        prediction_prob = torch.sigmoid(model(test_img_input)).squeeze(0)

        # Apply threshold to get the raw predicted mask
        predicted_mask = (prediction_prob > 0.5).cpu().numpy().squeeze()

        # --- START of ADDED post-processing code ---
        # Convert to a format cv2 can use (0-255)
        cleaned_mask_np = (predicted_mask * 255).astype(np.uint8)

        # Define a kernel for morphological operations. A 3x3 or 5x5 kernel is common.
        kernel = np.ones((5, 5), np.uint8)

        # Remove small noise/speckles (Opening = erosion then dilation)
        cleaned_mask_np = cv2.morphologyEx(cleaned_mask_np, cv2.MORPH_OPEN, kernel, iterations=1)

        # Fill small holes in the main object (Closing = dilation then erosion)
        cleaned_mask_np = cv2.morphologyEx(cleaned_mask_np, cv2.MORPH_CLOSE, kernel, iterations=1)

        # Convert back to 0-1 float range for consistency if needed, though imshow handles 0-255 fine
        # For plotting, the uint8 version is fine.
        # --- END of ADDED post-processing code ---


        # Convert original tensors to numpy arrays for plotting
        test_img_np = test_img_tensor.numpy().squeeze()
        ground_truth_np = ground_truth_tensor.numpy().squeeze()

        # --- Plot the results ---
        plt.figure(figsize=(18, 6))

        plt.subplot(1, 4, 1)
        plt.title('Testing Image')
        plt.imshow(test_img_np, cmap='gray')
        plt.axis('off')

        plt.subplot(1, 4, 2)
        plt.title('Ground Truth Mask')
        plt.imshow(ground_truth_np, cmap='gray')
        plt.axis('off')

        plt.subplot(1, 4, 3)
        plt.title("Model's Raw Prediction")
        # Show the original, raw prediction before cleaning
        plt.imshow(predicted_mask, cmap='gray')
        plt.axis('off')

        plt.subplot(1, 4, 4)
        plt.title("Cleaned Predicted Mask") # <-- Changed title
        # Show the new, cleaned mask
        plt.imshow(cleaned_mask_np, cmap='gray') # <-- Use the cleaned mask
        plt.axis('off')

        plt.tight_layout()
        plt.show()

In [None]:
# --- 10. Visualization on Test Samples ---
print("\n--- Visualizing Sample Predictions from the Test Set ---")
num_samples_to_show = min(5, len(test_dataset))
indices = np.random.choice(range(len(test_dataset)), num_samples_to_show, replace=False)

model.eval()
with torch.no_grad():
    for i in indices:
        img_tensor, gt_tensor = test_dataset[i]
        img_input = img_tensor.unsqueeze(0).to(DEVICE)
        pred_prob = torch.sigmoid(model(img_input)).squeeze(0)
        pred_mask = (pred_prob > 0.5).cpu().numpy().squeeze()
        img_np, gt_np = img_tensor.numpy().squeeze(), gt_tensor.numpy().squeeze()

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.title('Testing Image')
        plt.imshow(img_np, cmap='gray')
        plt.axis('off')
        plt.subplot(1, 3, 2)
        plt.title('Ground Truth Mask')
        plt.imshow(gt_np, cmap='gray')
        plt.axis('off')
        plt.subplot(1, 3, 3)
        plt.title("Model's Predicted Mask")
        plt.imshow(pred_mask, cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.show()