# **Retinal Blood Vessel Segmentation with Attention U-Net**

# 1. Setup and data preparation


*   Connected to Google Drive and downloaded the DRIVE dataset from Kaggle.

*   Data Organization: Created a specific directory in Google Drive and copied the downloaded dataset there

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
from PIL import Image
from google.colab import drive
import glob
import itertools
import kagglehub
import shutil
import tqdm

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
import tensorflow.keras.backend as K
from tensorflow.keras.preprocessing.image import load_img
import cv2


In [None]:
# Step 1: Mount Google Drive (adjust the path to where your DRIVE dataset is stored)
drive.mount('/content/drive')


In [None]:
# Download the dataset
path = kagglehub.dataset_download("zionfuo/drive2004")
print("Path to dataset files:", path)

# Set target directory in your Google Drive
target_dir = "/content/drive/MyDrive/DRIVE_2004"

# Delete target_dir if it exists
if os.path.exists(target_dir):
    shutil.rmtree(target_dir)

# Recreate the directory
os.makedirs(target_dir)

# Copy the files to Google Drive
shutil.copytree(path, target_dir, dirs_exist_ok=True)
print(f"Dataset copied to: {target_dir}")


# 2. Helpers



*   **Format Conversion**: Developed a utility function to convert GIF image files to TIF format, addressing potential compatibility issues with other libraries.\
Functions: `convert_gif_to_tif` (uses `PIL.Image`).

*   **Image Loading**: Created a function to load image files, specifically extracting the green channel (though later modified to load as grayscale).\
Functions: `load_image` (uses `cv2.imread`).

*   **Mask Loading**: Implemented a function to load mask files and process them into binary representations.\
Functions: `load_mask` (uses `cv2.imread`).

In [None]:
def convert_gif_to_tif(src_folder, dest_folder):
    """Convert all .gif files in src_folder to .tif files in dest_folder, then delete the original .gif."""
    if not os.path.exists(dest_folder):
        os.makedirs(dest_folder)
    for filename in os.listdir(src_folder):
        if filename.lower().endswith('.gif'):
            img_path = os.path.join(src_folder, filename)
            with Image.open(img_path) as img:
                base_name = os.path.splitext(filename)[0]
                new_filename = base_name + '.tif'
                new_path = os.path.join(dest_folder, new_filename)
                img.save(new_path)
            # Remove original gif file
            os.remove(img_path)


Maybe I will try it(Actually, I did not since it messed up the input image, biasing towards much more false positives.)

In [None]:
# create your CLAHE object once
# _clahe = cv2.createCLAHE( clipLimit=2.0, tileGridSize=(8,8) )


In [None]:
def load_image(filepath):
    """
    Load an image using OpenCV and extract the green channel.
    Note: cv2.imread loads images in BGR order.
    """
    # img = cv2.imread(filepath, cv2.IMREAD_COLOR)
    img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Could not load image: {filepath}")
    # green_channel = img[:, :, 1]  # extract green channel
    # green_channel = _clahe.apply(green_channel)
    # return green_channel
    return img


# def load_image(filepath, target_size=(512,512)):
#     img = cv2.imread(filepath, cv2.IMREAD_COLOR)
#     if img is None:
#         raise ValueError(f"Could not load image: {filepath}")
#     # 1) resize
#     img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
#     # 2) extract green channel
#     green = img[..., 1]
#     return green[..., np.newaxis]  # shape=(512,512,1)


In [None]:
def load_mask(filepath):
    img = cv2.imread(filepath, cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError(f"Could not load image: {filepath}")
    if img.ndim == 3:
        img = img[...,0]
    img = (img > 128).astype(np.uint8)
    return img[..., np.newaxis]


# def load_mask(filepath):
#     """
#     Load a mask using PIL (to support .gif files), convert to a numpy array,
#     and threshold it to create a binary mask.
#     """
#     with Image.open(filepath) as img:
#         mask = np.array(img)
#     # If the mask has 3 channels, take the first channel
#     if mask.ndim == 3:
#         mask = mask[:, :, 0]
#     # Convert to binary mask (assuming values >128 are foreground)
#     mask = (mask > 128).astype(np.uint8)
#     return mask


# 3. Loading



*   Converted GIF files in the training and test mask directories to TIF format.\
`convert_gif_to_tif`() was called on `train_mask_dir`, and `test_mask_dir`.

*   Tested loading a sample training image and its corresponding mask.\
`load_image()`, `load_mask()`, and
`plt.imshow()` was used to display the loaded images.
The output shows the shapes of the loaded training image `(584, 565)` and mask `(584, 565, 1)`, confirming successful loading.



    
        
    
        


In [None]:
train_image_dir = "/content/drive/MyDrive/DRIVE_2004/DRIVE/training/images/"
train_mask_dir  = "/content/drive/MyDrive/DRIVE_2004/DRIVE/training/1st_manual/"

test_image_dir = "/content/drive/MyDrive/DRIVE_2004/DRIVE/test/images/"
test_mask_dir  = "/content/drive/MyDrive/DRIVE_2004/DRIVE/test/1st_manual/"


In [None]:
# Since OpenCV is not good with gif files
convert_gif_to_tif(train_mask_dir, train_mask_dir)
convert_gif_to_tif(test_mask_dir, test_mask_dir)

In [None]:
# Test loading one image
train_image_path = train_image_dir + '22_training.tif'
train_image = load_image(train_image_path)
plt.figure(figsize=(6,6))
plt.imshow(train_image, cmap='gray')
plt.title("Training Image Sample")
plt.axis("off")
plt.show()
train_image.shape

In [None]:
# Test loading one manual annotation mask
mask_path = train_mask_dir + '22_manual1.tif'
mask = load_mask(mask_path)
plt.figure(figsize=(6,6))
plt.imshow(mask, cmap='gray')
plt.title("Manual Annotation Mask")
plt.axis("off")
plt.show()
mask.shape

# 4. U-Net



*   Defined `double_conv_block` function for two convolutional layers with ReLU activation.\
`layers.Conv2D()` with `kernel_size=3`, `padding="same"`, `activation="relu"`

*   Defined `downsample_block` function which applies the double convolution block, followed by MaxPooling and Dropout.\
`layers.MaxPooling2D(pool_size=(2, 2))`, `layers.Dropout(0.2)`

*   Defined `upsample_block` function for the decoder path which performs transposed convolution, concatenates with skip features, and applies the double convolution block.\
`layers.Conv2DTranspose()`, `layers.concatenate()`



    
        
    
        
    
        


In [None]:
def double_conv_block(x, n_filters):
  # builds two convolutional layers, with n_filters.  Let's use a filter size 3x3
  x = layers.Conv2D(n_filters, kernel_size=3, padding="same", activation="relu")(x)
  x = layers.Conv2D(n_filters, kernel_size=3, padding="same", activation="relu")(x)
  return x


In [None]:
def downsample_block(x, n_filters):
  f = double_conv_block(x, n_filters)
  p = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(f)
  ###### The second most effective thing, adding the dropout after seeing overfitting
  p = layers.Dropout(0.2)(p)
  return f, p  # return both for skip connection (f) and pooling output (p)


In [None]:
def upsample_block(x, conv_features, n_filters):
    # Step 1: Upsample the input feature map by a factor of 2 using Conv2DTranspose
    # - kernel_size = 3
    # - strides = 2 (to double the spatial dimensions)
    # - padding = "same" to maintain output size compatibility
    x = layers.Conv2DTranspose(n_filters, kernel_size=3, strides=2, padding="same")(x)

    # Step 2: Concatenate the upsampled feature map with the corresponding encoder feature map (skip connection)
    # - axis=-1 to concatenate along the channel axis
    x = layers.concatenate([x, conv_features], axis=-1)

    # Step 3: Apply two convolutional layers using the double_conv_block function
    x = double_conv_block(x, n_filters)

    return x


In [None]:
def Unet():
  # Input layer
  inputs = layers.Input(shape=(128, 128, 1))   # 1 channel

  # Encoder - Contracting Path
  f1, p1 = downsample_block(inputs, 64)   # 128 -> 64
  f2, p2 = downsample_block(p1, 128)      # 64 -> 32
  f3, p3 = downsample_block(p2, 256)      # 32 -> 16
  f4, p4 = downsample_block(p3, 512)      # 16 -> 8

  # Bottleneck
  b = double_conv_block(p4, 1024)         # 8x8

  # Decoder - Expanding Path
  u6 = upsample_block(b, f4, 512)         # 8 -> 16
  u7 = upsample_block(u6, f3, 256)        # 16 -> 32
  u8 = upsample_block(u7, f2, 128)        # 32 -> 64
  u9 = upsample_block(u8, f1, 64)         # 64 -> 128

  # Output layer: softmax for multi-class segmentation (3 classes in this example)
  outputs = layers.Conv2D(1, 1, activation="sigmoid")(u9)

  return tf.keras.Model(inputs, outputs, name="U-Net")


# 5. Attention



*   Implemented `gating_signal` for the attention mechanism.\

    `layers.Conv2D()`, `layers.BatchNormalization()`, `layers.Activation('relu')`


*   Implemented `attention_block` to apply additive attention to filter skip connections\

    `layers.Conv2D()`, `layers.Conv2DTranspose()`, `layers.add()`, `layers.Activation('relu')`, `layers.Conv2D(1, ...)` (for attention coefficients), `layers.UpSampling2D()`, `layers.Lambda()` (for tiling), `layers.multiply()`, `layers.Conv2D(..., 1, ...)` (for final projection), `layers.BatchNormalization()`


*   Modified the `upsample_block` to incorporate the attention mechanism by using `gating_signal` and `attention_block`.\


*   Defined the `Unet` model incorporating the attention blocks in the decoder path.\

    *    `layers.Input()`
    *    Sequential application of `downsample_block`, `double_conv_block` (bottleneck), and `upsample_block`.
    *    Final `layers.Conv2D(1, 1, activation="sigmoid")` for the output layer.


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K

# --------------------------------------------------
#   Base convolutional block: two 3x3 conv layers
# --------------------------------------------------
def double_conv_block(x, n_filters):
    x = layers.Conv2D(n_filters, kernel_size=3, padding="same", activation="relu")(x)
    x = layers.Conv2D(n_filters, kernel_size=3, padding="same", activation="relu")(x)
    return x

# --------------------------------------------------
#   Attention components
# --------------------------------------------------

def gating_signal(x, out_size):
    '''1x1 conv + BN + ReLU to project decoder feature map as gating signal.'''
    g = layers.Conv2D(out_size, kernel_size=1, padding='same')(x)
    g = layers.BatchNormalization()(g)
    g = layers.Activation('relu')(g)
    return g


def attention_block(x, gating, inter_channels):
    '''Additive attention gate to filter skip-connection features.'''
    # Project encoder features (x) down
    theta_x = layers.Conv2D(inter_channels, kernel_size=2, strides=2, padding='same')(x)

    # Project gating signal
    phi_g = layers.Conv2D(inter_channels, kernel_size=1, padding='same')(gating)
    # Upsample gating to match theta_x's spatial dimensions
    shape_theta = K.int_shape(theta_x)
    shape_phi = K.int_shape(phi_g)
    up_phi = layers.Conv2DTranspose(
        inter_channels,
        kernel_size=3,
        strides=(shape_theta[1] // shape_phi[1], shape_theta[2] // shape_phi[2]),
        padding='same'
    )(phi_g)

    # Fuse and activate
    concat_xg = layers.add([up_phi, theta_x])
    act_xg = layers.Activation('relu')(concat_xg)

    # Generate attention coefficients
    psi = layers.Conv2D(1, kernel_size=1, padding='same')(act_xg)
    sigmoid_xg = layers.Activation('sigmoid')(psi)

    # Upsample coefficients to original encoder size
    shape_sig = K.int_shape(sigmoid_xg)
    shape_x = K.int_shape(x)
    up_psi = layers.UpSampling2D(
        size=(shape_x[1] // shape_sig[1], shape_x[2] // shape_sig[2])
    )(sigmoid_xg)

    # Tile across channels
    up_psi = layers.Lambda(
        lambda z, rep: K.repeat_elements(z, rep, axis=3),
        arguments={'rep': shape_x[3]}
    )(up_psi)

    # Apply attention to encoder features
    y = layers.multiply([up_psi, x])

    # Final linear projection
    result = layers.Conv2D(shape_x[3], kernel_size=1, padding='same')(y)
    result = layers.BatchNormalization()(result)
    return result

# --------------------------------------------------
#   Downsampling block remains unchanged
# --------------------------------------------------

def downsample_block(x, n_filters):
    f = double_conv_block(x, n_filters)
    p = layers.MaxPooling2D(pool_size=(2, 2))(f)
    p = layers.Dropout(0.2)(p)
    return f, p

# --------------------------------------------------
#   Upsampling block with attention
# --------------------------------------------------

def upsample_block(x, conv_features, n_filters):
    # 1) Compute gating signal from pre-upsample decoder features
    g = gating_signal(x, n_filters)
    # 2) Upsample the decoder feature map
    x = layers.Conv2DTranspose(n_filters, kernel_size=3, strides=2, padding="same")(x)
    # 3) Apply attention gate on skip features
    attn_feats = attention_block(conv_features, g, inter_channels=n_filters)
    # 4) Concatenate and apply double convolution
    x = layers.concatenate([x, attn_feats], axis=-1)
    x = double_conv_block(x, n_filters)
    return x

# --------------------------------------------------
#   U-Net with Attention Gates
# --------------------------------------------------

def Unet(input_shape=(128, 128, 1)):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    f1, p1 = downsample_block(inputs, 64)
    f2, p2 = downsample_block(p1, 128)
    f3, p3 = downsample_block(p2, 256)
    f4, p4 = downsample_block(p3, 512)

    # Bottleneck
    b = double_conv_block(p4, 1024)

    # Decoder with attention at each skip connection
    u6 = upsample_block(b, f4, 512)
    u7 = upsample_block(u6, f3, 256)
    u8 = upsample_block(u7, f2, 128)
    u9 = upsample_block(u8, f1, 64)

    # Output layer
    outputs = layers.Conv2D(1, kernel_size=1, activation="sigmoid")(u9)

    model = models.Model(inputs, outputs, name="Attention_U-Net")
    return model

# 6. Patches



*   Using patches to preserve thin vessel structures and increase training samples. **Since the vessel annotation is so thin. The information gets distorted and sometimes lost during the resizing**. Rather than warping the whole image, we can extract many 128x128 overlapping patches from my 584×565 frames. That means:

  *    No global resize → vessel shapes stay true.

  *    We get more training samples.

  *    During inference we tile-and-stitch back together.



*   Defined `extract_patches` function to extract overlapping patches from full-size images and masks.

  *    Uses nested loops to slide a window of `patch_size` with a specified `stride`.
  *    Uses `np.stack()` to combine patches into a single array.

*   Defined `load_and_prep_image` and `load_and_prep_mask` to load full-resolution images and masks and ensure they have a channel dimension.
  *   `load_image()`, `load_mask() `
  *   `np.newaxis` to add a channel dimension


In [None]:
def extract_patches(img, mask, patch_size=128, stride=8):
    H, W = img.shape[:2]
    patches_img, patches_msk = [], []

    # Slide window
    for y in range(0, H - patch_size + 1, stride):
        for x in range(0, W - patch_size + 1, stride):
            patch_img = img[y:y+patch_size, x:x+patch_size]
            patch_msk = mask[y:y+patch_size, x:x+patch_size]

            # sanity check: both should be exactly patch_size²
            assert patch_img.shape == (patch_size, patch_size, 1)
            assert patch_msk.shape == (patch_size, patch_size, 1)

            patches_img.append(patch_img)
            patches_msk.append(patch_msk)

    # Stack into big array: shape = (num_patches,128,128,1)
    return np.stack(patches_img, axis=0), np.stack(patches_msk, axis=0)

In [None]:
def load_and_prep_image(fp):
    img = load_image(fp)             # e.g. shape=(584,565)
    if img.ndim == 2:
        img = img[..., np.newaxis]   # now (584,565,1)
    return img.astype("uint8")       # keep uint8 until normalization later

def load_and_prep_mask(fp):
    msk = load_mask(fp)              # e.g. shape=(584,565)
    if msk.ndim == 2:
        msk = msk[..., np.newaxis]   # now (584,565,1)
    return msk                        # uint8 {0,1}


# 7. Train and Test


*   **Data Preparation**: Loaded file paths for training images and masks and then extracted patches from all training data using the defined `extract_patches` function.


*   **Data Formatting**: Concatenated all extracted patches and normalized the image data.
  *    Outputs: Printed the shape of the final training set, indicating a significant number of patches (e.g., (63800, 128, 128, 1)).

*   **Sample Visualization**: Implemented and used a function to display pairs of image patches and their corresponding masks, visually validating the patch extraction.
  *    Functions: `show_sample` (uses `matplotlib.pyplot`).
  *    Outputs: Displayed sample patches.


In [None]:
# 1. Gather your file‐lists (sorted so patches line up with masks)
train_imgs = sorted(glob.glob(os.path.join(train_image_dir, "*_training.tif")))
train_msks = sorted(glob.glob(os.path.join(train_mask_dir,  "*_manual1.tif")))

all_X, all_y = [], []

for img_fp, msk_fp in zip(train_imgs, train_msks):
    # a) load full-res image + mask
    img = load_and_prep_image(img_fp)    # (584,565,1)
    msk = load_and_prep_mask(msk_fp)     # (584,565,1)

    # b) extract overlapping patches
    Xp, yp = extract_patches(img, msk, patch_size=128, stride=8)
    #    Xp.shape = (n_patches_image, 128,128,1)
    #    yp.shape = same

    all_X.append(Xp)
    all_y.append(yp)

# 2. Concatenate along the “example” axis
X_train = np.concatenate(all_X, axis=0)   # shape = (total_patches,128,128,1)
y_train = np.concatenate(all_y, axis=0)   # shape = (total_patches,128,128,1)

# 3. Convert to float32 & normalize images
X_train = X_train.astype("float32") / 255.0
y_train = y_train.astype("float32")       # keep 0 or 1

print("Final training set:", X_train.shape, y_train.shape)
# e.g. (1600,128,128,1) if you got 100 patches per image × 16 images


In [None]:
import matplotlib.pyplot as plt

def show_sample(X, y, index=0):
    """Visualize one input patch and its corresponding mask."""
    fig, axs = plt.subplots(1, 2, figsize=(6, 3))
    axs[0].imshow(X[index].squeeze(), cmap='gray')
    axs[0].set_title("Input Image")
    axs[0].axis("off")

    axs[1].imshow(y[index].squeeze(), cmap='gray')
    axs[1].set_title("Mask")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()

# Example: show the 0th patch
for i in range(10):
  show_sample(X_train, y_train,
              index= random.randint(0, len(X_train)),
              # index = i,
              )


# 8. Loss function and training

1. Loss components

    Implemented **Binary Focal Loss**.
        FL = − α y (1−p)^γ log p − (1−α)(1−y) p^γ log(1−p)
        α = 0.9 gives vessels 9× weight; γ = 7 forces focus on hard pixels.
    Implemented **Dice coefficient** and **Dice loss (1 − Dice)**.
        Dice = (2 |Y∩Ŷ| + ε) / (|Y| + |Ŷ| + ε)
        Balances classes and rewards global overlap.

2. Combined loss

    Added the two terms.
        loss = Focal(α = 0.9, γ = 7) + (1 − Dice)
        Pixel-level hardness + structure-level overlap share the same optimum.

3. Training dynamics

    Early epochs: Focal dominates.
    Mid-late epochs: Dice ramps up; closes gaps and enforces continuity.
    Convergence: both terms → 0 only when every vessel pixel is captured and background is clean.


In [None]:
import tensorflow as tf

_EPS = 1e-6
_SMOOTH = 1e-6

# ─────────────────────────────────────────────────────────────
# 2) Binary Focal Loss (fixed shape handling)
# ─────────────────────────────────────────────────────────────

def binary_focal_loss(alpha=0.25, gamma=2.0, from_logits=False):
    """
    FL = - alpha * y_true * (1-p)^γ * log(p)
         - (1-alpha)* (1-y_true) * p^γ * log(1-p)

    Returns a callable to pass into model.compile.
    """
    def loss_fn(y_true, y_pred):
        # If logits, convert to probabilities:
        if from_logits:
            y_pred = tf.sigmoid(y_pred)

        # Clip to avoid log(0) instabilities
        y_pred = tf.clip_by_value(y_pred, _EPS, 1.0 - _EPS)

        # compute cross-entropy per-pixel
        ce = - (y_true * tf.math.log(y_pred) +
                (1.0 - y_true) * tf.math.log(1.0 - y_pred))
        # p_t = p if y_true=1 else (1-p)
        p_t = tf.where(tf.equal(y_true, 1.0), y_pred, 1.0 - y_pred)

        # focal weight
        alpha_factor = tf.where(tf.equal(y_true, 1.0),
                                alpha,
                                1.0 - alpha)
        modulating_factor = tf.pow(1.0 - p_t, gamma)

        # combine
        loss = alpha_factor * modulating_factor * ce

        # average over all pixels & batch
        return tf.reduce_mean(loss)

    return loss_fn

In [None]:
# Combined: focal + dice
def dice_coef(y_true, y_pred):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    inter = tf.reduce_sum(y_true_f * y_pred_f)
    return (2.*inter + _SMOOTH)/(tf.reduce_sum(y_true_f)+tf.reduce_sum(y_pred_f)+_SMOOTH)



*   **Defined Keras Callbacks**: `ModelCheckpoint` to save the best model based on validation accuracy and `EarlyStopping` to stop training when validation accuracy stops improving

  *   `tf.keras.callbacks.ModelCheckpoint(...)`
  *   `tf.keras.callbacks.EarlyStopping(...)`


In [None]:
import tensorflow as tf

# 1. ModelCheckpoint: save only the best model on validation loss
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/best_model.h5',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    save_weights_only=False,
    verbose=1
)

# 2. EarlyStopping: stop when val_loss stops improving
earlystop_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=3,
    mode='max',
    restore_best_weights=True,
    verbose=1
)


Compiled the model using the Adam optimizer and a combined loss function (Binary Focal Loss + Dice Loss).

In [None]:
unet = Unet()

# wnetloss=binary_focal_loss(alpha=0.25, gamma=2.0)
unet.compile(
  optimizer="adam",
  # loss='binary_crossentropy',
  # loss=binary_focal_loss(alpha=0.9, gamma=5),
  loss=lambda yt, yp: binary_focal_loss(alpha=0.9, gamma=7.0)(yt, yp) + (1.-dice_coef(yt, yp)),
  # loss=binary_focal_loss(alpha=0.75, gamma=2.5),
  metrics=['accuracy']
)
unet.summary()



The training output shows metrics (accuracy and loss) for each epoch on both the training and validation sets. It also shows when the model checkpoint is saved (when `val_accuracy` improves). The model trained for 20 epochs, and validation accuracy generally improved over time, reaching a peak of 0.97764.

In [None]:
# unet.fit(
#     X_train, y_train,
#     validation_split=0.05,
#     batch_size=8,
#     epochs=5,
#     shuffle=True
# )

unet.fit(
    X_train,
    y_train,
    validation_split=0.01,
    # batch_size=4,
    batch_size=8,
    epochs=20,
    shuffle=True,
    callbacks=[checkpoint_cb]

)

# 9. Inference

In order to run the model by yourself, you should download the h5 file in the /content/drive/MyDrive directory and mount your google drive on colab. Below I did as such.

In [None]:
!ls /content/drive/MyDrive/ | grep 'best_model'

In [None]:
from google.colab import files
files.download('checkpoints/best_model.h5')


In [None]:
shutil.move('checkpoints/best_model.h5', '/content/drive/MyDrive/best_model_Unet_RBVS.h5')


In [None]:
unet = Unet()  # rebuild architecture exactly
unet.load_weights('/content/drive/MyDrive/best_model_Unet_RBVS.h5')


I did not save the history of my training and I had to plot it using the log

In [None]:
import matplotlib.pyplot as plt

# Manually parsed data from the log
epochs = list(range(1, 21))
accuracy = [0.9401, 0.9784, 0.9851, 0.9878, 0.9893, 0.9904, 0.9912, 0.9917, 0.9921, 0.9924, 0.9927, 0.9929, 0.9931, 0.9933, 0.9935, 0.9936, 0.9937, 0.9938, 0.9939, 0.9940]
val_accuracy = [0.9695, 0.9734, 0.9737, 0.9698, 0.9757, 0.9729, 0.9726, 0.9742, 0.9758, 0.9759, 0.9766, 0.9770, 0.9770, 0.9773, 0.9766, 0.9758, 0.9773, 0.9768, 0.9758, 0.9776]
loss = [0.3007, 0.1259, 0.0880, 0.0727, 0.0647, 0.0578, 0.0541, 0.0509, 0.0485, 0.0466, 0.0450, 0.0438, 0.0425, 0.0413, 0.0404, 0.0394, 0.0389, 0.0383, 0.0377, 0.0374]
val_loss = [0.1870, 0.1868, 0.2021, 0.3013, 0.2054, 0.2670, 0.2703, 0.2444, 0.2195, 0.2305, 0.2055, 0.1951, 0.2019, 0.2007, 0.1958, 0.2233, 0.1783, 0.2048, 0.2235, 0.1840]

# Plot
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs, accuracy, label='Train Accuracy')
plt.plot(epochs, val_accuracy, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(epochs, loss, label='Train Loss')
plt.plot(epochs, val_loss, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
plot_history(history)



*   Defined `predict_full` function to predict the segmentation mask for a full-size image by extracting patches, running inference on each patch, and stitching the results back together
  *    Handles padding using `np.pad(..., mode="reflect")`.
  *    Uses nested loops to iterate through patch locations.
  *    `model.predict()` on individual patches.
  *    Accumulates predictions in `sum_probs` and `count_map`.
  *    Normalizes accumulated probabilities by `count_map`.
  *    Crops the result back to the original image size.
  *    Applies a `threshold` to get a binary mask.(not needed after a few iterations since the confidence of my model got higher and higher. At the end, I set it to 0.99)
  

*   Defined `predict_full_debug` which is similar to `predict_full` but also returns a list of individual patch predictions for debugging.

*   Defined `show_patches` function to visualize individual image patches, their predicted probability maps, and the corresponding ground truth patches side-by-side.


*   Selected a test image and its ground truth mask (image index 18).



In [None]:
import math
import numpy as np
import tensorflow as tf

def predict_full(image,
                 model,
                 patch_size=128,
                #  patch_size=64,
                 stride=8,
                 threshold=0.5):
    """
    image: H×W×1 NumPy array (dtype uint8 [0–255] or float32 [0–1])
    model: your trained U-Net
    Returns: binary mask H×W uint8 array
    """
    # 1) Ensure float32 [0,1]
    img = image.astype("float32")
    if img.max() > 1.0:
        img /= 255.0

    H, W = img.shape[:2]

    # 2) Compute padding so that (H_pad - patch_size) % stride == 0
    #    and similarly for W. This ensures an integer number of steps.
    n_steps_h = math.ceil((H - patch_size) / stride) + 1
    n_steps_w = math.ceil((W - patch_size) / stride) + 1
    H_pad = (n_steps_h - 1) * stride + patch_size
    W_pad = (n_steps_w - 1) * stride + patch_size

    pad_h = H_pad - H
    pad_w = W_pad - W

    # Use “reflect” to avoid zero-border artifacts
    img_p = np.pad(img,
                   ((0, pad_h), (0, pad_w), (0,0)),
                   mode="reflect")

    # 3) Prepare accumulators
    sum_probs = np.zeros((H_pad, W_pad), dtype=np.float32)
    count_map = np.zeros((H_pad, W_pad), dtype=np.uint8)

    # 4) Slide over patches
    for y in range(0, H_pad - patch_size + 1, stride):
        for x in range(0, W_pad - patch_size + 1, stride):
            patch = img_p[y:y+patch_size, x:x+patch_size, :]
            # Add batch dimension
            patch_in = np.expand_dims(patch, axis=0)  # shape (1,128,128,1)
            prob = model.predict(patch_in, verbose=0)[0, ..., 0]

            # preds is a list: [array_of_out1, array_of_out2]
            # prob = model.predict(patch_in, verbose=0)[1][0, ..., 0] # pick the second output, first batch, last channel
            # prob shape: (128,128), values in [0,1]


            # Accumulate
            sum_probs[y:y+patch_size, x:x+patch_size] += prob
            count_map[y:y+patch_size, x:x+patch_size] += 1

    # 5) Normalize by the number of times each pixel was covered
    avg_probs = sum_probs / count_map

    # 6) Crop back to original size
    avg_probs = avg_probs[:H, :W]

    # 7) Threshold to binary mask
    bin_mask = (avg_probs >= threshold).astype(np.uint8)
    return bin_mask

# ——————————————————————————————————————————————————————————————


In [None]:
import math
import numpy as np
import tensorflow as tf

def predict_full_debug(image,
                       model,
                       patch_size=128,
                       stride=64,
                       threshold=0.5):
    """
    Returns:
      bin_mask: H×W uint8 array (aggregated & thresholded)
      patch_preds: list of tuples (y, x, prob_map) for each patch
    """
    # --- prep ---
    img = image.astype("float32")
    if img.max() > 1.0:
        img /= 255.0
    H, W = img.shape[:2]

    # compute padded dims
    n_h = math.ceil((H - patch_size) / stride) + 1
    n_w = math.ceil((W - patch_size) / stride) + 1
    H_pad = (n_h - 1)*stride + patch_size
    W_pad = (n_w - 1)*stride + patch_size
    pad_h, pad_w = H_pad - H, W_pad - W
    img_p = np.pad(img, ((0,pad_h),(0,pad_w),(0,0)), mode="reflect")

    # accumulators for full‐image fusion
    sum_probs = np.zeros((H_pad, W_pad), dtype=np.float32)
    count_map = np.zeros((H_pad, W_pad), dtype=np.uint16)

    # debug list
    patch_preds = []

    # --- slide & predict ---
    for i, y in enumerate(range(0, H_pad - patch_size + 1, stride)):
        for j, x in enumerate(range(0, W_pad - patch_size + 1, stride)):
            patch = img_p[y:y+patch_size, x:x+patch_size, :]
            prob_map = model.predict(patch[None,...], verbose=0)[0, ..., 0]

            # store for debug
            patch_preds.append((y, x, prob_map.copy()))

            # fuse into full
            sum_probs[y:y+patch_size, x:x+patch_size] += prob_map
            count_map[y:y+patch_size, x:x+patch_size] += 1

    # normalize & crop
    avg_probs = sum_probs / count_map
    avg_probs = avg_probs[:H, :W]

    # final mask
    bin_mask = (avg_probs >= threshold).astype(np.uint8)
    return bin_mask, patch_preds


In [None]:
def show_patches(img, patches, gt_mask,
                 patch_size=128, n_patches=4, threshold=0.5):
    """
    img         : H×W×1 array
    patches     : list of tuples where [0]=y, [1]=x, [-1]=prob_map
    gt_mask     : full-size H×W binary mask
    patch_size  : int
    n_patches   : how many rows to plot
    threshold   : contour level on prob_map
    """
    n = min(n_patches, len(patches))
    fig, axes = plt.subplots(n, 3, figsize=(12, 3*n))

    for i in range(n):
        y, x = int(patches[i][0]), int(patches[i][1])
        prob = patches[i][-1]

        # Raw
        ax = axes[i,0]
        ax.imshow(img[y:y+patch_size, x:x+patch_size, 0], cmap='gray')
        ax.set_title(f'Raw @ ({y},{x})')
        ax.axis('off')

        # Prediction + contour
        ax = axes[i,1]
        im = ax.imshow(prob, cmap='viridis', vmin=0, vmax=1)
        ax.contour(prob >= threshold, colors='r', linewidths=0.5)
        ax.set_title('Prediction')
        ax.axis('off')

        # Ground truth
        gt_patch = gt_mask[y:y+patch_size, x:x+patch_size]
        ax = axes[i,2]
        ax.imshow(gt_patch, cmap='gray')
        ax.set_title('Ground Truth')
        ax.axis('off')

    fig.colorbar(im, ax=axes[:,1].tolist(),
                 shrink=0.6, label='P(fg)')
    # plt.tight_layout()
    plt.show()

In [None]:
len(test_images)

In [None]:
# 1. File lists (must be sorted so image[i] matches mask[i])
test_images = sorted(glob.glob(os.path.join(test_image_dir, "*_test.tif")))
test_masks  = sorted(glob.glob(os.path.join(test_mask_dir,  "*_manual1.tif")))

# 2. Pick the i’th case (here 0)
img_fp  = test_images[18]
msk_fp  = test_masks[18]

# 3. Load & prep
raw         = load_image(img_fp)           # (584,565)
# clahe_raw   = _clahe.apply(raw)            # apply CLAHE
test_img    = raw[..., np.newaxis]   # (584,565,1)


pred_mask   = predict_full(test_img,
                           unet,
                          #  wnet,
                           patch_size=128,
                          #  patch_size=64,
                           stride=16,
                           threshold=0.999)

gt_mask     = load_mask(msk_fp)            # (584,565), binary {0,1}


In [None]:
# 1. File lists (must be sorted so image[i] matches mask[i])
test_images = sorted(glob.glob(os.path.join(test_image_dir, "*_test.tif")))
test_masks  = sorted(glob.glob(os.path.join(test_mask_dir,  "*_manual1.tif")))

# 2. Pick the i’th case (here 0)
img_fp  = test_images[18]
msk_fp  = test_masks[18]

# 3. Load & prep
raw         = load_image(img_fp)           # (584,565)
# clahe_raw   = _clahe.apply(raw)            # apply CLAHE
test_img    = raw[..., np.newaxis]   # (584,565,1)


pred_mask   = predict_full(test_img,
                           unet,
                          #  wnet,
                           patch_size=128,
                          #  patch_size=64,
                           stride=128,
                           threshold=0.9)

gt_mask     = load_mask(msk_fp)            # (584,565), binary {0,1}

# 4. Plot side‐by‐side
plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.title("Green Channel")
plt.imshow(test_img[...,0], cmap='gray')
plt.axis('off')

plt.subplot(1,3,2)
plt.title("Predicted Mask")
plt.imshow(pred_mask, cmap='gray')
plt.axis('off')

plt.subplot(1,3,3)
plt.title("Ground-Truth Mask")
plt.imshow(gt_mask, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
# 1. collect file paths
test_images = sorted(glob.glob(os.path.join(test_image_dir, "*_test.tif")))
test_masks  = sorted(glob.glob(os.path.join(test_mask_dir,  "*_manual1.tif")))

# 2. pick an index
idx = 18
img_fp, msk_fp = test_images[idx], test_masks[idx]

# 3. load data
raw      = load_image(img_fp)           # (H, W)
test_img = raw[..., np.newaxis]         # (H, W, 1)
gt_mask  = load_mask(msk_fp)            # (H, W) binary

# 4. run your debug-predict
#    note: threshold here only affects the *aggregated* full mask,
#          patch_preds always hold raw prob maps.
bin_mask, patch_preds = predict_full_debug(
    test_img, unet,
    patch_size=128,
    stride=64,
    threshold=0.999
)

# 5. show the first 6 patches
show_patches(
    test_img,
    patch_preds,     # <-- unpacked list of (y,x,prob_map)
    gt_mask,
    patch_size=128,
    n_patches=30,
    threshold=0.999
)

# # 6. if you also want the full-image side-by-side:
# plt.figure(figsize=(12,5))
# for i,(arr,title) in enumerate([
#     (test_img[...,0],   "Raw Image"),
#     (bin_mask,          "Aggregated Prediction"),
#     (gt_mask,           "Ground Truth")
# ]):
#     ax = plt.subplot(1,3,i+1)
#     ax.imshow(arr, cmap='gray')
#     ax.set_title(title); ax.axis('off')
# plt.tight_layout()
# plt.show()

Displayed an overlay of the predicted mask contour (green) and ground truth mask contour (red) on the original test image (test image 18, stride 128, threshold 0.9 result).

In [None]:
plt.imshow(test_img[...,0], cmap='gray')
plt.contour(gt_mask[...,0],  colors='r', linewidths=1)   # GT in green
plt.contour(pred_mask, colors='g', linewidths=1)   # Pred in red
plt.title("GT (green) vs Pred (red)")
plt.axis('off')


# 10. Results





*   Defined a `compute_metrics` function to calculate Accuracy, IoU, and F1 Score for binary masks
  *    Uses NumPy operations on flattened arrays.
  *    Includes `_SMOOTH` for numerical stability.

*   Iterated through all 20 test images.
  *    Loaded each image and its corresponding ground truth mask.
  *    Generated a predicted mask using `predict_full` with `patch_size=128`, `stride=64`, and `threshold=0.9`.
  *    Computed Accuracy, IoU, and F1 Score for each image using `compute_metrics`.
  *    Printed the metrics for each image.
  *    Displayed the side-by-side visualization (Image, Predicted Mask, Ground Truth Mask) and the contour overlay (GT red, Pred green) for each image.
  *    

*   Computed and printed the average Accuracy, IoU, and F1 Score across all test images.
  *    Average Accuracy: 0.9639
  *    Average IoU: 0.6383
  *    Average F1 Score: 0.7787


In [None]:
pred_mask = pred_mask.squeeze()

In [None]:
gt_mask = gt_mask.squeeze()

In [None]:

import matplotlib.pyplot as plt
import numpy as np
def compute_metrics(y_true, y_pred):
    """
    Compute Accuracy, IoU, and F1 Score for binary masks.
    y_true, y_pred: numpy arrays of the same shape, with binary values (0 or 1)
    """
    _SMOOTH = 1e-7  # for numerical stability

    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()

    # Accuracy
    correct_pixels = np.sum(y_true_f == y_pred_f)
    total_pixels = len(y_true_f)
    accuracy = correct_pixels / total_pixels

    # IoU (Jaccard Index)
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection
    iou = (intersection + _SMOOTH) / (union + _SMOOTH)

    # F1 Score (Dice Coefficient)
    f1 = (2 * intersection + _SMOOTH) / (np.sum(y_true_f) + np.sum(y_pred_f) + _SMOOTH)

    return accuracy, iou, f1


# 1. Collect file paths
test_images = sorted(glob.glob(os.path.join(test_image_dir, "*_test.tif")))
test_masks  = sorted(glob.glob(os.path.join(test_mask_dir,  "*_manual1.tif")))

all_accuracies = []
all_ious = []
all_f1s = []

# 2. Iterate through all test images
for i in range(len(test_images)):
    img_fp  = test_images[i]
    msk_fp  = test_masks[i]

    # 3. Load & prep
    raw         = load_image(img_fp)
    test_img    = raw[..., np.newaxis]

    # 4. Predict the mask
    pred_mask   = predict_full(test_img,
                               unet,
                               patch_size=128,
                               stride=64,
                               threshold=0.5)

    pred_mask = pred_mask.squeeze()

    gt_mask     = load_mask(msk_fp)
    gt_mask = (gt_mask > 0).astype(np.uint8)
    gt_mask = gt_mask.squeeze()

    # 5. Compute metrics for the current image
    accuracy, iou, f1 = compute_metrics(gt_mask, pred_mask)
    all_accuracies.append(accuracy)
    all_ious.append(iou)
    all_f1s.append(f1)

    print(f"Image {i+1}/{len(test_images)}: Accuracy = {accuracy:.4f}, IoU = {iou:.4f}, F1 = {f1:.4f}")

    # 6. Show visualizations for the current image
    plt.figure(figsize=(15,5))
    plt.subplot(1,3,1)
    plt.title(f"Image {i+1}: Green Channel")
    plt.imshow(test_img[...,0], cmap='gray')
    plt.axis('off')

    plt.subplot(1,3,2)
    plt.title(f"Image {i+1}: Predicted Mask")
    plt.imshow(pred_mask, cmap='gray')
    plt.axis('off')

    plt.subplot(1,3,3)
    plt.title(f"Image {i+1}: Ground-Truth Mask")
    plt.imshow(gt_mask, cmap='gray')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(6,6))
    plt.imshow(test_img[...,0], cmap='gray')
    plt.contour(gt_mask,  colors='r', linewidths=1)   # GT in red
    plt.contour(pred_mask, colors='g', linewidths=1)  # Pred in green
    plt.title(f"Image {i+1}: GT (red) vs Pred (green)")
    plt.axis('off')
    plt.show()

# 7. Compute and print average metrics
avg_accuracy = np.mean(all_accuracies)
avg_iou = np.mean(all_ious)
avg_f1 = np.mean(all_f1s)

print("\n--- Overall Average Metrics ---")
print(f"Average Accuracy across test images: {avg_accuracy:.4f}")
print(f"Average IoU across test images: {avg_iou:.4f}")
print(f"Average F1 Score across test images: {avg_f1:.4f}")
