In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install pycocotools

In [None]:
!pip install -U albumentations

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO
import json
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Project paths
PROJECT_FOLDER_IN_DRIVE = '/content/drive/MyDrive/oct_major_project/'

# Define all 4 dataset folders
DATASETS = {
    'NORMAL': 'NORMAL 2.v1i.coco-segmentation',
    'DME': 'DME 2.v1i.coco-segmentation',
    'CNV': 'CNV 2.v1i.coco-segmentation',
    'DRUSEN': 'drusen 3.v1i.coco-segmentation'  # Using version 3
}

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Training hyperparameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
NUM_EPOCHS = 75  # Increased from 75 for better convergence with scheduler

# ========================================
# UNIFIED CATEGORY MAPPING
# ========================================

# Define unified categories across all datasets
UNIFIED_CATEGORIES = {
    0: 'background',
    1: 'GCL',                   # Ganglion Cell Layer
    2: 'INL',                   # Inner Nuclear Layer
    3: 'IPL',                   # Inner Plexiform Layer
    4: 'ONL',                   # Outer Nuclear Layer
    5: 'OPL',                   # Outer Plexiform Layer
    6: 'RNFL',                  # Retinal Nerve Fiber Layer
    7: 'RPE',                   # Retinal Pigment Epithelium
    8: 'CHOROID',               # Choroid (only in DME)
    9: 'INTRA-RETINAL-FLUID',   # Fluid within retinal layers
    10: 'SUB-RETINAL-FLUID',    # Fluid below retina (CNV)
    11: 'PED',                  # Pigment Epithelial Detachment
    12: 'DRUSENOID-PED'         # Drusen-specific PED
}

NUM_CLASSES = len(UNIFIED_CATEGORIES)
print(f"\nTotal unified classes: {NUM_CLASSES}")
print("Categories:", list(UNIFIED_CATEGORIES.values()))

# ========================================
# DATASET-SPECIFIC MAPPINGS
# ========================================

# Map original category IDs to unified IDs for each dataset
CATEGORY_MAPPINGS = {
    'NORMAL': {
        # NORMAL: id -> unified_id
        1: 1,   # GCL
        2: 2,   # INL
        3: 3,   # IPL
        4: 4,   # ONL
        5: 5,   # OPL
        6: 6,   # RNFL
        7: 7,   # RPE
    },
    'DME': {
        # DME: id -> unified_id
        1: 0,   # "0" -> background (artifact, ignore it)
        2: 8,   # CHOROID
        3: 1,   # GCL
        4: 2,   # INL
        5: 9,   # INTRA-RETINAL-FLUID
        6: 3,   # IPL
        7: 4,   # ONL
        8: 5,   # OPL
        9: 6,   # RNFL
        10: 7,  # RPE
    },
    'CNV': {
        # CNV: id -> unified_id
        1: 1,   # GCL
        2: 2,   # INL
        3: 9,   # INTRA-RETINAL-FLUID
        4: 3,   # IPL
        5: 4,   # ONL
        6: 5,   # OPL
        7: 11,  # PED
        8: 6,   # RNFL
        9: 7,   # RPE
        10: 10, # SUB-RETINAL-FLUID
    },
    'DRUSEN': {
        # DRUSEN: id -> unified_id
        1: 12,  # DRUSENOID-PED
        2: 1,   # GCL
        3: 2,   # INL
        4: 9,   # INTRA-RETINAL-FLUID
        5: 3,   # IPL
        6: 4,   # ONL
        7: 5,   # OPL
        8: 11,  # PED
        9: 6,   # RNFL
        10: 7,  # RPE
    }
}

print("\n‚úì Configuration complete!")
print(f"Datasets to process: {list(DATASETS.keys())}")

In [None]:
# ========================================
# VERIFY CATEGORIES IN ALL DATASETS
# ========================================

print("=" * 70)
print("VERIFYING CATEGORIES ACROSS ALL DATASETS")
print("=" * 70)

for dataset_name, folder_name in DATASETS.items():
    print(f"\n{'='*70}")
    print(f"Dataset: {dataset_name}")
    print(f"{'='*70}")

    # Build paths
    data_dir = os.path.join(PROJECT_FOLDER_IN_DRIVE, folder_name, 'train/')
    json_path = os.path.join(data_dir, '_annotations.coco.json')

    # Check if file exists
    if not os.path.exists(json_path):
        print(f"‚ùå ERROR: JSON file not found at {json_path}")
        continue

    # Load COCO annotations
    coco = COCO(json_path)

    # Get categories
    categories = coco.loadCats(coco.getCatIds())

    print(f"\nFound {len(categories)} categories:")
    print(f"{'Original ID':<15} {'Category Name':<30} {'‚Üí Unified ID':<15}")
    print("-" * 60)

    # Show mapping
    mapping = CATEGORY_MAPPINGS[dataset_name]
    for cat in categories:
        orig_id = cat['id']
        cat_name = cat['name']
        unified_id = mapping.get(orig_id, 0)
        unified_name = UNIFIED_CATEGORIES.get(unified_id, 'UNMAPPED')

        print(f"{orig_id:<15} {cat_name:<30} ‚Üí {unified_id:<3} ({unified_name})")

    # Count images
    img_ids = coco.getImgIds()
    print(f"\n‚úì Total images: {len(img_ids)}")

print("\n" + "=" * 70)
print("VERIFICATION COMPLETE")
print("=" * 70)

In [None]:
# ========================================
# GENERATE UNIFIED MASKS FOR ALL DATASETS
# ========================================

print("=" * 70)
print("GENERATING UNIFIED MASKS FOR ALL DATASETS")
print("=" * 70)

total_images_processed = 0

for dataset_name, folder_name in DATASETS.items():
    print(f"\n{'='*70}")
    print(f"Processing: {dataset_name}")
    print(f"{'='*70}")

    # Build paths
    data_dir = os.path.join(PROJECT_FOLDER_IN_DRIVE, folder_name, 'train/')
    json_path = os.path.join(data_dir, '_annotations.coco.json')
    image_dir = data_dir
    mask_save_dir = os.path.join(data_dir, 'masks/')

    # Create masks directory
    os.makedirs(mask_save_dir, exist_ok=True)

    # Load COCO annotations
    print(f"Loading COCO annotations from: {json_path}")
    coco = COCO(json_path)

    # Get all images
    img_ids = coco.getImgIds()
    images = coco.loadImgs(img_ids)

    print(f"Found {len(images)} images. Starting mask generation...")

    # Get mapping for this dataset
    category_mapping = CATEGORY_MAPPINGS[dataset_name]

    # Generate masks
    for img_info in tqdm(images, desc=f"Generating masks for {dataset_name}"):
        img_id = img_info['id']
        img_file_name = img_info['file_name']
        img_height = img_info['height']
        img_width = img_info['width']

        # Create blank mask (all background = 0)
        mask = np.zeros((img_height, img_width), dtype=np.uint8)

        # Get annotations for this image
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)

        # Process each annotation
        for ann in anns:
            category_id = ann['category_id']

            # Map to unified category ID
            unified_label = category_mapping.get(category_id, 0)

            # Draw polygons
            for seg in ann['segmentation']:
                poly = np.array(seg, dtype=np.int32).reshape((-1, 1, 2))
                cv2.fillPoly(mask, [poly], color=int(unified_label))

        # Save mask
        base_name = img_file_name.split('.')[0]
        mask_file_name = f"{base_name}.png"
        save_path = os.path.join(mask_save_dir, mask_file_name)
        cv2.imwrite(save_path, mask)

    print(f"‚úì Completed {dataset_name}: {len(images)} masks saved to {mask_save_dir}")
    total_images_processed += len(images)

print("\n" + "=" * 70)
print(f"MASK GENERATION COMPLETE!")
print(f"Total images processed: {total_images_processed}")
print("=" * 70)

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

class RetinaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_file = self.images[index]
        mask_file = img_file.split('.')[0] + '.png'

        img_path = os.path.join(self.image_dir, img_file)
        mask_path = os.path.join(self.mask_dir, mask_file)

        # Read image and mask
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Enhanced preprocessing pipeline
        # 1. Denoise while preserving edges
        image = cv2.fastNlMeansDenoising(image, None, h=10, searchWindowSize=21)

        # 2. Enhance edges using unsharp masking
        gaussian_3 = cv2.GaussianBlur(image, (0, 0), 2.0)
        unsharp_image = cv2.addWeighted(image, 2.0, gaussian_3, -1.0, 0)

        # 3. Enhance local contrast
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        image = clahe.apply(unsharp_image)

        # 4. Convert grayscale to RGB for albumentations transforms
        # (CLAHE and Sharpen require 3-channel images)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask'].long()

        return image, mask


train_transform = A.Compose(
    [
        A.Resize(height=512, width=512, interpolation=cv2.INTER_LANCZOS4),
        A.HorizontalFlip(p=0.5),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
            A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
        ], p=0.3),
        A.OneOf([
            A.GaussNoise(p=1.0),
            A.ISONoise(p=1.0),
        ], p=0.2),
        A.Sharpen(p=0.3),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

val_transform = A.Compose(
    [
        A.Resize(height=512, width=512, interpolation=cv2.INTER_LANCZOS4),
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),  # Enhance contrast
        A.Sharpen(p=1.0),  # Sharpen edges
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

print("Dataset class and augmentation transforms have been updated.")

In [None]:
# ========================================
# U-NET MODEL ARCHITECTURE
# ========================================

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=13, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Decoder
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # Final output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Reverse skip connections
        skip_connections = skip_connections[::-1]

        # Decoder
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = torch.nn.functional.interpolate(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)


print("‚úì U-Net model architecture defined!")
print(f"Model will output {NUM_CLASSES} classes")

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # Apply softmax to get probabilities
        inputs = torch.softmax(inputs, dim=1)

        # Flatten label and prediction tensors
        inputs = inputs.reshape(-1)
        targets = targets.reshape(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice

class EdgeAwareLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.25, weight=None):
        super(EdgeAwareLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.cross_entropy = nn.CrossEntropyLoss(weight=weight)
        self.dice = DiceLoss()
        self.sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        self.sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)

    def edge_loss(self, pred, target):
        # Convert predictions to probabilities
        pred_soft = torch.softmax(pred, dim=1)

        # Calculate edges for both prediction and target
        pred_edges = torch.sqrt(
            torch.conv2d(pred_soft[:, 1:].sum(1, keepdim=True), self.sobel_x.to(pred.device), padding=1)**2 +
            torch.conv2d(pred_soft[:, 1:].sum(1, keepdim=True), self.sobel_y.to(pred.device), padding=1)**2
        )

        target_edges = torch.sqrt(
            torch.conv2d(target.float(), self.sobel_x.to(target.device), padding=1)**2 +
            torch.conv2d(target.float(), self.sobel_y.to(target.device), padding=1)**2
        )

        return torch.mean(torch.abs(pred_edges - target_edges))

    def forward(self, inputs, targets):
        # Convert targets for different loss components
        target_one_hot = torch.nn.functional.one_hot(targets, num_classes=NUM_CLASSES).permute(0, 3, 1, 2)

        # Calculate individual losses
        loss_ce = self.cross_entropy(inputs, targets)
        loss_dice = self.dice(inputs, target_one_hot)
        loss_edge = self.edge_loss(inputs, targets)

        # Combine losses with weights
        total_loss = (self.alpha * loss_ce +
                     (1 - self.alpha) * loss_dice +
                     self.beta * loss_edge)

        return total_loss

print("Advanced CombinedLoss (Cross-Entropy + Dice) function defined.")

In [None]:
# ========================================
# CLASS WEIGHT CALCULATION FUNCTION
# ========================================

def calculate_class_weights(loader, num_classes, device):
    """
    Calculate class weights based on pixel frequency to handle class imbalance

    Args:
        loader: DataLoader containing the training data
        num_classes: Number of classes in the dataset
        device: torch device (cuda/cpu)

    Returns:
        Tensor of class weights
    """
    class_counts = torch.zeros(num_classes)

    print("\n" + "="*70)
    print("CALCULATING CLASS WEIGHTS")
    print("="*70)

    for batch_idx, (_, targets) in enumerate(tqdm(loader, desc="Counting pixels")):
        for c in range(num_classes):
            class_counts[c] += (targets == c).sum().item()

    # Inverse frequency weighting
    total_pixels = class_counts.sum()
    class_weights = total_pixels / (num_classes * class_counts + 1e-6)

    # Normalize so minimum weight = 1.0
    class_weights = class_weights / class_weights.min()

    print(f"\nClass Distribution and Weights:")
    print(f"{'Class ID':<10} {'Name':<25} {'Pixel Count':<15} {'Weight':<10}")
    print("-" * 70)
    for c in range(num_classes):
        if class_counts[c] > 0:
            print(f"{c:<10} {UNIFIED_CATEGORIES[c]:<25} {int(class_counts[c]):,<15} {class_weights[c]:.3f}")

    print("="*70)
    return class_weights.to(device)

print("‚úì Class weight calculation function defined!")


In [None]:
# ========================================
# DIAGNOSTIC: CHECK WHAT FILES EXIST
# ========================================

print("=" * 70)
print("DIAGNOSTIC: CHECKING FILE STRUCTURE")
print("=" * 70)

for dataset_name, folder_name in DATASETS.items():
    print(f"\n{'='*70}")
    print(f"Dataset: {dataset_name}")
    print(f"{'='*70}")

    # Build paths
    image_dir = os.path.join(PROJECT_FOLDER_IN_DRIVE, folder_name, 'train')
    mask_dir = os.path.join(PROJECT_FOLDER_IN_DRIVE, folder_name, 'train', 'masks')

    print(f"Image directory: {image_dir}")
    print(f"Exists: {os.path.exists(image_dir)}")

    if os.path.exists(image_dir):
        all_files = os.listdir(image_dir)
        print(f"Total items in directory: {len(all_files)}")

        # Check for image files
        jpg_files = [f for f in all_files if f.endswith('.jpg') and os.path.isfile(os.path.join(image_dir, f))]
        png_files = [f for f in all_files if f.endswith('.png') and os.path.isfile(os.path.join(image_dir, f))]
        jpeg_files = [f for f in all_files if f.endswith('.jpeg') and os.path.isfile(os.path.join(image_dir, f))]

        print(f"  - .jpg files: {len(jpg_files)}")
        print(f"  - .png files: {len(png_files)}")
        print(f"  - .jpeg files: {len(jpeg_files)}")

        # Show first few files
        image_files = jpg_files + png_files + jpeg_files
        if len(image_files) > 0:
            print(f"  - Sample files: {image_files[:3]}")
        else:
            print(f"  - ‚ö†Ô∏è  NO IMAGE FILES FOUND!")
            print(f"  - All items in directory: {all_files[:10]}")

    print(f"\nMask directory: {mask_dir}")
    print(f"Exists: {os.path.exists(mask_dir)}")

    if os.path.exists(mask_dir):
        mask_files = [f for f in os.listdir(mask_dir) if f.endswith('.png')]
        print(f"  - Mask files: {len(mask_files)}")
        if len(mask_files) > 0:
            print(f"  - Sample masks: {mask_files[:3]}")
        else:
            print(f"  - ‚ö†Ô∏è  NO MASK FILES FOUND!")

print("\n" + "=" * 70)
print("DIAGNOSTIC COMPLETE")
print("=" * 70)


In [None]:
# ========================================
# COMBINED DATASET CLASS (ROBUST VERSION)
# ========================================

class CombinedOCTDataset(Dataset):
    """
    Dataset that combines multiple OCT disease datasets
    Includes 'Smart Search' to handle filename mismatches
    """
    def __init__(self, datasets_config, project_folder, transform=None):
        self.transform = transform
        self.samples = []  # List of (image_path, mask_path, dataset_name) tuples

        # Load samples from each dataset
        for dataset_name, folder_name in datasets_config.items():
            # Try multiple possible image directory locations
            possible_image_dirs = [
                os.path.join(project_folder, folder_name, 'train'),
                os.path.join(project_folder, folder_name, 'train', 'images'),
            ]

            mask_dir = os.path.join(project_folder, folder_name, 'train', 'masks')

            # Find valid image directory
            image_dir = None
            for possible_dir in possible_image_dirs:
                if os.path.exists(possible_dir):
                    # Check if directory contains images (not just folders)
                    has_images = any(f.lower().endswith(('.jpg', '.png', '.jpeg')) for f in os.listdir(possible_dir))
                    if has_images:
                        image_dir = possible_dir
                        break

            if image_dir is None or not os.path.exists(mask_dir):
                print(f"‚ö†Ô∏è  Skipping {dataset_name}: Images or Masks folder missing.")
                continue

            # Get list of images
            images = sorted([f for f in os.listdir(image_dir)
                           if f.lower().endswith(('.jpg', '.png', '.jpeg'))
                           and not f.startswith('.')
                           and 'mask' not in f.lower()])

            dataset_count = 0

            for img_file in images:
                img_path = os.path.join(image_dir, img_file)

                # STRATEGY 1: Standard Naming (image.01.jpg -> image.01.png)
                base_name_strict = img_file.rsplit('.', 1)[0]
                mask_file_strict = base_name_strict + '.png'
                mask_path_strict = os.path.join(mask_dir, mask_file_strict)

                # STRATEGY 2: Simple Naming (image.01.jpg -> image.png)
                base_name_simple = img_file.split('.')[0]
                mask_file_simple = base_name_simple + '.png'
                mask_path_simple = os.path.join(mask_dir, mask_file_simple)

                # Check which one exists
                if os.path.exists(mask_path_strict):
                    self.samples.append((img_path, mask_path_strict, dataset_name))
                    dataset_count += 1
                elif os.path.exists(mask_path_simple):
                    self.samples.append((img_path, mask_path_simple, dataset_name))
                    dataset_count += 1

            print(f"  - {dataset_name}: Loaded {dataset_count} pairs (Total images in folder: {len(images)})")
            if dataset_count == 0:
                print(f"    ‚ùå DEBUG: Looked for masks named '{mask_file_strict}' OR '{mask_file_simple}' but found neither.")

        print(f"‚úì Loaded {len(self.samples)} total image-mask pairs.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        img_path, mask_path, dataset_name = self.samples[index]

        # Read image and mask
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Safety check for empty files
        if image is None:
            print(f"Error loading image: {img_path}")
            # Return a dummy tensor to prevent crash, or raise error
            raise ValueError(f"Failed to load image: {img_path}")
        if mask is None:
            print(f"Error loading mask: {mask_path}")
            raise ValueError(f"Failed to load mask: {mask_path}")

        # Preprocessing (Same as before)
        try:
            image = cv2.fastNlMeansDenoising(image, None, h=10, searchWindowSize=21)
            gaussian_3 = cv2.GaussianBlur(image, (0, 0), 2.0)
            unsharp_image = cv2.addWeighted(image, 2.0, gaussian_3, -1.0, 0)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            image = clahe.apply(unsharp_image)
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        except Exception as e:
            print(f"Preprocessing error on {img_path}: {e}")
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # Fallback

        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask'].long()
        else:
            image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
            mask = torch.from_numpy(mask).long()

        return image, mask

print("‚úì Robust CombinedOCTDataset class updated!")

In [None]:
# ========================================
# INITIALIZE MODEL, OPTIMIZER, AND DATALOADERS
# ========================================
#
# ‚ö†Ô∏è  IMPORTANT: Before running this cell, make sure you have:
#    1. Run Cell 4 to generate masks from COCO annotations
#    2. Verified that masks were created successfully
#
# If you get an error about 0 images, go back and run Cell 4 first!
# ========================================

import gc
from torch.utils.data import random_split, Subset

# GPU Cleanup
print("Clearing CUDA cache...")
torch.cuda.empty_cache()
gc.collect()

if torch.cuda.is_available():
    torch.cuda.synchronize()
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")

# Initialize Model
print(f"\nInitializing U-Net model with {NUM_CLASSES} output classes...")
model = UNet(in_channels=3, out_channels=NUM_CLASSES).to(DEVICE)
print(f"‚úì Model initialized!")

# Initialize Optimizer (will be used after setting up loss function)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
print(f"‚úì Optimizer initialized with learning rate: {LEARNING_RATE}")

# Create Combined Dataset
print("\n" + "="*70)
print("CREATING COMBINED DATASET")
print("="*70)

full_dataset = CombinedOCTDataset(
    datasets_config=DATASETS,
    project_folder=PROJECT_FOLDER_IN_DRIVE,
    transform=None
)

# Check if dataset is empty
if len(full_dataset) == 0:
    print("\n" + "="*70)
    print("‚ö†Ô∏è  ERROR: NO IMAGES FOUND!")
    print("="*70)
    print("\n‚ö†Ô∏è  The dataset is empty. This means masks haven't been generated yet.")
    print("\nüìù TO FIX THIS:")
    print("   1. Go back and run Cell 4: 'GENERATE UNIFIED MASKS FOR ALL DATASETS'")
    print("   2. Wait for mask generation to complete (this may take a few minutes)")
    print("   3. Then come back and run this cell again")
    print("\nüí° Cell 4 will create masks from COCO annotations and save them to:")
    for dataset_name, folder_name in DATASETS.items():
        mask_dir = os.path.join(PROJECT_FOLDER_IN_DRIVE, folder_name, 'train', 'masks')
        print(f"   - {mask_dir}")
    print("="*70)
    raise ValueError("Dataset is empty. Please run Cell 4 to generate masks first!")

# Split 80/20
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_indices, val_indices = random_split(
    range(len(full_dataset)),
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"\n‚úì Training: {len(train_indices)} images")
print(f"‚úì Validation: {len(val_indices)} images")

# Create datasets with transforms
train_dataset = CombinedOCTDataset(
    datasets_config=DATASETS,
    project_folder=PROJECT_FOLDER_IN_DRIVE,
    transform=train_transform
)

val_dataset = CombinedOCTDataset(
    datasets_config=DATASETS,
    project_folder=PROJECT_FOLDER_IN_DRIVE,
    transform=val_transform
)

train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(val_dataset, val_indices)

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("‚úì DataLoaders created!")
print(f"  - Training batches: {len(train_loader)}")
print(f"  - Validation batches: {len(val_loader)}")
print(f"  - Batch size: {BATCH_SIZE}")


In [None]:
# ========================================
# SETUP LOSS FUNCTION WITH CLASS WEIGHTS
# ========================================

# Calculate class weights from training data
class_weights = calculate_class_weights(train_loader, NUM_CLASSES, DEVICE)

# Use EdgeAwareLoss instead of basic CrossEntropyLoss
# This combines CE + Dice + Edge loss for better boundary detection
loss_fn = EdgeAwareLoss(alpha=0.5, beta=0.25, weight=class_weights)

print(f"\n‚úì Using EdgeAwareLoss with class weights")
print(f"  - Alpha (CE weight): 0.5")
print(f"  - Beta (Edge weight): 0.25")
print(f"  - Dice weight: 0.5 (1-alpha)")

# ========================================
# SETUP LEARNING RATE SCHEDULER
# ========================================

# ReduceLROnPlateau: Reduces LR when Dice score stops improving
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',              # Maximize Dice score
    factor=0.5,              # Reduce LR by 50%
    patience=10,             # Wait 10 epochs before reducing
    min_lr=1e-7,             # Minimum learning rate
    threshold=0.001          # Minimum improvement threshold
)

print(f"\n‚úì Learning rate scheduler initialized")
print(f"  - Initial LR: {LEARNING_RATE}")
print(f"  - Patience: 10 epochs")
print(f"  - Reduction factor: 0.5x")
print(f"  - Minimum LR: 1e-7")


In [None]:
# ========================================
# TRAINING FUNCTION
# ========================================

def train_fn(loader, model, optimizer, loss_fn, device):
    loop = tqdm(loader, desc="Training")
    total_loss = 0.0

    model.train()

    for batch_idx, (data, targets) in enumerate(loop):
        data, targets = data.to(device=device), targets.to(device=device)

        predictions = model(data)
        loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(loader)
    print(f"Average training loss: {avg_loss:.4f}")
    return avg_loss


# ========================================
# EVALUATION FUNCTION
# ========================================

def check_metrics(loader, model, device, num_classes):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    iou_score = 0

    model.eval()

    with torch.no_grad():
        for x, y in tqdm(loader, desc="Calculating Metrics"):
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            preds = torch.argmax(preds, dim=1)

            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)

            y_one_hot = torch.nn.functional.one_hot(y, num_classes=num_classes).permute(0, 3, 1, 2)
            preds_one_hot = torch.nn.functional.one_hot(preds, num_classes=num_classes).permute(0, 3, 1, 2)

            intersection = (preds_one_hot * y_one_hot).float().sum()
            union = preds_one_hot.float().sum() + y_one_hot.float().sum()

            dice_score += (2. * intersection) / (union + 1e-8)
            iou_score += intersection / (union - intersection + 1e-8)

    pixel_acc = (num_correct / num_pixels) * 100
    avg_dice = dice_score / len(loader)
    avg_iou = iou_score / len(loader)

    print(f"\n{'='*60}")
    print(f"Pixel Accuracy: {pixel_acc:.2f}%")
    print(f"Dice Score: {avg_dice:.4f}")
    print(f"IoU Score: {avg_iou:.4f}")
    print(f"{'='*60}")

    model.train()

    return pixel_acc, avg_dice, avg_iou


print("‚úì Training and evaluation functions defined!")

In [None]:
import matplotlib.pyplot as plt
from torch.utils.data import random_split, Subset
import gc

# GPU Cleanup
print("Clearing CUDA cache...")
torch.cuda.empty_cache()
gc.collect()

if torch.cuda.is_available():
    torch.cuda.synchronize()
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")

# Initialize Model
print(f"\nInitializing U-Net model with {NUM_CLASSES} output classes...")
model = UNet(in_channels=3, out_channels=NUM_CLASSES).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
print(f"‚úì Model initialized!")

# Create Combined Dataset
print("\n" + "="*70)
print("CREATING COMBINED DATASET")
print("="*70)

full_dataset = CombinedOCTDataset(
    datasets_config=DATASETS,
    project_folder=PROJECT_FOLDER_IN_DRIVE,
    transform=None
)

# Split 80/20
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_indices, val_indices = random_split(
    range(len(full_dataset)),
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"\n‚úì Training: {len(train_indices)} images")
print(f"‚úì Validation: {len(val_indices)} images")

# Create datasets with transforms
train_dataset = CombinedOCTDataset(
    datasets_config=DATASETS,
    project_folder=PROJECT_FOLDER_IN_DRIVE,
    transform=train_transform
)

val_dataset = CombinedOCTDataset(
    datasets_config=DATASETS,
    project_folder=PROJECT_FOLDER_IN_DRIVE,
    transform=val_transform
)

train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(val_dataset, val_indices)

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("‚úì DataLoaders created!")

# Training Loop
print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70)

training_losses = []
validation_accuracies = []
validation_dice_scores = []
validation_iou_scores = []
best_dice = 0.0

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch+1}/{NUM_EPOCHS}")
    print(f"{'='*70}")

    avg_loss = train_fn(train_loader, model, optimizer, loss_fn, DEVICE)
    training_losses.append(avg_loss)

    print("\nValidation:")
    pixel_acc, dice, iou = check_metrics(val_loader, model, DEVICE, NUM_CLASSES)
    validation_accuracies.append(pixel_acc.cpu().numpy() if torch.is_tensor(pixel_acc) else pixel_acc)
    validation_dice_scores.append(dice.cpu().numpy() if torch.is_tensor(dice) else dice)
    validation_iou_scores.append(iou.cpu().numpy() if torch.is_tensor(iou) else iou)

    if dice > best_dice:
        best_dice = dice
        torch.save(model.state_dict(), os.path.join(PROJECT_FOLDER_IN_DRIVE, "unet_combined_best.pth"))
        print(f"‚úì New best model saved! (Dice: {best_dice:.4f})")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)

# Plot results
training_losses = np.array(training_losses)
validation_accuracies = np.array(validation_accuracies)
validation_dice_scores = np.array(validation_dice_scores)
validation_iou_scores = np.array(validation_iou_scores)

plt.figure(figsize=(18, 5))

plt.subplot(1, 3, 1)
plt.plot(range(1, NUM_EPOCHS + 1), training_losses, 'b-', linewidth=2)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(range(1, NUM_EPOCHS + 1), validation_accuracies, 'g-', linewidth=2, label='Pixel Accuracy')
plt.plot(range(1, NUM_EPOCHS + 1), validation_dice_scores * 100, 'r-', linewidth=2, label='Dice (√ó100)')
plt.title('Validation Metrics')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(range(1, NUM_EPOCHS + 1), validation_iou_scores, 'purple', linewidth=2)
plt.title('IoU Score')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(PROJECT_FOLDER_IN_DRIVE, 'training_metrics.png'), dpi=300)
plt.show()

print(f"\nFinal Loss: {training_losses[-1]:.4f}")
print(f"Final Accuracy: {validation_accuracies[-1]:.2f}%")
print(f"Final Dice: {validation_dice_scores[-1]:.4f}")
print(f"Final IoU: {validation_iou_scores[-1]:.4f}")
print(f"Best Dice: {best_dice:.4f}")

In [None]:
import matplotlib.patches as mpatches
from google.colab import files
from PIL import Image
import io

# ========================================
# SINGLE IMAGE PREDICTION WITH MANUAL UPLOAD
# ========================================

def predict_single_image(image_path, model_path, device='cuda'):
    """Predict segmentation mask for a single image"""
    # Load the model
    print(f"Loading model from: {model_path}")
    model = UNet(in_channels=3, out_channels=NUM_CLASSES).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print("‚úì Model loaded successfully!")

    # Load and preprocess the image
    print(f"\nLoading image: {image_path}")
    original_image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    if original_image is None:
        raise ValueError(f"Could not load image from {image_path}")

    print(f"Original image shape: {original_image.shape}")

    # Apply preprocessing
    image = cv2.fastNlMeansDenoising(original_image, None, h=10, searchWindowSize=21)
    gaussian_3 = cv2.GaussianBlur(image, (0, 0), 2.0)
    unsharp_image = cv2.addWeighted(image, 2.0, gaussian_3, -1.0, 0)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    image = clahe.apply(unsharp_image)

    # Convert grayscale to RGB for albumentations transforms
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    # Apply validation transforms
    augmented = val_transform(image=image, mask=np.zeros_like(image))
    image_tensor = augmented['image'].unsqueeze(0).to(device)

    # Predict
    print("Running inference...")
    with torch.no_grad():
        prediction = model(image_tensor)
        predicted_mask = torch.argmax(prediction, dim=1).squeeze(0).cpu().numpy()

    print(f"Predicted mask shape: {predicted_mask.shape}")
    print(f"Unique classes in prediction: {np.unique(predicted_mask)}")

    # Create colored mask
    color_map = {
        0: [0, 0, 0], 1: [255, 0, 0], 2: [0, 255, 0], 3: [0, 0, 255],
        4: [255, 255, 0], 5: [255, 0, 255], 6: [0, 255, 255], 7: [255, 128, 0],
        8: [128, 0, 255], 9: [255, 192, 203], 10: [0, 128, 128], 11: [128, 128, 0],
        12: [192, 192, 192]
    }

    colored_mask = np.zeros((predicted_mask.shape[0], predicted_mask.shape[1], 3), dtype=np.uint8)
    for class_id, color in color_map.items():
        colored_mask[predicted_mask == class_id] = color

    print("‚úì Prediction complete!")
    return original_image, predicted_mask, colored_mask


def visualize_prediction(original_image, predicted_mask, colored_mask, save_path=None):
    """Visualize the prediction results"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    axes[0].imshow(original_image, cmap='gray')
    axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')

    axes[1].imshow(colored_mask)
    axes[1].set_title('Predicted Segmentation', fontsize=14, fontweight='bold')
    axes[1].axis('off')

    overlay = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
    overlay = cv2.addWeighted(overlay, 0.6, colored_mask, 0.4, 0)
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay', fontsize=14, fontweight='bold')
    axes[2].axis('off')

    # Create legend
    unique_classes = np.unique(predicted_mask)
    legend_patches = []
    color_map = {
        0: [0, 0, 0], 1: [255, 0, 0], 2: [0, 255, 0], 3: [0, 0, 255],
        4: [255, 255, 0], 5: [255, 0, 255], 6: [0, 255, 255], 7: [255, 128, 0],
        8: [128, 0, 255], 9: [255, 192, 203], 10: [0, 128, 128], 11: [128, 128, 0],
        12: [192, 192, 192]
    }

    for class_id in unique_classes:
        if class_id in UNIFIED_CATEGORIES:
            color = np.array(color_map[class_id]) / 255.0
            label = UNIFIED_CATEGORIES[class_id]
            legend_patches.append(mpatches.Patch(color=color, label=label))

    plt.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úì Visualization saved to: {save_path}")

    plt.show()


# ========================================
# MANUAL IMAGE UPLOAD AND PREDICTION
# ========================================

print("=" * 70)
print("SINGLE IMAGE PREDICTION - MANUAL UPLOAD")
print("=" * 70)

# Model path
MODEL_PATH = os.path.join(PROJECT_FOLDER_IN_DRIVE, 'unet_combined_best.pth')

# Check if model exists
if not os.path.exists(MODEL_PATH):
    print(f"\n‚ùå Model not found at: {MODEL_PATH}")
    print("Please train the model first or update the MODEL_PATH variable.")
else:
    print("\nüì§ Please upload an OCT image for segmentation prediction")
    print("Click 'Choose Files' button below to upload your image...")

    # Upload image
    uploaded = files.upload()

    if uploaded:
        # Get the uploaded filename
        uploaded_filename = list(uploaded.keys())[0]
        print(f"\n‚úì Image uploaded: {uploaded_filename}")

        # Save uploaded image temporarily
        temp_image_path = f"/content/{uploaded_filename}"

        # Read the uploaded image
        try:
            # Try to read the image
            img = Image.open(io.BytesIO(uploaded[uploaded_filename]))

            # Convert to grayscale if needed and save
            if img.mode != 'L':
                img = img.convert('L')

            img.save(temp_image_path)
            print(f"‚úì Image saved to: {temp_image_path}")

            # Predict
            print("\n" + "=" * 70)
            original, mask, colored = predict_single_image(
                image_path=temp_image_path,
                model_path=MODEL_PATH,
                device=DEVICE
            )

            # Visualize
            save_path = os.path.join(PROJECT_FOLDER_IN_DRIVE, 'prediction_result.png')
            visualize_prediction(original, mask, colored, save_path=save_path)

            # Print statistics
            print("\n" + "=" * 70)
            print("PREDICTION STATISTICS")
            print("=" * 70)
            unique, counts = np.unique(mask, return_counts=True)
            total_pixels = mask.size

            print(f"\n{'Class ID':<10} {'Class Name':<25} {'Pixels':<12} {'Percentage':<10}")
            print("-" * 70)
            for class_id, count in zip(unique, counts):
                class_name = UNIFIED_CATEGORIES.get(class_id, 'Unknown')
                percentage = (count / total_pixels) * 100
                print(f"{class_id:<10} {class_name:<25} {count:<12} {percentage:<10.2f}%")

            # Clean up temporary file
            if os.path.exists(temp_image_path):
                os.remove(temp_image_path)
                print(f"\n‚úì Temporary file cleaned up")

        except Exception as e:
            print(f"\n‚ùå Error processing image: {e}")
    else:
        print("\n‚ö†Ô∏è No image was uploaded. Please run this cell again to upload an image.")

print("\n" + "=" * 70)
print("To predict on another image, simply run this cell again!")
print("=" * 70)
