In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras import layers as L, Model
from tensorflow.keras.metrics import BinaryAccuracy, Precision, BinaryIoU
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array

In [None]:
print(tf.__version__)
print("Num GPUs Available:", len(tf.config.list_physical_devices('GPU')))

In [None]:
# Data loading
def load_masks(mask_dir, image_size=(256, 256)):
    mask_files = sorted(os.listdir(mask_dir))
    masks = []

    for mask_file in mask_files:
        mask = cv2.imread(os.path.join(mask_dir, mask_file), cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, image_size)
        mask = (mask / 255.0).astype(np.uint8)  # Ensure binary
        masks.append(mask)

    return np.expand_dims(np.array(masks, dtype=np.float32), -1)

def load_images(image_dir, image_size=(256, 256)):
    image_files = sorted(os.listdir(image_dir))
    images = []

    for img_file in image_files:
        img = cv2.imread(os.path.join(image_dir, img_file), cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, image_size) / 255.0
        images.append(img)

    return np.expand_dims(np.array(images, dtype=np.float32), -1)

# Plotting
def plot_image_mask(idx, X, Y, title=""):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Input Image")
    plt.imshow(X[idx].squeeze(), cmap='gray')
    plt.subplot(1, 2, 2)
    plt.title("Ground Truth")
    plt.imshow(Y[idx].squeeze(), cmap='gray')
    plt.suptitle(title)
    plt.show()

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def combined_dice_bce_loss(y_true, y_pred):
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return 0.4*bce + 0.6*dice

# Model Blocks
def conv_block(x, filters):
    for _ in range(2):
        x = L.Conv2D(filters, 3, padding="same")(x)
        x = L.BatchNormalization()(x)
        x = L.Activation("relu")(x)
    return x

def encoder_block(x, filters):
    c = conv_block(x, filters)
    p = L.MaxPool2D((2, 2))(c)
    return c, p

def attention_gate(g, s, filters):
    g1 = L.BatchNormalization()(L.Conv2D(filters, 1, padding="same")(g))
    s1 = L.BatchNormalization()(L.Conv2D(filters, 1, padding="same")(s))
    out = L.Activation("relu")(g1 + s1)
    out = L.Activation("sigmoid")(L.Conv2D(filters, 1, padding="same")(out))
    return out * s

def decoder_block(x, s, filters):
    x = L.UpSampling2D(interpolation="bilinear")(x)
    s = attention_gate(x, s, filters)
    x = L.Concatenate()([x, s])
    return conv_block(x, filters)

def attention_unet(input_shape):
    inputs = L.Input(input_shape)
    c0, p0 = encoder_block(inputs, 32)
    c1, p1 = encoder_block(p0, 64)
    c2, p2 = encoder_block(p1, 128)
    c3, p3 = encoder_block(p2, 256)
    c4, p4 = encoder_block(p3, 512)
    c5 = conv_block(p4, 1024)
    u1 = decoder_block(c5, c4, 512)
    u2 = decoder_block(u1, c3, 256)
    u3 = decoder_block(u2, c2, 128)
    u4 = decoder_block(u3, c1, 64)
    u5 = decoder_block(u4, c0, 32)
    outputs = L.Conv2D(1, 1, padding="same", activation="sigmoid")(u5)

    model = Model(inputs, outputs)

    model.compile(optimizer=Adam(0.001),
              loss=combined_dice_bce_loss,
              metrics=[
                  BinaryIoU(target_class_ids=[0, 1], threshold=0.5),
                  BinaryAccuracy(),
                  Precision()
              ])

    return model

# -Metrics
def plot_metrics(history, metric_name, ylabel, ylim=None):
    plt.figure()
    plt.plot(history.epoch, history.history[metric_name], 'b', label='Training')
    plt.plot(history.epoch, history.history[f"val_{metric_name}"], 'r', label='Validation')
    plt.title(f'Training and Validation {metric_name.capitalize()}')
    plt.xlabel('Epoch')
    plt.ylabel(ylabel)
    if ylim:
        plt.ylim(ylim)
    plt.legend()
    plt.show()

# Predict and Visualize Test Results
def visualize_prediction(idx, model, X, Y):
    pred = model.predict(np.expand_dims(X[idx], axis=0)).squeeze()
    gt = Y[idx].squeeze()
    diff = np.stack([gt, gt, pred], axis=-1)

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 4, 1)
    plt.title("Input Image")
    plt.imshow(X[idx].squeeze(), cmap='gray')
    plt.subplot(1, 4, 2)
    plt.title("Ground Truth")
    plt.imshow(gt, cmap='gray')
    plt.subplot(1, 4, 3)
    plt.title("Prediction")
    plt.imshow(pred, cmap='gray')
    plt.subplot(1, 4, 4)
    plt.title("Differences")
    plt.imshow(diff)
    plt.show()

# Evaluate with DICE
def DICE_COE(mask1, mask2):
    intersect = np.sum(mask1 * mask2)
    return round((2 * intersect) / (np.sum(mask1) + np.sum(mask2) + 1e-7), 3)

In [None]:
channels = ['gray', 'RGB_B', 'RGB_G', 'RGB_R', 'YUV_Y', 'YUV_U', 'YUV_V',
            'HSV_H', 'HSV_S', 'HSV_V', 'HLS_H', 'HLS_L', 'HLS_S',
            'CIELab_L', 'CIELab_a', 'CIELab_b', 'YCrCb_Y', 'YCrCb_Cr', 'YCrCb_Cb']

# Common target mask directories
mask_train_dir = 'C:/Users/User/Desktop/Helevorn/data/segmentation/train'
mask_val_dir = 'C:/Users/User/Desktop/Helevorn/data/segmentation/val'
mask_test_dir = 'C:/Users/User/Desktop/Helevorn/data/segmentation/test'

Y_train = load_masks(mask_train_dir)
Y_val = load_masks(mask_val_dir)
Y_test = load_masks(mask_test_dir)

for channel in channels:
    print(f"\n--- Processing Channel: {channel} ---")
    
    img_train_dir = f'C:/Users/User/Desktop/Helevorn/data/chan/{channel}/images/train'
    img_val_dir = f'C:/Users/User/Desktop/Helevorn/data/chan/{channel}/images/val'
    img_test_dir = f'C:/Users/User/Desktop/Helevorn/data/chan/{channel}/images/test'
    
    # Load channel-specific input images
    X_train = load_images(img_train_dir)
    X_val = load_images(img_val_dir)
    X_test = load_images(img_test_dir)
    
    print(f"Train shape: {X_train.shape}, Val shape: {X_val.shape}, Test shape: {X_test.shape}")
    
    model = attention_unet((256, 256, 1))
    
    history = model.fit(X_train, Y_train,
                        validation_data=(X_val, Y_val),
                        batch_size=8,
                        epochs=50)
    
    # Plot metrics
    plot_metrics(history, 'binary_accuracy', 'Binary Accuracy', [0, 1])
    plot_metrics(history, 'precision', 'Precision', [0, 1])
    plot_metrics(history, 'loss', 'Loss', [0, 1])
    plot_metrics(history, 'binary_io_u', 'IoU', [0, 1])
    
    # Evaluate and print DICE
    dice_scores = [DICE_COE(model.predict(np.expand_dims(x, 0)).squeeze(), y.squeeze())
                   for x, y in zip(X_test, Y_test)]

    print(f"{channel} - Avg DICE Score:", np.mean(dice_scores))
    print(f"{channel} - Max DICE Score:", np.max(dice_scores))
    print(f"{channel} - Min DICE Score:", np.min(dice_scores))
    print(f"{channel} - Std DICE Score:", np.std(dice_scores))
    
    # Save model
    model.save(f'retinal_{channel}.keras')

    # Cleanup
    del X_train, X_val, X_test, model
    K.clear_session()

In [None]:
def attention_unet(input_shape=(None, None, 1)):
    inputs = L.Input(input_shape)
    c0, p0 = encoder_block(inputs, 32)
    c1, p1 = encoder_block(p0, 64)
    c2, p2 = encoder_block(p1, 128)
    c3, p3 = encoder_block(p2, 256)
    c4, p4 = encoder_block(p3, 512)
    c5 = conv_block(p4, 1024)
    u1 = decoder_block(c5, c4, 512)
    u2 = decoder_block(u1, c3, 256)
    u3 = decoder_block(u2, c2, 128)
    u4 = decoder_block(u3, c1, 64)
    u5 = decoder_block(u4, c0, 32)
    outputs = L.Conv2D(1, 1, padding="same", activation="sigmoid")(u5)

    model = Model(inputs, outputs)

    model.compile(optimizer=Adam(0.001),
              loss=combined_dice_bce_loss,
              metrics=[
                  BinaryIoU(target_class_ids=[0, 1], threshold=0.5),
                  BinaryAccuracy(),
                  Precision()
              ])

    return model

for channel in channels:
    model_path = f"retinal_{channel}.keras"
    weights_path = f"model_weights_{channel}.h5"
    
    print(f"Processing {channel}...")

    # Load fixed model with locked input shape
    fixed_model = load_model(
        model_path,
        custom_objects={
            "combined_dice_bce_loss": combined_dice_bce_loss,
            "dice_loss": dice_loss,
            "BinaryIoU": tf.keras.metrics.BinaryIoU,
            "BinaryAccuracy": tf.keras.metrics.BinaryAccuracy,
            "Precision": tf.keras.metrics.Precision
        }
    )

    # Save weights only
    fixed_model.save_weights(weights_path)
    print(f"Saved weights to {weights_path}")


In [None]:
def calculate_metrics(y_true, y_pred):
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()

    tp = np.sum((y_true == 1) & (y_pred == 1))
    tn = np.sum((y_true == 0) & (y_pred == 0))
    fp = np.sum((y_true == 0) & (y_pred == 1))
    fn = np.sum((y_true == 1) & (y_pred == 0))

    acc = (tp + tn) / (tp + tn + fp + fn + 1e-6)
    prec = tp / (tp + fp + 1e-6)
    spec = tn / (tn + fp + 1e-6)

    return acc, prec, spec

In [None]:
# Rebuild the flexible model
flex_model = attention_unet(input_shape=(None, None, 1))

# Load weights
flex_model.load_weights("model_weights_gray.h5")

# Load and preprocess image
img_path = "C:/Users/User/Desktop/Fangorn/data/numeric/images/grey/1.png"
img = load_img(img_path, color_mode="grayscale")
img = img_to_array(img)  # shape: (H, W, 1)
h, w = img.shape[:2]

# Pad to multiple of 32
pad_h = (32 - h % 32) % 32
pad_w = (32 - w % 32) % 32
img_padded = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="constant")

img_padded = img_padded / 255.0
img_input = np.expand_dims(img_padded, axis=0)  # (1, H, W, 1)

# Predict
pred = flex_model.predict(img_input)[0, ..., 0]
pred = pred[:h, :w]  # Crop back to original size

# Threshold
binary_mask = (pred > 0.5).astype(np.uint8)

# Load ground truth mask (must match predicted shape)
gt_mask = load_img("C:/Users/User/Desktop/Fangorn/data/numeric/gt/1.png", color_mode="grayscale")
gt_mask = img_to_array(gt_mask).squeeze()  # shape: (H, W)
gt_mask = cv2.resize(gt_mask, (w, h))  # Resize if needed
gt_mask = (gt_mask > 127).astype(np.uint8)  # Binarize if it's 8-bit

dice_score = DICE_COE(gt_mask, binary_mask)

acc, prec, spec = calculate_metrics(gt_mask, binary_mask)

print(f"Dice coefficient : {dice_score:.4f}")
print(f"Accuracy         : {acc:.4f}")
print(f"Precision        : {prec:.4f}")
print(f"Specificity      : {spec:.4f}")

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(img.squeeze(), cmap="gray")
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(gt_mask, cmap="gray")
plt.title("Ground Truth Mask")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(binary_mask, cmap="gray")
plt.title(f"Predicted Mask Dice: {dice_score:.4f}")
plt.axis("off")

plt.tight_layout()
plt.show()