In [None]:
"""
3-class U-Net training with 3-color masks
"""

import os
import glob
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K

In [None]:
image_list = sorted(glob.glob("K:/images/*"))
label_list = sorted(glob.glob("K:/masks/*"))

image_size = (512, 512)
batch_size = 8
epochs = 50

background = [255, 0, 0]
boundary = [0, 255, 0]
inside = [0, 0, 255]
COLOR_DICT = np.array([background, boundary, inside])
num_class = 3

fig, axs = plt.subplots(2, 5, figsize=(30,10))
sampling = random.sample(range(len(image_list)), k=5)
for col, i in enumerate(sampling):
    img = cv2.cvtColor(cv2.imread(image_list[i]), cv2.COLOR_BGR2RGB)
    label = cv2.cvtColor(cv2.imread(label_list[i]), cv2.COLOR_BGR2RGB)
    axs[0, col].imshow(img)
    axs[0, col].axis('off')
    axs[0, col].set_title(f"Image {col+1}")
    axs[1, col].imshow(label)
    axs[1, col].axis('off')
    axs[1, col].set_title(f"Mask {col+1}")
plt.show()

In [None]:
def mask_to_onehot(mask, color_dict=COLOR_DICT, tol=10):
    h, w, _ = mask.shape
    onehot = np.zeros((h, w, len(color_dict)), dtype=np.uint8)
    for i, color in enumerate(color_dict):
        matches = np.all(np.abs(mask - color) <= tol, axis=-1)
        onehot[..., i] = matches.astype(np.uint8)
    return onehot

def adjustData(img, mask, target_size):
    img = np.array([cv2.resize(im, target_size) for im in img]) / 255.0
    mask_resized = np.array([cv2.resize(mk, target_size, interpolation=cv2.INTER_NEAREST) for mk in mask])
    onehot_masks = np.array([mask_to_onehot(mk) for mk in mask_resized])
    return img, onehot_masks

In [None]:
def trainGenerator(batch_size, image_list, mask_list, augment_dict=None, target_size=image_size):
    datagen_args = augment_dict if augment_dict else {}
    image_datagen = ImageDataGenerator(**datagen_args)
    mask_datagen = ImageDataGenerator(**datagen_args)

    while True:
        idxs = np.random.permutation(len(image_list))
        for i in range(0, len(image_list), batch_size):
            batch_idxs = idxs[i:i+batch_size]
            batch_images = [cv2.cvtColor(cv2.imread(image_list[j]), cv2.COLOR_BGR2RGB) for j in batch_idxs]
            batch_masks = [cv2.cvtColor(cv2.imread(mask_list[j]), cv2.COLOR_BGR2RGB) for j in batch_idxs]
            if augment_dict:
                seed = np.random.randint(0, 10000)
                batch_images = next(image_datagen.flow(np.array(batch_images), batch_size=batch_size, seed=seed))
                batch_masks = next(mask_datagen.flow(np.array(batch_masks), batch_size=batch_size, seed=seed))
            X, Y = adjustData(batch_images, batch_masks, target_size)
            yield X, Y

In [None]:
# U-Net model
def conv_bn(filters, x):
    x = Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

def unet(input_size=(512, 512, 3)):
    inputs = Input(input_size)
    c1 = conv_bn(64, inputs)
    p1 = MaxPooling2D((2,2))(c1)
    c2 = conv_bn(128, p1)
    p2 = MaxPooling2D((2,2))(c2)
    c3 = conv_bn(256, p2)
    p3 = MaxPooling2D((2,2))(c3)
    c4 = conv_bn(512, p3)
    p4 = MaxPooling2D((2,2))(c4)

    c5 = conv_bn(1024, p4)

    u6 = UpSampling2D((2,2))(c5)
    u6 = Conv2D(512, 2, activation='relu', padding='same')(u6)
    merge6 = concatenate([c4, u6])
    c6 = conv_bn(512, merge6)

    u7 = UpSampling2D((2,2))(c6)
    u7 = Conv2D(256, 2, activation='relu', padding='same')(u7)
    merge7 = concatenate([c3, u7])
    c7 = conv_bn(256, merge7)

    u8 = UpSampling2D((2,2))(c7)
    u8 = Conv2D(128, 2, activation='relu', padding='same')(u8)
    merge8 = concatenate([c2, u8])
    c8 = conv_bn(128, merge8)

    u9 = UpSampling2D((2,2))(c8)
    u9 = Conv2D(64, 2, activation='relu', padding='same')(u9)
    merge9 = concatenate([c1, u9])
    c9 = conv_bn(64, merge9)

    outputs = Conv2D(num_class, 1, activation='softmax')(c9)
    model = Model(inputs, outputs)
    return model

In [None]:
train_gen_args = dict(horizontal_flip=True, vertical_flip=True)
val_gen_args = dict(horizontal_flip=False, vertical_flip=False)

split_idx = int(0.8*len(image_list))
train_images = image_list[:split_idx]
train_masks = label_list[:split_idx]
val_images = image_list[split_idx:]
val_masks = label_list[split_idx:]

trainGene = trainGenerator(batch_size, train_images, train_masks, augment_dict=train_gen_args)
valGene = trainGenerator(batch_size, val_images, val_masks, augment_dict=val_gen_args)

model = unet(input_size=(image_size[0], image_size[1], 3))
model.compile(optimizer=Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
steps_per_epoch = len(train_images)//batch_size
validation_steps = len(val_images)//batch_size

checkpoint = ModelCheckpoint('unet_best.h5', monitor='loss', verbose=1, save_best_only=True)

In [None]:
history = model.fit(
    trainGene,
    steps_per_epoch=steps_per_epoch,
    validation_data=valGene,
    validation_steps=validation_steps,
    epochs=epochs,
    callbacks=[checkpoint]
)

In [None]:
# Plot accuracy

plt.figure()
plt.plot(history.history['accuracy'], label='train')
plt.plot(history.history['val_accuracy'], label='val')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
from sklearn.metrics import jaccard_score
import matplotlib.pyplot as plt

test_img_dir = "/content/images-in"
test_mask_dir = "/content/testMask"  
save_pred_dir = "test_predictions1"
os.makedirs(save_pred_dir, exist_ok=True)

test_images = sorted(glob.glob(os.path.join(test_img_dir, "*")))
test_masks = sorted(glob.glob(os.path.join(test_mask_dir, "*"))) if os.path.exists(test_mask_dir) else None

def preprocess_test_image(img_path):
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, image_size)
    img = img / 255.0
    return img

def onehot_to_color(pred_mask, color_dict=COLOR_DICT):
    mask_class = np.argmax(pred_mask, axis=-1)
    h, w = mask_class.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for i, color in enumerate(color_dict):
        color_mask[mask_class == i] = color
    return color_mask

for i, img_path in enumerate(test_images):
    img_name = os.path.basename(img_path)

    img = preprocess_test_image(img_path)
    pred = model.predict(np.expand_dims(img, 0))[0]
    pred_color = onehot_to_color(pred)

    save_path = os.path.join(save_pred_dir, img_name)
    cv2.imwrite(save_path, cv2.cvtColor(pred_color, cv2.COLOR_RGB2BGR))

    if test_masks:
        mask_path = test_masks[i]
        true_mask = cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_BGR2RGB)
        true_mask = cv2.resize(true_mask, image_size)

    plt.figure(figsize=(15,5))

    plt.subplot(1, 3, 1)
    plt.imshow((img * 255).astype(np.uint8))
    plt.title("Image")
    plt.axis("off")
 
    if test_masks:
        plt.subplot(1, 3, 2)
        plt.imshow(true_mask)
        plt.title("Ground Truth Mask")
        plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pred_color)
    plt.title("Predicted Mask")
    plt.axis("off")

    plt.show()

     # ----- IoU -----
    if test_masks:
        true_onehot = mask_to_onehot(true_mask)
        pred_onehot = mask_to_onehot(pred_color)

        y_true = np.argmax(true_onehot, axis=-1).flatten()
        y_pred = np.argmax(pred_onehot, axis=-1).flatten()

        iou = jaccard_score(y_true, y_pred, average='macro')
        print(f"{img_name} - IoU: {iou:.4f}")