# Retinal Fundus Dataset Visualization

This notebook visualizes the DRIVE, HRF, CHASEDB1, and STARE datasets for retinal vessel segmentation. It includes dataset descriptions, folder hierarchy, image counts, formats, resolutions, and visualizations of sample images and data distributions.

**You can get a quick overview of major dataset on this link too** [Datasets Overview](https://www.medicmind.tech/retinal-image-databases)

## Step 1: Dataset Folder Hierarchy

The datasets are stored in the `/kaggle/input/final_dataset/` directory with the following structure (corrected for the typo in DRIVE training masks):

```
final_dataset/
├── CHASEDB1/
│   ├── images/
│   │   ├── Image_01L.jpg
│   │   ├── Image_01R.jpg
│   │   └── ... (28 images)
│   ├── labels1/
│   │   ├── Image_01L_1stHO.png
│   │   ├── Image_01R_1stHO.png
│   │   └── ... (28 images)
│   ├── labels2/
│   │   ├── Image_01L_2ndHO.png
│   │   ├── Image_01R_2ndHO.png
│   │   └── ... (28 images)
│   ├── masks/
│   │   ├── mask_01L.png
│   │   ├── mask_01R.png
│   │   └── ... (28 images)
├── DRIVE/
│   ├── test/
│   │   ├── 1st_manual/
│   │   │   ├── 01_manual1.gif
│   │   │   ├── 02_manual1.gif
│   │   │   └── ... (20 images)
│   │   ├── 2nd_manual/
│   │   │   ├── 01_manual2.gif
│   │   │   ├── 02_manual2.gif
│   │   │   └── ... (20 images)
│   │   ├── images/
│   │   │   ├── 01_test.tif
│   │   │   ├── 02_test.tif
│   │   │   └── ... (20 images)
│   │   ├── mask/
│   │   │   ├── 01_test_mask.gif
│   │   │   ├── 02_test_mask.gif
│   │   │   └── ... (20 images)
│   ├── training/
│   │   ├── 1st_manual/
│   │   │   ├── 21_manual1.gif
│   │   │   ├── 22_manual1.gif
│   │   │   └── ... (20 images)
│   │   ├── images/
│   │   │   ├── 21_training.tif
│   │   │   ├── 22_training.tif
│   │   │   └── ... (20 images)
│   │   ├── mask/
│   │   │   ├── 21_training_mask.gif
│   │   │   ├── 22_training_mask.gif
│   │   │   └── ... (20 images)
├── HRF/
│   ├── images/
│   │   ├── 01_dr.JPG
│   │   ├── 01_g.jpg
│   │   ├── 01_h.jpg
│   │   └── ... (45 images)
│   ├── manual1/
│   │   ├── 01_dr.tif
│   │   │   ├── 01_g.tif
│   │   │   ├── 01_h.tif
│   │   │   └── ... (45 images)
│   ├── mask/
│   │   ├── 01_dr_mask.tif
│   │   │   ├── 01_g_mask.tif
│   │   │   ├── 01_h_mask.tif
│   │   │   └── ... (45 images)
├── STARE/
│   ├── images/
│   │   ├── im0001.ppm
│   │   ├── im0002.ppm
│   │   └── ... (20 images, non-sequential numbering)
│   ├── labels-ah/
│   │   ├── im0001.ah.ppm
│   │   ├── im0002.ah.ppm
│   │   └── ... (20 images, non-sequential numbering)
│   ├── masks/
│   │   ├── mask_0001.png
│   │   ├── mask_0002.png
│   │   └── ... (20 images, non-sequential numbering)
```

## Step 3: CHASEDB1 Dataset Overview

**Dataset Details:**
- **Online/Release**: CHASEDB1 (Child Heart and Health Study in England Database) is a publicly available dataset for retinal vessel segmentation, containing 28 fundus images of left and right eyes from 14 children. Originally, masks were not included but were retrieved from [GitHub](https://github.com/zhengyuan-liu/Retinal-Vessel-Segmentation). Original link: [CHASEDB1](https://www.kaggle.com/datasets/khoongweihao/chasedb1/data)
- **In Code**: Stored in `/kaggle/input/final_dataset/CHASEDB1/` with subfolders: `images` (.jpg), `labels1` (.png, first annotator), `labels2` (.png, second annotator), and `masks` (.png).
- **Count**: 28 images, 28 masks, 28 first annotations, 28 second annotations.
- **Format**: Images (.jpg), Masks/Labels (.png).
- **Resolution**: Typically 999x960 pixels.
- **Division**: No predefined train/test split.

**Output**: Displays counts, unique resolutions, and a row of sample images (original, mask, first annotation, second annotation).

## Step 4: DRIVE Dataset Overview

**Dataset Details:**
- **Online/Release**: DRIVE (Digital Retinal Images for Vessel Extraction) contains 40 fundus images, split into 20 training and 20 test images. Test set manual annotations were originally missing but retrieved from [GitHub](https://github.com/zhengyuan-liu/Retinal-Vessel-Segmentation). Original link: [DRIVE](https://www.kaggle.com/datasets/andrewmvd/drive-digital-retinal-images-for-vessel-extraction)
- **In Code**: Stored in `/kaggle/input/final_dataset/DRIVE/` with `test` and `training` subfolders, each containing `images` (.tif), `1st_manual` (.gif), `2nd_manual` (.gif, test only), and `mask` (.gif).
- **Count**: Training: 20 images, 20 masks, 20 first annotations. Test: 20 images, 20 masks, 20 first annotations, 20 second annotations.
- **Format**: Images (.tif), Masks/Annotations (.gif).
- **Resolution**: 565x584 pixels.
- **Division**: Predefined train/test split (20/20).

**Output**: Displays counts, resolutions, and two rows of sample images (training and test sets).

## Step 5: HRF Dataset Overview

**Dataset Details:**
- **Online/Release**: HRF (High-Resolution Fundus) contains 45 high-resolution fundus images: 15 healthy (h), 15 diabetic retinopathy (dr), 15 glaucomatous (g). Available publicly with masks and annotations. Original link: [HRF](https://datasetninja.com/high-resolution-fundus
)
- **In Code**: Stored in `/kaggle/input/final_dataset/HRF/` with `images` (.JPG), `manual1` (.tif), and `mask` (.tif).
- **Count**: 45 images, 45 masks, 45 annotations.
- **Format**: Images (.JPG), Masks/Annotations (.tif).
- **Resolution**: Typically 3504x2336 pixels.
- **Division**: No predefined train/test split.

**Output**: Displays counts, resolutions, sample images for each category, and a bar plot of category distribution.

## Step 6: STARE Dataset Overview

**Dataset Details:**
- **Online/Release**: STARE (Structured Analysis of the Retina) contains 20 fundus images with non-sequential numbering. Masks were not originally included but retrieved from [GitHub](https://github.com/zhengyuan-liu/Retinal-Vessel-Segmentation). Original Link: [Official STARE](https://cecas.clemson.edu/~ahoover/stare/) , [Kaggle STARE](https://www.kaggle.com/datasets/vidheeshnacode/stare-dataset)
- **In Code**: Stored in `/kaggle/input/final_dataset/STARE/` with `images` (.ppm), `labels-ah` (.ppm), and `masks` (.png). File names handled dynamically due to irregular numbering.
- **Count**: 20 images, 20 masks, 20 annotations.
- **Format**: Images/Annotations (.ppm), Masks (.png).
- **Resolution**: Typically 700x605 pixels.
- **Division**: No predefined train/test split.

**Output**: Displays counts, resolutions, and sample images, dynamically handling file names.

# Model For Training: R2U-Net for Retinal Blood Vessel Segmentation



## Step 1: Import Libraries

We import necessary libraries for data handling, image processing, model building, and evaluation. TensorFlow/Keras is used for the R2U-Net implementation, and Matplotlib for visualization.

**Output**: Libraries are loaded, and the base path is set.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_fscore_support
import random
import cv2
import logging

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)
random.seed(42)

# Base path for dataset
base_path = '/kaggle/input/final-dataset'

## Step 2: Define Helper Functions

We define functions for loading images, extracting patches, and computing evaluation metrics (Accuracy, Sensitivity, Specificity, F1-Score, Dice Coefficient, Jaccard Similarity, AUC). These are used across all datasets.

**Output**: Helper functions are defined for later use.

In [None]:
def load_image(image_path):
    """Load and preprocess retinal fundus image to enhance vessel visibility with controlled contrast."""
    img = Image.open(image_path)
    
    # Convert to RGB if not already (in case of grayscale or other formats)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    
    # Extract green channel (vessels are most prominent in green channel)
    img_array = np.array(img)
    green_channel = img_array[:, :, 1]  # Green channel (index 1 in RGB)
    
    # Apply CLAHE to enhance contrast with reduced amplification
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    green_channel = clahe.apply(green_channel)
    
    return green_channel

def extract_patches(image, mask, patch_size=48, num_patches=10000):
    """Extract random patches from image and mask."""
    patches = []
    labels = []
    h, w = image.shape
    for _ in range(num_patches):
        y = random.randint(0, h - patch_size)
        x = random.randint(0, w - patch_size)
        patch = image[y:y+patch_size, x:x+patch_size]
        label = mask[y:y+patch_size, x:x+patch_size]
        patches.append(patch[..., np.newaxis])  # Add channel dimension
        labels.append(label[..., np.newaxis] / 255.0)  # Normalize to [0,1]
    return np.array(patches), np.array(labels)

def extract_patches_for_eval(image, patch_size=48):
    """Extract all possible 48x48 non-overlapping patches from an image for evaluation."""
    patches = []
    h, w = image.shape[:2]
    for y in range(0, h - patch_size + 1, patch_size):  # Non-overlapping patches
        for x in range(0, w - patch_size + 1, patch_size):
            patch = image[y:y+patch_size, x:x+patch_size]
            patches.append(patch[..., np.newaxis])  # Add channel dimension
    return np.array(patches), h, w

def reconstruct_image(patches, h, w, patch_size=48):
    """Reconstruct image from non-overlapping patches."""
    recon = np.zeros((h, w, 1))
    patch_idx = 0
    for y in range(0, h - patch_size + 1, patch_size):
        for x in range(0, w - patch_size + 1, patch_size):
            if patch_idx < len(patches):  # Ensure index is valid
                recon[y:y+patch_size, x:x+patch_size] = patches[patch_idx]
            patch_idx += 1
    return recon

def compute_metrics(y_true, y_pred):
    """Compute evaluation metrics: AC, SE, SP, F1, DC, JS, AUC."""
    y_true = (y_true.flatten() > 0.5).astype(np.int32)  # Convert float64 to int32
    y_pred_binary = (y_pred.flatten() > 0.5).astype(np.int32)  # Threshold at 0.5
    # Replace NaN or inf with 0 to avoid issues in metrics computation
    y_pred_binary = np.nan_to_num(y_pred_binary, nan=0, posinf=0, neginf=0)
    confusion = y_true * 2 + y_pred_binary
    tn, fp, fn, tp = np.bincount(confusion, minlength=4)[0:4]
    ac = (tp + tn) / (tp + tn + fp + fn + 1e-10)
    se = tp / (tp + fn + 1e-10)
    sp = tn / (tn + fp + 1e-10)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred_binary, average='binary', zero_division=0)
    dc = 2 * (precision * recall) / (precision + recall + 1e-10)
    js = precision * recall / (precision + recall - precision * recall + 1e-10)
    # Handle NaN in y_pred before AUC calculation
    y_pred_flat = np.nan_to_num(y_pred.flatten(), nan=0)
    auc = roc_auc_score(y_true, y_pred_flat)
    return {'AC': ac, 'SE': se, 'SP': sp, 'F1': f1, 'DC': dc, 'JS': js, 'AUC': auc}

def plot_sample_patches(patches, labels, preds, num_samples=5):
    """Plot sample patches, ground truth, and predictions."""
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    for i in range(num_samples):
        axes[i, 0].imshow(patches[i, :, :, 0], cmap='gray')
        axes[i, 0].set_title('Input Patch')
        axes[i, 1].imshow(labels[i, :, :, 0], cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 2].imshow(preds[i, :, :, 0] > 0.5, cmap='gray')
        axes[i, 2].set_title('Prediction')
        for ax in axes[i]: ax.axis('off')
    plt.tight_layout()
    plt.show()

##########   New functions added for better training process (i.e preventing overfitting)   ############
@tf.keras.utils.register_keras_serializable()
def dice_loss(y_true, y_pred, smooth=1e-6):
    """Compute Dice loss for binary segmentation."""
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)

@tf.keras.utils.register_keras_serializable()
def combined_loss(y_true, y_pred):
    """Combine binary cross-entropy and Dice loss."""
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return 0.5 * bce + 0.5 * dice

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Compute Dice coefficient for evaluation during training."""
    y_true_f = tf.keras.backend.flatten(tf.cast(y_true, tf.float32))  # Cast to float32
    y_pred_f = tf.keras.backend.flatten(tf.cast(tf.where(y_pred > 0.5, 1.0, 0.0), tf.float32))  # Cast to float32
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)
    
class DiceCoefficientCallback(tf.keras.callbacks.Callback):
    """Custom callback to compute Dice Coefficient and save best model based on val_dice_coefficient."""
    def __init__(self, checkpoint_path, val_patches, val_labels, monitor='val_dice_coefficient', mode='max'):
        super(DiceCoefficientCallback, self).__init__()
        self.checkpoint_path = checkpoint_path
        self.val_patches = val_patches
        self.val_labels = val_labels
        self.monitor = monitor
        self.mode = mode
        self.best = -float('inf') if mode == 'max' else float('inf')
        self.best_weights = None

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        # Compute Dice Coefficient for validation
        val_pred = self.model.predict(self.val_patches, verbose=0)
        current_val_dice = dice_coefficient(self.val_labels, val_pred)
        logs['val_dice_coefficient'] = float(current_val_dice)
        
        # Check if current epoch is the best based on monitor
        current = logs.get(self.monitor)
        if current is None:
            return
        
        if (self.mode == 'max' and current > self.best) or (self.mode == 'min' and current < self.best):
            self.best = current
            self.best_weights = self.model.get_weights()
            self.model.save(self.checkpoint_path)
            print(f"\nEpoch {epoch + 1}: saving model to {self.checkpoint_path} with {self.monitor}={current:.4f}")
        
        # Log metrics
        logs['val_dice_coefficient'] = float(current_val_dice)

## Step 3: Load and Preprocess Datasets

We load DRIVE, STARE, CHASE_DB1, and HRF datasets, applying the paper’s preprocessing (cropping DRIVE to 565x565, using provided masks for STARE/CHASE_DB1, patch-based for HRF). We extract 190,000 patches for DRIVE and 250,000 for STARE, CHASE_DB1, and HRF, following the paper’s splits.

**Output**: Prints the shapes of training/validation patches and the number of test images for each dataset. Expected shapes: DRIVE (171,000/19,000 patches, 20 test images), STARE/CHASE_DB1/HRF (225,000/25,000 patches, 4/8/9 test images).

In [None]:
def load_dataset(dataset_name, total_patches, train_patches, val_patches):
    """Load images and extract patches for a dataset with predefined splits."""
    dataset_path = os.path.join(base_path, dataset_name)
    
    if dataset_name == 'DRIVE':
        train_image_path = os.path.join(dataset_path, 'training/images')
        train_mask_path = os.path.join(dataset_path, 'training/mask')
        train_label_path = os.path.join(dataset_path, 'training/1st_manual')
        test_image_path = os.path.join(dataset_path, 'test/images')
        test_mask_path = os.path.join(dataset_path, 'test/mask')
        test_label_path = os.path.join(dataset_path, 'test/1st_manual')
        
        train_image_files = sorted([f for f in os.listdir(train_image_path) if f.endswith(('.tif', '.TIF'))])
        train_mask_files = sorted([f for f in os.listdir(train_mask_path) if f.endswith(('.gif', '.GIF'))])
        train_label_files = sorted([f for f in os.listdir(train_label_path) if f.endswith(('.gif', '.GIF'))])
        test_image_files = sorted([f for f in os.listdir(test_image_path) if f.endswith(('.tif', '.TIF'))])
        test_mask_files = sorted([f for f in os.listdir(test_mask_path) if f.endswith(('.gif', '.GIF'))])
        test_label_files = sorted([f for f in os.listdir(test_label_path) if f.endswith(('.gif', '.GIF'))])
        
        train_files = train_image_files
        test_files = test_image_files
    else:
        image_path = os.path.join(dataset_path, 'images')
        mask_path = os.path.join(dataset_path, 'masks' if dataset_name in ['STARE', 'CHASEDB1'] else 'mask')
        label_path = os.path.join(dataset_path, 'labels-ah' if dataset_name == 'STARE' else 'labels1' if dataset_name == 'CHASEDB1' else 'manual1')
        
        image_files = sorted([f for f in os.listdir(image_path) if f.endswith(('.jpg', '.JPG', '.ppm'))])
        mask_files = sorted([f for f in os.listdir(mask_path) if f.endswith(('.png', '.PNG', '.tif', '.TIF'))])
        label_files = sorted([f for f in os.listdir(label_path) if f.endswith(('.ppm', '.PPM', '.png', '.PNG', '.tif', '.TIF'))])
        
        if dataset_name == 'CHASEDB1':
            train_files = image_files[:20]
            test_files = image_files[20:]
        elif dataset_name == 'HRF':
            train_files = image_files[:36]  # 80% of 45 images
            test_files = image_files[36:]
        else:  # STARE
            train_files = image_files[:16]  # ~80% of 20 images
            test_files = image_files[16:]
    
    train_patches_list, train_labels_list = [], []
    val_patches_list, val_labels_list = [], []
    test_images, test_labels = [], []

    # Training and validation patches
    for i, img_file in enumerate(train_files):
        if dataset_name == 'DRIVE':
            img = load_image(os.path.join(train_image_path, img_file))
            label = load_image(os.path.join(train_label_path, train_label_files[i]))
        else:
            img = load_image(os.path.join(image_path, img_file))
            label = load_image(os.path.join(label_path, label_files[i]))
        
        if dataset_name == 'DRIVE':
            img = img[:, 9:574]  # Crop to 565x565
            label = label[:, 9:574]
        
        # Adjust patch numbers based on memory constraints
        patches_per_image = train_patches // len(train_files)
        if dataset_name == 'HRF':
            patches_per_image = 1250  # Reduced from ~6250 (225000/36) to fit memory
        patches, labels = extract_patches(img, label, num_patches=patches_per_image)
        split_idx = int(0.9 * len(patches))
        train_patches_list.append(patches[:split_idx])
        train_labels_list.append(labels[:split_idx])
        val_patches_list.append(patches[split_idx:])
        val_labels_list.append(labels[split_idx:])

    
    train_patches = np.concatenate(train_patches_list, axis=0)
    train_labels = np.concatenate(train_labels_list, axis=0)
    val_patches = np.concatenate(val_patches_list, axis=0)
    val_labels = np.concatenate(val_labels_list, axis=0)

    # Test images (full images for evaluation)
    for i, img_file in enumerate(test_files):
        if dataset_name == 'DRIVE':
            img = load_image(os.path.join(test_image_path, img_file))
            label = load_image(os.path.join(test_label_path, test_label_files[i]))
        else:
            img = load_image(os.path.join(image_path, img_file))
            label = load_image(os.path.join(label_path, label_files[i + len(train_files)]))
        
        if dataset_name == 'DRIVE':
            img = img[:, 9:574]
            label = label[:, 9:574]
        
        test_images.append(img[..., np.newaxis])
        test_labels.append(label[..., np.newaxis] / 255.0)
    
    return train_patches, train_labels, val_patches, val_labels, test_images, test_labels

# Load datasets with original patch numbers and splits

# original from paper
# ('DRIVE', 190000, 171000, 19000)
# ('STARE', 250000, 225000, 25000)
# ('CHASEDB1', 250000, 225000, 25000)
# ('HRF', 250000, 225000, 25000) <- This will cause memory issue and will break code, dont use this

# subset for testing
drive_patches_train, drive_labels_train, drive_patches_val, drive_labels_val, drive_test_images, drive_test_labels = load_dataset('DRIVE', 190000, 171000, 19000)
stare_patches_train, stare_labels_train, stare_patches_val, stare_labels_val, stare_test_images, stare_test_labels = load_dataset('STARE', 250000, 225000, 25000)
chase_patches_train, chase_labels_train, chase_patches_val, chase_labels_val, chase_test_images, chase_test_labels = load_dataset('CHASEDB1', 250000, 225000, 25000)
hrf_patches_train, hrf_labels_train, hrf_patches_val, hrf_labels_val, hrf_test_images, hrf_test_labels = load_dataset('HRF', 45000, 40500, 4500)  # Reduced to 45,000 total

# Print shapes for verification
print(f"DRIVE: Train patches {drive_patches_train.shape}, Val patches {drive_patches_val.shape}, Test images {len(drive_test_images)}")
print(f"STARE: Train patches {stare_patches_train.shape}, Val patches {stare_patches_val.shape}, Test images {len(stare_test_images)}")
print(f"CHASEDB1: Train patches {chase_patches_train.shape}, Val patches {chase_patches_val.shape}, Test images {len(chase_test_images)}")
print(f"HRF: Train patches {hrf_patches_train.shape}, Val patches {hrf_patches_val.shape}, Test images {len(hrf_test_images)}")

## Step 4: Define R2U-Net Model (( t=3 ))

We implement the R2U-Net model with ( t=3 ) (1 forward convolution + 3 recurrent convolutions) using the architecture ( 1 -> 16 -> 32 -> 64 -> 128 -> 64 -> 32 -> 16 -> 1 ). The model includes recurrent residual convolutional units (RRCU) and skip connections with concatenation.

**Output**: Displays the model summary, showing ~1.037M parameters for the ( t=3 ) architecture.

In [None]:
def rrcu_block(inputs, filters, kernel_size=3, t=3):
    """Recurrent Residual Convolutional Unit (RRCU) with t time steps."""
    conv = layers.Conv2D(filters, kernel_size, padding='same', kernel_initializer='he_normal')(inputs)
    x = layers.BatchNormalization()(conv)
    x = layers.ReLU()(x)
    for _ in range(t-1):  # t-1 recurrent convolutions
        conv_r = layers.Conv2D(filters, kernel_size, padding='same', kernel_initializer='he_normal')(x)
        x = layers.Add()([conv, conv_r])  # Residual connection
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
    x = layers.Dropout(0.1)(x)  # Add dropout for regularization
    return x

def r2unet(input_shape=(48, 48, 1)):
    """R2U-Net model with t=3."""
    inputs = layers.Input(input_shape)
    # Encoder
    e1 = rrcu_block(inputs, 16)
    p1 = layers.MaxPooling2D((2, 2))(e1)
    e2 = rrcu_block(p1, 32)
    p2 = layers.MaxPooling2D((2, 2))(e2)
    e3 = rrcu_block(p2, 64)
    p3 = layers.MaxPooling2D((2, 2))(e3)
    e4 = rrcu_block(p3, 128)
    # Decoder
    u3 = layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(e4)
    u3 = layers.Concatenate()([u3, e3])
    d3 = rrcu_block(u3, 64)
    u2 = layers.Conv2DTranspose(32, 2, strides=(2, 2), padding='same')(d3)
    u2 = layers.Concatenate()([u2, e2])
    d2 = rrcu_block(u2, 32)
    u1 = layers.Conv2DTranspose(16, 2, strides=(2, 2), padding='same')(d2)
    u1 = layers.Concatenate()([u1, e1])
    d1 = rrcu_block(u1, 16)
    # Output
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(d1)
    return tf.keras.models.Model(inputs, outputs)  # Explicitly use tf.keras.models.Model


## Step 5: Train R2U-Net on Each Dataset

We train the R2U-Net model separately on DRIVE, STARE, CHASE_DB1, and HRF, using 150 epochs, batch size 16, and binary cross-entropy loss. Checkpoints are saved to handle Kaggle’s session limits.

**Output**: For each dataset, prints training progress and displays plots of training/validation loss and accuracy over 150 epochs (or fewer if early stopping triggers)

In [None]:
# Configure logging
def setup_logging(dataset_name):
    logging.basicConfig(
        filename=f'training_log_{dataset_name}.txt',
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )

def train_model(model, train_patches, train_labels, val_patches, val_labels, dataset_name, r2unet_func):
    """Train R2U-Net and save checkpoints and final model."""
    checkpoint_path = f'r2unet_{dataset_name}_checkpoint.weights.h5'  # Use .weights.h5
    final_model_path = f'r2unet_{dataset_name}_final.keras'
    dice_checkpoint_path = f'r2unet_{dataset_name}_checkpoint_dice.keras'  # Added for clarity
    
    # Setup logging
    setup_logging(dataset_name)
    logging.info(f"Starting training for {dataset_name}")
    
    # Check for existing checkpoint and load if available
    initial_epoch = 0
    if os.path.exists(checkpoint_path):
        try:
            model.load_weights(checkpoint_path)
            logging.info(f"Loaded checkpoint from {checkpoint_path}")
            print(f"Resuming training from checkpoint: {checkpoint_path}")
            initial_epoch = 0  # Adjust based on saved epoch if metadata is available
        except Exception as e:
            logging.error(f"Failed to load checkpoint {checkpoint_path}: {str(e)}")
            print(f"Warning: Failed to load checkpoint {checkpoint_path}. Starting from scratch.")
    
    # Define callbacks
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_path, 
        save_best_only=True, 
        monitor='val_loss', 
        mode='min', 
        save_weights_only=True
    )
    dice_checkpoint = DiceCoefficientCallback(
        dice_checkpoint_path, 
        val_patches=val_patches, 
        val_labels=val_labels, 
        monitor='val_dice_coefficient', 
        mode='max'
    )
    early_stopping = tf.keras.callbacks.EarlyStopping(
        patience=15, 
        restore_best_weights=True, 
        monitor='val_loss', 
        mode='min'
    )
    lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
        factor=0.5, 
        patience=5, 
        min_lr=1e-6, 
        monitor='val_loss', 
        mode='min'
    )
    
    # # Data augmentation
    # datagen = ImageDataGenerator(rotation_range=10, horizontal_flip=True, fill_mode='nearest')
    # datagen.fit(train_patches)
    
    # Compile model with custom loss
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4),
        loss=combined_loss,
        metrics=['accuracy', tf.keras.metrics.MeanSquaredError(), dice_coefficient]
    )
    
    # Train model
    try:
        # history = model.fit(
        #     datagen.flow(train_patches, train_labels, batch_size=8),
        #     validation_data=(val_patches, val_labels),
        #     epochs=80,  # Align with provided code
        #     initial_epoch=initial_epoch,
        #     callbacks=[checkpoint, dice_checkpoint, early_stopping, lr_scheduler]
        # )

        history = model.fit(
            x=train_patches,
            y=train_labels,
            validation_data=(val_patches, val_labels),
            batch_size=8,
            epochs=50,
            initial_epoch=initial_epoch,
            callbacks=[checkpoint, dice_checkpoint, early_stopping, lr_scheduler]
        )
        
        # Log training metrics
        for epoch in range(len(history.history['loss'])):
            log_message = (
                f"Epoch {epoch + 1}: "
                f"loss={history.history['loss'][epoch]:.4f}, "
                f"val_loss={history.history['val_loss'][epoch]:.4f}, "
                f"val_dice_coefficient={history.history['val_dice_coefficient'][epoch]:.4f}"
            )
            logging.info(log_message)
        
        # Save final model
        model.save(final_model_path)
        logging.info(f"Saved final model to {final_model_path}")
        print(f"Saved final model to {final_model_path}")

        # Evaluate all saved models on validation set
        print(f"\n\nEvaluating saved models for {dataset_name} on validation set...")
        saved_models = [
           ('Dice Checkpoint', dice_checkpoint_path, True),  # Full model
           ('Weights Checkpoint', checkpoint_path, False),   # Weights only
           ('Final Model', final_model_path, True)          # Full model
        ]
        
        for model_name, path, is_full_model in saved_models:
           try:
               if is_full_model:
                   eval_model = tf.keras.models.load_model(
                       path,
                       custom_objects={'combined_loss': combined_loss, 'dice_coefficient': dice_coefficient}
                   )
                   print(f"Loaded full model from {path}")
               else:
                   eval_model = r2unet_func()  # Recreate model using provided function
                   eval_model.compile(
                       optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4),
                       loss=combined_loss,
                       metrics=['accuracy', tf.keras.metrics.MeanSquaredError(), dice_coefficient]
                   )
                   eval_model.load_weights(path)
                   print(f"Loaded weights from {path}")
               
               # Predict on validation patches
               val_pred = eval_model.predict(val_patches, verbose=0)
               
               # Compute metrics
               metrics = compute_metrics(val_labels, val_pred)
               print(f"\n{dataset_name} {model_name} Validation Metrics:")
               for metric_name, value in metrics.items():
                   print(f"{metric_name}: {value:.4f}")
               
               # Clear session to free memory
               tf.keras.backend.clear_session()
               
           except FileNotFoundError:
               print(f"Error: {model_name} file {path} not found. Skipping evaluation.")
           except Exception as e:
               print(f"Error evaluating {model_name} from {path}: {str(e)}")

        return history
    
    except Exception as e:
        logging.error(f"Training failed for {dataset_name}: {str(e)}")
        print(f"Error: Training failed for {dataset_name}: {str(e)}")
        raise

# Train on each dataset
# datasets = {
#     'DRIVE': (drive_patches_train, drive_labels_train, drive_patches_val, drive_labels_val),
#     'STARE': (stare_patches_train, stare_labels_train, stare_patches_val, stare_labels_val),
#     'CHASEDB1': (chase_patches_train, chase_labels_train, chase_patches_val, chase_labels_val),
#     'HRF': (hrf_patches_train, hrf_labels_train, hrf_patches_val, hrf_labels_val)
# }

# Train separate models for each dataset
datasets = {
    'DRIVE': (drive_patches_train, drive_labels_train, drive_patches_val, drive_labels_val)
}

models_dict = {}
histories = {}
for dataset_name, (train_patches, train_labels, val_patches, val_labels) in datasets.items():
    print(f"Training on {dataset_name}...")
    model = r2unet()
    # Compile model with custom loss and metrics
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4),
        loss=combined_loss,
        metrics=['accuracy', tf.keras.metrics.MeanSquaredError(), dice_coefficient]
    )
    histories[dataset_name] = train_model(model, train_patches, train_labels, val_patches, val_labels, dataset_name,r2unet_func=r2unet)    
    models_dict[dataset_name] = model
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(histories[dataset_name].history['loss'], label='Train Loss')
    plt.plot(histories[dataset_name].history['val_loss'], label='Val Loss')
    plt.title(f'{dataset_name} Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(histories[dataset_name].history['accuracy'], label='Train Accuracy')
    plt.plot(histories[dataset_name].history['val_accuracy'], label='Val Accuracy')
    plt.title(f'{dataset_name} Accuracy')
    plt.legend()
    plt.show()

## Step 6: Evaluate and Visualize Results

We evaluate the trained model on the test sets of each dataset, computing all metrics (AC, SE, SP, F1, DC, JS, AUC) and visualizing sample predictions.

**Output**: For each dataset, displays three sample predictions (input, ground truth, prediction) and prints average metrics (AC, SE, SP, F1, DC, JS, AUC).

In [None]:
def evaluate_model(model, test_images, test_labels, dataset_name):
    """Evaluate model on test images using patch-based approach and visualize results."""
    metrics_list = []
    for i, (img, label) in enumerate(zip(test_images, test_labels)):
        # Ensure input has channel dimension if missing
        if len(img.shape) == 2:
            img = img[..., np.newaxis]
        if len(label.shape) == 2:
            label = label[..., np.newaxis]
        
        # Extract patches
        patches, h, w = extract_patches_for_eval(img)
        if len(patches) == 0:
            print(f"Warning: No patches extracted for {dataset_name} sample {i+1}. Skipping.")
            continue
        
        # Predict on patches
        pred_patches = model.predict(patches, verbose=0)
        
        # Reconstruct predicted image
        pred = reconstruct_image(pred_patches, h, w)
        
        # Ensure label and pred have compatible shapes
        if pred.shape != label.shape:
            pred = cv2.resize(pred, (label.shape[1], label.shape[0]), interpolation=cv2.INTER_NEAREST)
        
        # Compute metrics
        metrics = compute_metrics(label, pred)
        metrics_list.append(metrics)
        
        # Visualize first few samples (full images)
        if i < 5:
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.imshow(img[:, :, 0], cmap='gray')
            plt.title('Input Image')
            plt.subplot(1, 3, 2)
            plt.imshow(label[:, :, 0], cmap='gray')
            plt.title('Ground Truth')
            plt.subplot(1, 3, 3)
            plt.imshow(pred[:, :, 0] > 0.5, cmap='gray')
            plt.title('Prediction')
            plt.suptitle(f'{dataset_name} Test Sample {i+1}')
            plt.tight_layout()
            plt.show()
    
    # Average metrics
    if metrics_list:
        avg_metrics = {key: np.mean([m[key] for m in metrics_list]) for key in metrics_list[0]}
    else:
        avg_metrics = {key: 0.0 for key in ['AC', 'SE', 'SP', 'F1', 'DC', 'JS', 'AUC']}
    return avg_metrics


# Evaluate on each dataset using saved models
# test_datasets = {
#     'DRIVE': (drive_test_images, drive_test_labels),
#     'STARE': (stare_test_images, stare_test_labels),
#     'CHASEDB1': (chase_test_images, chase_test_labels),
#     'HRF': (hrf_test_images, hrf_test_labels)
# }

# Evaluate on each dataset using saved models
test_datasets = {
    'DRIVE': (drive_test_images, drive_test_labels)
}

for dataset_name, (test_images, test_labels) in test_datasets.items():
    
    print(f"Evaluating on {dataset_name}...")
    
    # Try loading the best model based on val_dice_coefficient
    # 1) r2unet_DRIVE_checkpoint.weights.h5: Saved by the ModelCheckpoint callback when the validation loss (val_loss) improves. This represents the model weights with the lowest validation loss during training.
    # 2) r2unet_DRIVE_checkpoint_dice.keras: Saved by the DiceCoefficientCallback when the validation Dice Coefficient (val_dice_coefficient) improves. This is a full model file containing the best model based on Dice performance.
    # 3) r2unet_DRIVE_final.keras: Saved at the end of training, representing the final model state after 2 epochs, regardless of performance

    
    dice_checkpoint_path = f'r2unet_{dataset_name}_checkpoint_dice.keras'
    weights_checkpoint_path = f'r2unet_{dataset_name}_checkpoint.weights.h5'
    final_model_path = f'r2unet_{dataset_name}_final.keras'
    
    try:
        model = tf.keras.models.load_model(
            dice_checkpoint_path,
            custom_objects={'combined_loss': combined_loss, 'dice_coefficient': dice_coefficient}
        )
        print(f"Loaded best model from {dice_checkpoint_path}")
        metrics = evaluate_model(model, test_images, test_labels, dataset_name)
        print(f"{dataset_name} Metrics: {metrics}")
    except FileNotFoundError:
        print(f"Error: Model file {dice_checkpoint_path} not found. Trying weights checkpoint...")
        try:
            model = r2unet()  # Initialize model
            model.compile(
                optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4),
                loss=combined_loss,
                metrics=['accuracy', tf.keras.metrics.MeanSquaredError(), dice_coefficient]
            )
            model.load_weights(weights_checkpoint_path)
            print(f"Loaded weights from {weights_checkpoint_path}")
            metrics = evaluate_model(model, test_images, test_labels, dataset_name)
            print(f"{dataset_name} Metrics: {metrics}")
        except FileNotFoundError:
            print(f"Error: Weights file {weights_checkpoint_path} not found. Trying final model...")
            try:
                model = tf.keras.models.load_model(
                    final_model_path,
                    custom_objects={'combined_loss': combined_loss, 'dice_coefficient': dice_coefficient}
                )
                print(f"Loaded final model from {final_model_path}")
                metrics = evaluate_model(model, test_images, test_labels, dataset_name)
                print(f"{dataset_name} Metrics: {metrics}")
            except FileNotFoundError:
                print(f"Error: Final model file {final_model_path} not found. Please train and save the model first.")