## Importation

In [1]:
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

In [2]:
# -------------------------
# CONFIG
# -------------------------
ROOT = r"D:\WORK\Python\Project\CV\flood_map_segmentation\floodmap_segmentation\data"            # <-- set this to the folder that contains metadata.csv, images/, masks/
CSV_PATH = os.path.join(ROOT, "metadata.csv")
IMAGES_DIR = os.path.join(ROOT, "Image")   # .jpg
MASKS_DIR  = os.path.join(ROOT, "Mask")    # .png

IMAGE_SIZE = 256
BATCH_SIZE = 8
EPOCHS = 40
SEED = 42
AUTOTUNE = tf.data.AUTOTUNE
N_CLASSES = 1        # 1 for binary segmentation (sigmoid). If multiclass, set >1 and adjust masks accordingly.

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

## Pair images and masks by meta

In [3]:
def discover_pairs_from_csv(csv_path, images_dir, masks_dir):
    """
    Reads CSV and tries to infer image & mask filenames.
    Accepts common column names, or falls back to using first two columns.
    Returns list of (image_path, mask_path).
    """
    df = pd.read_csv(csv_path)
    # common column names
    possible_img_cols = ['image', 'img', 'image_filename', 'image_file', 'filename', 'file']
    possible_mask_cols = ['mask', 'mask_filename', 'mask_file', 'mask_path', 'segmentation']

    img_col = None
    mask_col = None
    for c in possible_img_cols:
        if c in df.columns:
            img_col = c
            break
    for c in possible_mask_cols:
        if c in df.columns:
            mask_col = c
            break

    if img_col is None or mask_col is None:
        # fallback: use first two columns
        if len(df.columns) >= 2:
            img_col, mask_col = df.columns[0], df.columns[1]
        else:
            raise ValueError("Could not detect image/mask columns from CSV. Please ensure metadata.csv contains filenames.")

    imgs = df[img_col].astype(str).tolist()
    masks = df[mask_col].astype(str).tolist()

    pairs = []
    for im, m in zip(imgs, masks):
        im_path = im if os.path.isabs(im) else os.path.join(images_dir, im)
        m_path  = m  if os.path.isabs(m) else os.path.join(masks_dir, m)
        if not os.path.exists(im_path):
            raise FileNotFoundError(f"Image not found: {im_path}")
        if not os.path.exists(m_path):
            raise FileNotFoundError(f"Mask not found: {m_path}")
        pairs.append((im_path, m_path))
    return pairs


In [14]:
# -------------------------
# DATA LOADING & PREPROCESSING
# -------------------------
def read_image(path, image_size=IMAGE_SIZE):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)  # [0,1]
    image = tf.image.resize(image, [image_size, image_size])
    return image

def read_mask(path, image_size=IMAGE_SIZE, n_classes=N_CLASSES):
    mask = tf.io.read_file(path)
    # Masks are PNG where each pixel may be 0 or 255 (or class index)
    mask = tf.image.decode_png(mask, channels=1)  # keep 1 channel
    # convert to float32
    mask = tf.image.convert_image_dtype(mask, tf.float32)  # 0..1
    # If mask is 0/255, convert to 0/1 by thresholding at 0.5
    mask = tf.image.resize(mask, [image_size, image_size], method='nearest')
    mask = tf.where(mask>=0.5, 1.0, 0.0)
    if n_classes == 1:
        return mask
    else:
        # For multi-class integer masks, we expect integer values (0..n_classes-1)
        # Convert to one-hot if requested
        mask_int = tf.cast(tf.squeeze(mask, axis=-1) * (n_classes-1), tf.int32)  # attempt mapping 0..1 -> 0..n_classes-1
        mask_oh = tf.one_hot(mask_int, depth=n_classes)
        return tf.cast(mask_oh, tf.float32)

def preprocess_pair(image_path, mask_path, augment=False):
    image = read_image(image_path)
    mask = read_mask(mask_path)

    if augment:
        if tf.random.uniform(()) > 0.5:
            image = tf.image.flip_left_right(image)
            mask  = tf.image.flip_left_right(mask)
        if tf.random.uniform(()) > 0.5:
            image = tf.image.flip_up_down(image)
            mask  = tf.image.flip_up_down(mask)

    return image, mask


In [15]:
def make_dataset(pairs, batch_size=BATCH_SIZE, augment=False, shuffle=True):
    image_paths = [p for p, m in pairs]
    mask_paths  = [m for p, m in pairs]
    ds = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(pairs), seed=SEED, reshuffle_each_iteration=True)

    # Directly map preprocess_pair
    ds = ds.map(lambda i, m: preprocess_pair(i, m, augment=augment),
                num_parallel_calls=AUTOTUNE)

    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds


## Unet Architecture

In [6]:
# -------------------------
# U-Net model
# -------------------------
def conv_block(x, filters, kernel_size=3, activation='relu', batchnorm=True, name=None):
    x = layers.Conv2D(filters, kernel_size, padding="same", use_bias=not batchnorm,
                      kernel_initializer="he_normal", name=(None if name is None else f"{name}_conv1"))(x)
    if batchnorm:
        x = layers.BatchNormalization(name=(None if name is None else f"{name}_bn1"))(x)
    x = layers.Activation(activation, name=(None if name is None else f"{name}_act1"))(x)

    x = layers.Conv2D(filters, kernel_size, padding="same", use_bias=not batchnorm,
                      kernel_initializer="he_normal", name=(None if name is None else f"{name}_conv2"))(x)
    if batchnorm:
        x = layers.BatchNormalization(name=(None if name is None else f"{name}_bn2"))(x)
    x = layers.Activation(activation, name=(None if name is None else f"{name}_act2"))(x)
    return x

def build_unet(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), filters=64, dropout=0.1, n_classes=N_CLASSES):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = conv_block(inputs, filters, name="enc1")
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = conv_block(p1, filters*2, name="enc2")
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = conv_block(p2, filters*4, name="enc3")
    p3 = layers.MaxPooling2D((2,2))(c3)

    c4 = conv_block(p3, filters*8, name="enc4")
    p4 = layers.MaxPooling2D((2,2))(c4)

    # Bottleneck
    b = conv_block(p4, filters*16, name="bottleneck")
    b = layers.Dropout(dropout)(b)

    # Decoder
    u1 = layers.Conv2DTranspose(filters*8, (2,2), strides=(2,2), padding="same")(b)
    u1 = layers.concatenate([u1, c4])
    c5 = conv_block(u1, filters*8, name="dec1")

    u2 = layers.Conv2DTranspose(filters*4, (2,2), strides=(2,2), padding="same")(c5)
    u2 = layers.concatenate([u2, c3])
    c6 = conv_block(u2, filters*4, name="dec2")

    u3 = layers.Conv2DTranspose(filters*2, (2,2), strides=(2,2), padding="same")(c6)
    u3 = layers.concatenate([u3, c2])
    c7 = conv_block(u3, filters*2, name="dec3")

    u4 = layers.Conv2DTranspose(filters, (2,2), strides=(2,2), padding="same")(c7)
    u4 = layers.concatenate([u4, c1])
    c8 = conv_block(u4, filters, name="dec4")

    if n_classes == 1:
        outputs = layers.Conv2D(1, (1,1), padding="same", activation="sigmoid")(c8)
    else:
        outputs = layers.Conv2D(n_classes, (1,1), padding="same", activation="softmax")(c8)

    model = models.Model(inputs, outputs, name="UNet")
    return model

## Metrics and Lose

In [7]:
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    return 0.5 * bce + 0.5 * dice_loss(y_true, y_pred)

def iou_metric(y_true, y_pred, thresh=0.5, smooth=1e-6):
    y_pred = tf.cast(y_pred > thresh, tf.float32)
    intersection = K.sum(y_true * y_pred)
    union = K.sum(y_true) + K.sum(y_pred) - intersection
    return (intersection + smooth) / (union + smooth)

## Training Prepare 


In [8]:
pairs = discover_pairs_from_csv(CSV_PATH, IMAGES_DIR, MASKS_DIR)
print(f"Found {len(pairs)} image-mask pairs.")

# 2) Train/Val split
random.shuffle(pairs)
val_split = 0.15
n_val = int(len(pairs)*val_split)
val_pairs = pairs[:n_val]
train_pairs = pairs[n_val:]
print(f"Train: {len(train_pairs)}, Val: {len(val_pairs)}")

# 3) Create datasets
train_ds = make_dataset(train_pairs, batch_size=BATCH_SIZE, augment=True, shuffle=True)
val_ds   = make_dataset(val_pairs, batch_size=BATCH_SIZE, augment=False, shuffle=False)

Found 290 image-mask pairs.
Train: 247, Val: 43


## Model , Compile, Callbacks

In [12]:
# 4) Build model
model = build_unet(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), filters=32, dropout=0.1, n_classes=N_CLASSES)
model.summary()

# 5) Compile
if N_CLASSES == 1:
    loss = bce_dice_loss
    metrics = [dice_coef, lambda yt, yp: iou_metric(yt, yp), tf.keras.metrics.BinaryAccuracy(name="bin_acc")]
else:
    loss = tf.keras.losses.CategoricalCrossentropy()
    metrics = [tf.keras.metrics.CategoricalAccuracy(name="cat_acc")]

model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss=loss, metrics=metrics)

# 6) Callbacks
os.makedirs("checkpoints", exist_ok=True)
checkpoint_cb = ModelCheckpoint("checkpoints/unet_best.h5", save_best_only=True, monitor="val_loss", mode="min")
reduce_cb = ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-7, verbose=1)
early_cb = EarlyStopping(monitor="val_loss", patience=12, restore_best_weights=True, verbose=1)

# 7) Fit
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[checkpoint_cb, reduce_cb, early_cb]
)


Epoch 1/40


InvalidArgumentError: Graph execution error:

Detected at node IteratorGetNext defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel_launcher.py", line 18, in <module>

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\traitlets\config\application.py", line 1075, in launch_instance

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelapp.py", line 739, in start

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\tornado\platform\asyncio.py", line 205, in start

  File "c:\Users\PC\anaconda3\envs\image_compression\Lib\asyncio\base_events.py", line 640, in run_forever

  File "c:\Users\PC\anaconda3\envs\image_compression\Lib\asyncio\base_events.py", line 1992, in _run_once

  File "c:\Users\PC\anaconda3\envs\image_compression\Lib\asyncio\events.py", line 88, in _run

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 534, in process_one

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\ipkernel.py", line 362, in execute_request

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 778, in execute_request

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\ipkernel.py", line 449, in do_execute

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\ipykernel\zmqshell.py", line 549, in run_cell

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3075, in run_cell

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3130, in _run_cell

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\IPython\core\async_helpers.py", line 128, in _pseudo_sync_runner

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3334, in run_cell_async

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3517, in run_ast_nodes

  File "C:\Users\PC\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3577, in run_code

  File "C:\Users\PC\AppData\Local\Temp\ipykernel_4556\3430092765.py", line 22, in <module>

  File "c:\Users\PC\anaconda3\envs\image_compression\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "c:\Users\PC\anaconda3\envs\image_compression\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 377, in fit

  File "c:\Users\PC\anaconda3\envs\image_compression\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 220, in function

  File "c:\Users\PC\anaconda3\envs\image_compression\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 133, in multi_step_on_iterator

Incompatible shapes at component 0: expected [?,256,256,3] but got [8,1,256,256,3].
	 [[{{node IteratorGetNext}}]] [Op:__inference_multi_step_on_iterator_82478]

## Save model and Testing

In [None]:
model.save("unet_flood.h5")
print("Saved model to unet_final.h5")

# 9) Quick evaluation on a few validation samples (visual check)
import matplotlib.pyplot as plt

def show_sample(model, ds, n=3):
    for images, masks in ds.unbatch().batch(1).take(n):
        preds = model.predict(images)
        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1); plt.title("Image"); plt.imshow(images[0]); plt.axis("off")
        plt.subplot(1,3,2); plt.title("Mask"); plt.imshow(tf.squeeze(masks[0]), cmap="gray"); plt.axis("off")
        plt.subplot(1,3,3); plt.title("Pred"); plt.imshow(tf.squeeze(preds[0])>0.5, cmap="gray"); plt.axis("off")
        plt.show()

# Show 3 val samples
show_sample(model, val_ds, n=3)

In [16]:
# Full U-Net segmentation training script (copy into one notebook cell)
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

# -------------------------
# CONFIG
# -------------------------
ROOT = r"D:\WORK\Python\Project\CV\flood_map_segmentation\floodmap_segmentation\data"            # <-- set this to the folder that contains metadata.csv, images/, masks/
CSV_PATH = os.path.join(ROOT, "metadata.csv")
IMAGES_DIR = os.path.join(ROOT, "Image")   # .jpg
MASKS_DIR  = os.path.join(ROOT, "Mask")    # .png

IMAGE_SIZE = 256
BATCH_SIZE = 8
EPOCHS = 40
SEED = 42
AUTOTUNE = tf.data.AUTOTUNE
N_CLASSES = 1        # 1 for binary segmentation (sigmoid). If multiclass, set >1 and adjust masks accordingly.

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)


# -------------------------
# HELPERS: parse metadata.csv to pairs
# -------------------------
def discover_pairs_from_csv(csv_path, images_dir, masks_dir):
    """
    Reads CSV and tries to infer image & mask filenames.
    Accepts common column names, or falls back to using first two columns.
    Returns list of (image_path, mask_path).
    """
    df = pd.read_csv(csv_path)
    # common column names
    possible_img_cols = ['image', 'img', 'image_filename', 'image_file', 'filename', 'file']
    possible_mask_cols = ['mask', 'mask_filename', 'mask_file', 'mask_path', 'segmentation']

    img_col = None
    mask_col = None
    for c in possible_img_cols:
        if c in df.columns:
            img_col = c
            break
    for c in possible_mask_cols:
        if c in df.columns:
            mask_col = c
            break

    if img_col is None or mask_col is None:
        # fallback: use first two columns
        if len(df.columns) >= 2:
            img_col, mask_col = df.columns[0], df.columns[1]
        else:
            raise ValueError("Could not detect image/mask columns from CSV. Please ensure metadata.csv contains filenames.")

    imgs = df[img_col].astype(str).tolist()
    masks = df[mask_col].astype(str).tolist()

    pairs = []
    for im, m in zip(imgs, masks):
        im_path = im if os.path.isabs(im) else os.path.join(images_dir, im)
        m_path  = m  if os.path.isabs(m) else os.path.join(masks_dir, m)
        if not os.path.exists(im_path):
            raise FileNotFoundError(f"Image not found: {im_path}")
        if not os.path.exists(m_path):
            raise FileNotFoundError(f"Mask not found: {m_path}")
        pairs.append((im_path, m_path))
    return pairs

# -------------------------
# DATA LOADING & PREPROCESSING
# -------------------------
def read_image(path, image_size=IMAGE_SIZE):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)  # [0,1]
    image = tf.image.resize(image, [image_size, image_size])
    return image

def read_mask(path, image_size=IMAGE_SIZE, n_classes=N_CLASSES):
    mask = tf.io.read_file(path)
    # Masks are PNG where each pixel may be 0 or 255 (or class index)
    mask = tf.image.decode_png(mask, channels=1)  # keep 1 channel
    # convert to float32
    mask = tf.image.convert_image_dtype(mask, tf.float32)  # 0..1
    # If mask is 0/255, convert to 0/1 by thresholding at 0.5
    mask = tf.image.resize(mask, [image_size, image_size], method='nearest')
    mask = tf.where(mask>=0.5, 1.0, 0.0)
    if n_classes == 1:
        return mask
    else:
        # For multi-class integer masks, we expect integer values (0..n_classes-1)
        # Convert to one-hot if requested
        mask_int = tf.cast(tf.squeeze(mask, axis=-1) * (n_classes-1), tf.int32)  # attempt mapping 0..1 -> 0..n_classes-1
        mask_oh = tf.one_hot(mask_int, depth=n_classes)
        return tf.cast(mask_oh, tf.float32)

def make_dataset(pairs, batch_size=BATCH_SIZE, augment=False, shuffle=True):
    image_paths = [p for p, m in pairs]
    mask_paths  = [m for p, m in pairs]
    ds = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(pairs), seed=SEED, reshuffle_each_iteration=True)

    # Directly map preprocess_pair
    ds = ds.map(lambda i, m: preprocess_pair(i, m, augment=augment),
                num_parallel_calls=AUTOTUNE)

    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds


def preprocess_pair(image_path, mask_path, augment=False):
    image = read_image(image_path)
    mask = read_mask(mask_path)

    if augment:
        if tf.random.uniform(()) > 0.5:
            image = tf.image.flip_left_right(image)
            mask  = tf.image.flip_left_right(mask)
        if tf.random.uniform(()) > 0.5:
            image = tf.image.flip_up_down(image)
            mask  = tf.image.flip_up_down(mask)

    return image, mask


# -------------------------
# U-Net model
# -------------------------
def conv_block(x, filters, kernel_size=3, activation='relu', batchnorm=True, name=None):
    x = layers.Conv2D(filters, kernel_size, padding="same", use_bias=not batchnorm,
                      kernel_initializer="he_normal", name=(None if name is None else f"{name}_conv1"))(x)
    if batchnorm:
        x = layers.BatchNormalization(name=(None if name is None else f"{name}_bn1"))(x)
    x = layers.Activation(activation, name=(None if name is None else f"{name}_act1"))(x)

    x = layers.Conv2D(filters, kernel_size, padding="same", use_bias=not batchnorm,
                      kernel_initializer="he_normal", name=(None if name is None else f"{name}_conv2"))(x)
    if batchnorm:
        x = layers.BatchNormalization(name=(None if name is None else f"{name}_bn2"))(x)
    x = layers.Activation(activation, name=(None if name is None else f"{name}_act2"))(x)
    return x

def build_unet(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), filters=64, dropout=0.1, n_classes=N_CLASSES):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = conv_block(inputs, filters, name="enc1")
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = conv_block(p1, filters*2, name="enc2")
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = conv_block(p2, filters*4, name="enc3")
    p3 = layers.MaxPooling2D((2,2))(c3)

    c4 = conv_block(p3, filters*8, name="enc4")
    p4 = layers.MaxPooling2D((2,2))(c4)

    # Bottleneck
    b = conv_block(p4, filters*16, name="bottleneck")
    b = layers.Dropout(dropout)(b)

    # Decoder
    u1 = layers.Conv2DTranspose(filters*8, (2,2), strides=(2,2), padding="same")(b)
    u1 = layers.concatenate([u1, c4])
    c5 = conv_block(u1, filters*8, name="dec1")

    u2 = layers.Conv2DTranspose(filters*4, (2,2), strides=(2,2), padding="same")(c5)
    u2 = layers.concatenate([u2, c3])
    c6 = conv_block(u2, filters*4, name="dec2")

    u3 = layers.Conv2DTranspose(filters*2, (2,2), strides=(2,2), padding="same")(c6)
    u3 = layers.concatenate([u3, c2])
    c7 = conv_block(u3, filters*2, name="dec3")

    u4 = layers.Conv2DTranspose(filters, (2,2), strides=(2,2), padding="same")(c7)
    u4 = layers.concatenate([u4, c1])
    c8 = conv_block(u4, filters, name="dec4")

    if n_classes == 1:
        outputs = layers.Conv2D(1, (1,1), padding="same", activation="sigmoid")(c8)
    else:
        outputs = layers.Conv2D(n_classes, (1,1), padding="same", activation="softmax")(c8)

    model = models.Model(inputs, outputs, name="UNet")
    return model

# -------------------------
# Losses & Metrics
# -------------------------
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    return 0.5 * bce + 0.5 * dice_loss(y_true, y_pred)

def iou_metric(y_true, y_pred, thresh=0.5, smooth=1e-6):
    y_pred = tf.cast(y_pred > thresh, tf.float32)
    intersection = K.sum(y_true * y_pred)
    union = K.sum(y_true) + K.sum(y_pred) - intersection
    return (intersection + smooth) / (union + smooth)

# -------------------------
# Put it all together and run
# -------------------------
# 1) discover pairs
pairs = discover_pairs_from_csv(CSV_PATH, IMAGES_DIR, MASKS_DIR)
print(f"Found {len(pairs)} image-mask pairs.")

# 2) Train/Val split
random.shuffle(pairs)
val_split = 0.15
n_val = int(len(pairs)*val_split)
val_pairs = pairs[:n_val]
train_pairs = pairs[n_val:]
print(f"Train: {len(train_pairs)}, Val: {len(val_pairs)}")

# 3) Create datasets
train_ds = make_dataset(train_pairs, batch_size=BATCH_SIZE, augment=True, shuffle=True)
val_ds   = make_dataset(val_pairs, batch_size=BATCH_SIZE, augment=False, shuffle=False)

# 4) Build model
model = build_unet(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), filters=32, dropout=0.1, n_classes=N_CLASSES)
model.summary()

# 5) Compile
if N_CLASSES == 1:
    loss = bce_dice_loss
    metrics = [dice_coef, lambda yt, yp: iou_metric(yt, yp), tf.keras.metrics.BinaryAccuracy(name="bin_acc")]
else:
    loss = tf.keras.losses.CategoricalCrossentropy()
    metrics = [tf.keras.metrics.CategoricalAccuracy(name="cat_acc")]

model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss=loss, metrics=metrics)

# 6) Callbacks
os.makedirs("checkpoints", exist_ok=True)
checkpoint_cb = ModelCheckpoint("checkpoints/unet_best.h5", save_best_only=True, monitor="val_loss", mode="min")
reduce_cb = ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-7, verbose=1)
early_cb = EarlyStopping(monitor="val_loss", patience=12, restore_best_weights=True, verbose=1)

# 7) Fit
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[checkpoint_cb, reduce_cb, early_cb]
)

# 8) Save final model
model.save("unet_final.h5")
print("Saved model to unet_final.h5")

# 9) Quick evaluation on a few validation samples (visual check)
import matplotlib.pyplot as plt

def show_sample(model, ds, n=3):
    for images, masks in ds.unbatch().batch(1).take(n):
        preds = model.predict(images)
        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1); plt.title("Image"); plt.imshow(images[0]); plt.axis("off")
        plt.subplot(1,3,2); plt.title("Mask"); plt.imshow(tf.squeeze(masks[0]), cmap="gray"); plt.axis("off")
        plt.subplot(1,3,3); plt.title("Pred"); plt.imshow(tf.squeeze(preds[0])>0.5, cmap="gray"); plt.axis("off")
        plt.show()

# Show 3 val samples
show_sample(model, val_ds, n=3)


Found 290 image-mask pairs.
Train: 247, Val: 43


Epoch 1/40


InvalidArgumentError: Graph execution error:

Detected at node resize/ResizeBilinear defined at (most recent call last):
<stack traces unavailable>
input must be 4-dimensional[1,1,575,862,3]
	 [[{{node resize/ResizeBilinear}}]]
	 [[IteratorGetNext]] [Op:__inference_multi_step_on_iterator_119078]