In [None]:
#!pip install numpy
#!pip install matplotlib
#!pip install monai


In [None]:
# Import required libraries
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd,
    RandRotate90d,
    RandFlipd,
    RandGaussianNoised,  # Fix typo here
    LambdaD,  
    RandAdjustContrastd,
    RandZoomd,
    ToTensord,
)
from monai.networks.nets import UNet
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.utils import set_determinism

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU is available")
    print('__CUDNN VERSION:', torch.backends.cudnn.version())
    print('__Number CUDA Devices:', torch.cuda.device_count())
    print('__CUDA Device Name:',torch.cuda.get_device_name(0))
    print('__CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)
else:
    device = torch.device('cpu')
    print("GPU is not available, using CPU")

In [None]:
print_config()
set_determinism(seed=42)

In [None]:

# Setup data directory and file lists
# /home/hamada/retina_segmentation/microaneurysm/images_microaneurysm
data_dir = "/home/hamada/retina_segmentation/hard_exudate"
train_image_dir = os.path.join(data_dir, "images_hard_exudate")
train_mask_dir = os.path.join(data_dir, "masks_hard_exudate")

# Original full dataset
all_images = sorted(glob.glob(os.path.join(train_image_dir, "*.jpg")))
all_masks = sorted(glob.glob(os.path.join(train_mask_dir, "*.png")))

# Verify correspondence
assert len(all_images) == len(all_masks), "Image-mask count mismatch!"
for img, msk in zip(all_images, all_masks):
    assert os.path.basename(img).split(".")[0] == os.path.basename(msk).split(".")[0], "Mismatched filenames!"

# Split indices
total = len(all_images)
train_count = int(0.8 * total)
val_count = int(0.1 * total)
test_count = total - train_count - val_count

# Create splits
train_files = [{"image": all_images[i], "mask": all_masks[i]} for i in range(train_count)]
val_files = [{"image": all_images[i+train_count], "mask": all_masks[i+train_count]} for i in range(val_count)]
test_files = [{"image": all_images[i+train_count+val_count], "mask": all_masks[i+train_count+val_count]} for i in range(test_count)]

print(f"Total samples: {total}")
print(f"Training: {len(train_files)}, Validation: {len(val_files)}, Testing: {len(test_files)}")

In [None]:
def show_rgb_overlay(sample):
    image = np.moveaxis(sample["image"].numpy(), 0, -1)  # CHW → HWC
    mask = sample["mask"].numpy().squeeze()
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title("RGB Image")
    
    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap="gray")
    plt.title("Mask")
    
    plt.subplot(1, 3, 3)
    plt.imshow(image)
    plt.imshow(mask, alpha=0.5, cmap="jet")
    plt.title("Overlay")
    
    plt.show()

In [None]:

# Define transforms

train_transforms = Compose([
    LoadImaged(keys=["image", "mask"]),
    EnsureChannelFirstd(keys=["image", "mask"]),
    ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0),
    LambdaD(keys=["mask"], func=lambda x: (x > 200).astype(np.float32)),
    
    # Enhanced augmentations
    RandRotate90d(keys=["image", "mask"], prob=0.75),  # Higher probability
    RandFlipd(keys=["image", "mask"], prob=0.5),
    RandGaussianNoised(keys=["image"], prob=0.5, std=0.05),
    RandAdjustContrastd(keys=["image"], prob=0.5, gamma=(0.8, 1.2)),
    RandZoomd(keys=["image", "mask"], prob=0.5, min_zoom=0.8, max_zoom=1.2),
    
    ToTensord(keys=["image", "mask"]),
])

# Validation transforms (no random augmentations)
val_transforms = Compose([
    LoadImaged(keys=["image", "mask"]),
    EnsureChannelFirstd(keys=["image", "mask"]),
    ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0),
    LambdaD(keys=["mask"], func=lambda x: (x > 200).astype(np.float32)),
    ToTensord(keys=["image", "mask"]),
])

# Create datasets and dataloaders
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.5)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False)

# Check mask validity
sample = train_ds[0]
print("Unique mask values:", np.unique(sample["mask"].numpy()))
# Should output [0., 1.] - if not, mask processing is broken


# Visualize sample training data
show_rgb_overlay(train_ds[0])

# Create model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Deeper UNet with residual blocks
model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=2,
    channels=(32, 64, 128, 256, 512),  # Increased capacity
    strides=(2, 2, 2, 2),
    num_res_units=4,  # More residual blocks
    norm="BATCH",  # Add batch normalization
).to(device)

loss_func = DiceCELoss(
    softmax=True,
    to_onehot_y=True,
    weight=torch.tensor([1.0, 15.0]).to(device)  # Class weights for CE term
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
dice_metric = DiceMetric(include_background=False, reduction="mean")


# Training loop
max_epochs = 50
best_metric = -1
best_metric_epoch = -1

# Implement learning rate warmup
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,
    total_steps=len(train_loader)*max_epochs,
    pct_start=0.3
)

for epoch in range(max_epochs):
    print(f"Epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    
    # Training
    for batch_data in train_loader:
        inputs = batch_data["image"].to(device)
        masks = batch_data["mask"].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_func(outputs, masks)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        
        epoch_loss += loss.item()
    
    epoch_loss /= len(train_loader)
    print(f"Train loss: {epoch_loss:.4f}")

    # Validation
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        
        for val_data in val_loader:
            val_images = val_data["image"].to(device)
            val_masks = val_data["mask"].to(device)
            
            val_outputs = model(val_images)
            val_probs = torch.softmax(val_outputs, dim=1)
            
            dice_metric(y_pred=val_probs, y=val_masks)
        
        metric = dice_metric.aggregate().item()
        dice_metric.reset()
        
        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), "best_model.pth")
        
        print(f"Validation Dice: {metric:.4f}")

        # Add this inside validation loop
        if epoch % 2 == 0:  # Every 2 epochs
            with torch.no_grad():
                # Randomly select 3 samples from validation set
                vis_indices = np.random.choice(len(val_ds), 3, replace=False)
                
                plt.figure(figsize=(15, 5))
                for i, idx in enumerate(vis_indices):
                    sample = val_ds[idx]
                    image = sample["image"].unsqueeze(0).to(device)
                    mask = sample["mask"].squeeze().cpu().numpy()
                    
                    # Get prediction
                    pred = model(image)
                    pred_mask = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
                    
                    # Convert image for display
                    img_display = image[0].cpu().numpy().transpose(1, 2, 0)  # (H, W, C)
                    
                    # Plot
                    plt.subplot(3, 3, i*3+1)
                    plt.imshow(img_display)
                    plt.title(f"Image {idx}")
                    plt.axis('off')
                    
                    plt.subplot(3, 3, i*3+2)
                    plt.imshow(img_display)
                    plt.imshow(mask, alpha=0.4, cmap="jet")
                    plt.title("Ground Truth")
                    plt.axis('off')
                    
                    plt.subplot(3, 3, i*3+3)
                    plt.imshow(img_display)
                    plt.imshow(pred_mask, alpha=0.4, cmap="jet")
                    plt.title("Prediction")
                    plt.axis('off')
                
                plt.tight_layout()
                plt.show()

print(f"Best validation Dice: {best_metric:.4f} at epoch {best_metric_epoch}")

# Testing visualization using the test split
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

test_ds = CacheDataset(data=test_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

with torch.no_grad():
    for test_data in test_loader:
        test_image = test_data["image"].to(device)
        test_mask = test_data["mask"].cpu().numpy().squeeze()
        
        test_output = model(test_image)
        test_pred = torch.argmax(test_output, dim=1).squeeze().cpu().numpy()
        
        # Plot results
        plt.figure(figsize=(15, 5))
        
        # Original Image (convert from tensor to HWC for RGB)
        img_np = test_image[0].cpu().numpy().transpose(1, 2, 0)
        
        plt.subplot(1, 3, 1)
        plt.imshow(img_np)
        plt.title("Input Image")
        plt.axis("off")
        
        plt.subplot(1, 3, 2)
        plt.imshow(img_np)
        plt.imshow(test_mask, alpha=0.5, cmap="jet")
        plt.title("Ground Truth")
        plt.axis("off")
        
        plt.subplot(1, 3, 3)
        plt.imshow(img_np)
        plt.imshow(test_pred, alpha=0.5, cmap="jet")
        plt.title("Prediction")
        plt.axis("off")
        
        plt.show()