In [None]:
def data_generator():
    for _ in range(1000):  # Example: 1000 samples
        image = np.random.rand(256, 256, 3).astype(np.float32)  # RGB Image
        mask = np.random.randint(0, 2, (256, 256, 1), dtype=np.uint8)  # Binary mask
        yield image, mask

# Create dataset using from_generator
dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 3), dtype=tf.float32),  # Image shape
        tf.TensorSpec(shape=(256, 256, 1), dtype=tf.uint8)     # Mask shape
    )
)

# Optimize dataset performance
dataset = (dataset
    .shuffle(1000)  # Shuffle buffer
    .batch(16)  # Batch size
    .prefetch(tf.data.AUTOTUNE)  # Prefetch for fast training
)

In [None]:
# Data generator
def data_generator():
    for x, y in zip(X_list, Y_list):
        yield x, y

# Create dataset
BATCH_SIZE = 16
dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 3), dtype=tf.float32),  # Image
        tf.TensorSpec(shape=(256, 256, 4), dtype=tf.float32)   # One-hot mask (4 classes)
    )
)

# Optimize dataset pipeline
dataset = dataset.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
from tensorflow.keras import layers, Model

def unet_multiclass(input_shape=(256, 256, 3), num_classes=4):
    inputs = layers.Input(input_shape)

    # Encoder
    conv1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D((2, 2))(conv1)

    conv2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D((2, 2))(conv2)

    conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D((2, 2))(conv3)

    # Bottleneck
    conv4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)

    # Decoder
    up5 = layers.UpSampling2D((2, 2))(conv4)
    merge5 = layers.concatenate([conv3, up5], axis=-1)
    conv5 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(merge5)
    conv5 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv5)

    up6 = layers.UpSampling2D((2, 2))(conv5)
    merge6 = layers.concatenate([conv2, up6], axis=-1)
    conv6 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(merge6)
    conv6 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv6)

    up7 = layers.UpSampling2D((2, 2))(conv6)
    merge7 = layers.concatenate([conv1, up7], axis=-1)
    conv7 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(merge7)
    conv7 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv7)

    # Output layer for multi-class segmentation
    outputs = layers.Conv2D(num_classes, (1, 1), activation='softmax')(conv7)

    return Model(inputs, outputs)

# Create model
model = unet_multiclass()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy', iou_metric])
model.summary()


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

# Training settings
EPOCHS = 50
CHECKPOINT_PATH = "best_unet_multiclass.h5"

# Callbacks
callbacks = [
    ModelCheckpoint(CHECKPOINT_PATH, monitor="val_loss", save_best_only=True, mode="min", verbose=1),
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6, verbose=1),
    EarlyStopping(monitor="val_loss", patience=7, restore_best_weights=True, verbose=1)
]

# Train the model
history = model.fit(
    dataset, 
    epochs=EPOCHS, 
    validation_data=val_dataset,  # Assuming you have a validation dataset
    callbacks=callbacks,
    verbose=1
)


In [None]:
import tensorflow as tf

def iou_metric(y_true, y_pred):
    """
    Computes the mean Intersection over Union (IoU) for multi-class segmentation.
    """
    y_pred = tf.argmax(y_pred, axis=-1)  # Convert softmax probabilities to class labels
    y_true = tf.argmax(y_true, axis=-1)  # Convert one-hot to class labels
    
    iou_list = []
    num_classes = tf.reduce_max(y_true) + 1  # Get number of classes dynamically
    
    for i in range(num_classes):
        true_mask = tf.cast(y_true == i, tf.float32)
        pred_mask = tf.cast(y_pred == i, tf.float32)

        intersection = tf.reduce_sum(true_mask * pred_mask)
        union = tf.reduce_sum(true_mask) + tf.reduce_sum(pred_mask) - intersection

        iou = (intersection + 1e-6) / (union + 1e-6)  # Avoid division by zero
        iou_list.append(iou)

    return tf.reduce_mean(iou_list)  # Average IoU across classes
    

In [None]:
def dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)  # Avoid zero division
    
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    
    dice = (2. * intersection + 1e-6) / (union + 1e-6)
    return 1 - tf.reduce_mean(dice)

# Compile model with Dice Loss
model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy', iou_metric])   

In [None]:
def dice_ce_loss(y_true, y_pred):
    return 0.5 * dice_loss(y_true, y_pred) + 0.5 * tf.keras.losses.categorical_crossentropy(y_true, y_pred)

# Compile model with hybrid loss
model.compile(optimizer='adam', loss=dice_ce_loss, metrics=['accuracy', iou_metric])


In [None]:
import tensorflow as tf

# Example: Assigning weights to 4 classes
class_weights = tf.constant([0.1, 0.3, 0.3, 0.3])  # Adjust based on dataset distribution

def weighted_categorical_crossentropy(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    
    # Apply per-class weights
    weights = tf.reduce_sum(class_weights * y_true, axis=-1)
    weighted_loss = loss * weights
    
    return tf.reduce_mean(weighted_loss)

# Compile model with WCCE
model.compile(optimizer='adam', loss=weighted_categorical_crossentropy, metrics=['accuracy', iou_metric])
