# Module 7 — Image Segmentation (UNet) — Expanded

This notebook demonstrates an end-to-end segmentation workflow using a tiny synthetic dataset of shapes (circles/rectangles) with binary masks. It includes:

- creating synthetic images + masks
- `tf.data` pipeline for segmentation (image, mask)
- a simple UNet implementation in Keras
- Dice + BCE loss, metrics
- short training and visualization of predictions
- saving the model

This is classroom-friendly and runs quickly. Replace synthetic data with real masks for real projects.

## 1 — Setup (install packages and imports)

In [None]:
# Install/ensure tensorflow
!pip -q install -U tensorflow==2.12.0 --quiet

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
print('TF version:', tf.__version__)

# create data folder
os.makedirs('/mnt/data/seg_dataset/images', exist_ok=True)
os.makedirs('/mnt/data/seg_dataset/masks', exist_ok=True)


## 2 — Create synthetic dataset (images + binary masks)

In [None]:
from PIL import Image, ImageDraw
import random

def create_shape_image(path_img, path_mask, size=(128,128)):
    img = Image.new('RGB', size, (255,255,255))
    mask = Image.new('L', size, 0)
    draw_img = ImageDraw.Draw(img)
    draw_mask = ImageDraw.Draw(mask)
    # randomly choose circle or rectangle
    if random.random() < 0.5:
        # circle
        x0 = random.randint(16, size[0]//2)
        y0 = random.randint(16, size[1]//2)
        r = random.randint(16, min(size)//3)
        bbox = [x0-r, y0-r, x0+r, y0+r]
        draw_img.ellipse(bbox, fill=(random.randint(50,200), random.randint(50,200), random.randint(50,200)))
        draw_mask.ellipse(bbox, fill=255)
    else:
        # rectangle
        x1 = random.randint(8, size[0]//2)
        y1 = random.randint(8, size[1]//2)
        x2 = random.randint(size[0]//2, size[0]-8)
        y2 = random.randint(size[1]//2, size[1]-8)
        bbox = [x1,y1,x2,y2]
        draw_img.rectangle(bbox, fill=(random.randint(50,200), random.randint(50,200), random.randint(50,200)))
        draw_mask.rectangle(bbox, fill=255)
    img.save(path_img)
    mask.save(path_mask)

# create dataset
n_train = 80
n_val = 20
for i in range(n_train):
    create_shape_image(f'/mnt/data/seg_dataset/images/train_{i}.png', f'/mnt/data/seg_dataset/masks/train_{i}.png')
for i in range(n_val):
    create_shape_image(f'/mnt/data/seg_dataset/images/val_{i}.png', f'/mnt/data/seg_dataset/masks/val_{i}.png')

print('Created synthetic segmentation dataset with', n_train, 'train and', n_val, 'val samples')

# show examples
fig, axes = plt.subplots(2,3, figsize=(9,6))
for i in range(3):
    img = Image.open(f'/mnt/data/seg_dataset/images/train_{i}.png')
    mask = Image.open(f'/mnt/data/seg_dataset/masks/train_{i}.png')
    axes[0,i].imshow(img); axes[0,i].axis('off')
    axes[1,i].imshow(mask, cmap='gray'); axes[1,i].axis('off')
plt.suptitle('Sample images (top) and masks (bottom)')
plt.show()


## 3 — tf.data pipeline (image, mask pairs)

In [None]:
import tensorflow as tf
from glob import glob

IMG_SIZE = (128,128)
BATCH = 8

image_paths = sorted(glob('/mnt/data/seg_dataset/images/*.png'))
mask_paths = sorted(glob('/mnt/data/seg_dataset/masks/*.png'))

# split by filenames (first n_train are train)
train_images = image_paths[:n_train]
train_masks = mask_paths[:n_train]
val_images = image_paths[n_train:]
val_masks = mask_paths[n_train:]

def load_image_mask(img_path, mask_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.cast(img, tf.float32) / 255.0
    m = tf.io.read_file(mask_path)
    m = tf.image.decode_png(m, channels=1)
    m = tf.image.resize(m, IMG_SIZE)
    m = tf.cast(m, tf.float32) / 255.0
    return img, m

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_masks)).map(lambda x,y: load_image_mask(x,y), num_parallel_calls=tf.data.AUTOTUNE).shuffle(100).batch(BATCH).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((val_images, val_masks)).map(lambda x,y: load_image_mask(x,y), num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH).prefetch(tf.data.AUTOTUNE)

for imgs, masks in train_ds.take(1):
    print('Batch shapes:', imgs.shape, masks.shape)


## 4 — Simple UNet implementation (Keras)

In [None]:
from tensorflow.keras import layers, models

def conv_block(x, filters):
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    return x

def encoder_block(x, filters):
    c = conv_block(x, filters)
    p = layers.MaxPooling2D((2,2))(c)
    return c, p

def decoder_block(x, skip, filters):
    x = layers.UpSampling2D((2,2))(x)
    x = layers.Concatenate()([x, skip])
    x = conv_block(x, filters)
    return x

inputs = layers.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
# Encoder
c1, p1 = encoder_block(inputs, 16)
c2, p2 = encoder_block(p1, 32)
c3, p3 = encoder_block(p2, 64)
# Bottleneck
bn = conv_block(p3, 128)
# Decoder
d3 = decoder_block(bn, c3, 64)
d2 = decoder_block(d3, c2, 32)
d1 = decoder_block(d2, c1, 16)
outputs = layers.Conv2D(1, 1, activation='sigmoid')(d1)

unet = models.Model(inputs, outputs)
unet.summary()


## 5 — Loss (BCE + Dice) and metrics

In [None]:
import tensorflow as tf

def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def bce_dice_loss(y_true, y_pred):
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    dice = 1 - dice_coef(y_true, y_pred)
    return bce + dice

unet.compile(optimizer='adam', loss=bce_dice_loss, metrics=[dice_coef])


## 6 — Train UNet (short demo)

In [None]:
EPOCHS = 8
history = unet.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)

# plot metrics
plt.figure(figsize=(10,4))
plt.subplot(1,2,1); plt.plot(history.history['loss'], label='train_loss'); plt.plot(history.history['val_loss'], label='val_loss'); plt.legend(); plt.title('Loss')
plt.subplot(1,2,2); plt.plot(history.history['dice_coef'], label='train_dice'); plt.plot(history.history['val_dice_coef'], label='val_dice'); plt.legend(); plt.title('Dice')
plt.show()


## 7 — Visualize predictions on validation images

In [None]:
# Visualize some predictions
for imgs, masks in val_ds.take(2):
    preds = unet.predict(imgs)
    plt.figure(figsize=(12,6))
    for i in range(min(6, imgs.shape[0])):
        plt.subplot(3,6,i+1); plt.imshow(imgs[i]); plt.axis('off'); plt.title('Image')
        plt.subplot(3,6,i+7); plt.imshow(masks[i,:,:,0], cmap='gray'); plt.axis('off'); plt.title('Mask')
        plt.subplot(3,6,i+13); plt.imshow(preds[i,:,:,0]>0.5, cmap='gray'); plt.axis('off'); plt.title('Pred >0.5')
    plt.show()
    break


## 8 — Save model and tips

In [None]:
unet.save('/mnt/data/unet_synthetic.h5')
print('Saved UNet to /mnt/data/unet_synthetic.h5')

# Tips: Replace synthetic dataset with real image/mask pairs; adjust model capacity for real tasks.


## 9 — Exercises & Instructor Notes

- Replace synthetic dataset with a real segmentation dataset (e.g., plant leaf segmentation) and re-run training.
- Try different loss functions (Dice loss alone, Focal loss) for class imbalance.
- Show morphological post-processing (opening/closing) on predicted masks to clean artifacts.
- Export model to SavedModel or TFLite for deployment demos.
