# Projet 8

In [None]:
train_mode = True

In [None]:
# --- Simple dataset scanner (counts only, no filenames, no JSON) ---
from pathlib import Path
import os
from collections import Counter

# >>> Change this to your dataset root (WSL path) <<<
ROOT = Path("../data")

IGNORE_HIDDEN = True       # ignore .git, __pycache__, etc.
MAX_DIRS_TO_SHOW = 80      # limit directory lines for readability

total_files = 0
total_dirs = 0
by_ext = Counter()
by_dir = {}

for dirpath, dirnames, filenames in os.walk(ROOT):
    # optionally hide hidden/internal dirs
    if IGNORE_HIDDEN:
        dirnames[:] = [d for d in dirnames if not d.startswith(".") and not d.startswith("__")]
    total_dirs += 1
    rel = Path(dirpath).relative_to(ROOT) if Path(dirpath) != ROOT else Path(".")
    by_dir[str(rel)] = len(filenames)
    for fn in filenames:
        by_ext[Path(fn).suffix.lower()] += 1
    total_files += len(filenames)

print(f"[ROOT] {ROOT}")
print(f"dirs={total_dirs:,}  files={total_files:,}\n")

print("By extension (top 10):")
for ext, n in by_ext.most_common(10):
    print(f"  {ext or '(no ext)'}: {n:,}")
print()

print(f"Directory counts (first {MAX_DIRS_TO_SHOW}):")
for i, (rel, n) in enumerate(sorted(by_dir.items())):
    if i >= MAX_DIRS_TO_SHOW:
        print("  ... (truncated)")
        break
    print(f"  {rel}: {n}")

# -------- Cityscapes mini-summary (counts only) --------
def count_pattern(base: Path, split: str, suffix: str) -> int:
    split_dir = base / split
    total = 0
    if split_dir.exists():
        for city_dir in split_dir.iterdir():
            if city_dir.is_dir():
                total += sum(1 for p in city_dir.iterdir()
                             if p.is_file() and p.name.endswith(suffix))
    return total

print("\n[Cityscapes summary]")
for split in ("train", "val", "test"):
    gt_base = ROOT / "gtFine"
    left_base = ROOT / "leftImg8bit"
    label = count_pattern(gt_base, split, "_gtFine_labelIds.png")
    color = count_pattern(gt_base, split, "_gtFine_color.png")
    inst  = count_pattern(gt_base, split, "_gtFine_instanceIds.png")
    poly  = count_pattern(gt_base, split, "_gtFine_polygons.json")
    left  = count_pattern(left_base, split, "_leftImg8bit.png")
    print(f"  {split:5s}: leftImg8bit={left:6d}  labelIds={label:6d}  color={color:6d}  instanceIds={inst:6d}  polygons.json={poly:6d}")


In [None]:
from pathlib import Path

ROOT = Path("../data")  # adapte si besoin
SUF_LEFT = "_leftImg8bit.png"
SUF_LBL  = "_gtFine_labelIds.png"

def base_id(name: str) -> str:
    return name[:-len(SUF_LEFT)] if name.endswith(SUF_LEFT) else name[:-len(SUF_LBL)]

def split_counts(split: str):
    left_dir = ROOT / "leftImg8bit" / split
    lbl_dir  = ROOT / "gtFine"      / split
    left = sorted(left_dir.rglob(f"*{SUF_LEFT}")) if left_dir.exists() else []
    lbl  = sorted(lbl_dir.rglob (f"*{SUF_LBL}" )) if lbl_dir.exists()  else []
    left_ids = {base_id(p.name) for p in left}
    lbl_ids  = {base_id(p.name) for p in lbl}
    paired = left_ids & lbl_ids
    print(f"{split:<5} | left={len(left):4d}  labels={len(lbl):4d}  paired={len(paired):4d}")

for sp in ("train", "val", "test"):
    split_counts(sp)



In [None]:
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random

ROOT = Path("../data")

PALETTE = {
    7:(128,64,128), 8:(244,35,232), 11:(70,70,70), 12:(102,102,156), 13:(190,153,153),
    17:(153,153,153), 19:(250,170,30), 20:(220,220,0), 21:(107,142,35), 22:(152,251,152),
    23:(70,130,180), 24:(220,20,60), 25:(255,0,0), 26:(0,0,142), 27:(0,0,70),
    28:(0,60,100), 31:(0,80,100), 32:(0,0,230), 33:(119,11,32),
}

def pairs(split="val"):
    lbls = sorted((ROOT/"gtFine"/split).rglob("*_gtFine_labelIds.png"))
    out = []
    for lp in lbls:
        stem = lp.name.replace("_gtFine_labelIds.png", "")
        city = lp.parent.name
        left = ROOT/"leftImg8bit"/split/city/(stem+"_leftImg8bit.png")
        if left.exists():
            out.append((left, lp))
    return out

def colorize(ids: np.ndarray) -> Image.Image:
    h, w = ids.shape
    rgb = np.zeros((h, w, 3), np.uint8)
    for k, c in PALETTE.items():
        rgb[ids == k] = c
    return Image.fromarray(rgb, "RGB")

def overlay(img: Image.Image, mask_rgb: Image.Image, alpha=0.5) -> Image.Image:
    a = np.asarray(img.convert("RGB"), np.float32)
    b = np.asarray(mask_rgb, np.float32)
    return Image.fromarray(np.clip((1-alpha)*a + alpha*b, 0, 255).astype(np.uint8))

samples = pairs("val")
assert samples, "No pairs found — check your paths."
random.shuffle(samples)
k = 3

plt.figure(figsize=(15, 5*k))
for i, (left_p, lbl_p) in enumerate(samples[:k]):
    left = Image.open(left_p).convert("RGB")
    ids  = np.array(Image.open(lbl_p))
    mask = colorize(ids)
    over = overlay(left, mask, alpha=0.45)
    for j, (img, title) in enumerate([(left,"leftImg8bit"),(mask,"labelIds (colored)"),(over,"overlay")]):
        ax = plt.subplot(k, 3, i*3 + j + 1)
        ax.imshow(img); ax.set_title(f"{left_p.parent.name} — {title}", fontsize=10); ax.axis("off")
plt.tight_layout(); plt.show()



### Remapping Cityscapes 32→8 classes

Vérifier la balance des classes !!!!


In [None]:
# --- 32→8 mapping (Cityscapes labelIds -> 8-class IDs), ignore = 255 ---
import numpy as np

CS_LABELID_TO_8 = {
    # 0..5 (voids) -> ignore by LUT fill (no need to list)
    6: 0,
    7: 0,  9: 0, 10: 0,           # road-like: road, parking, rail track
    8: 1,                         # sidewalk
    11: 2, 12: 2, 13: 2, 14: 2, 15: 2, 16: 2,   # building + barriers
    17: 3, 18: 3, 19: 3, 20: 3,                 # traffic objs (pole/ts/tl)
    21: 4, 22: 4,                                 # vegetation + terrain
    23: 5,                                       # sky
    24: 6, 25: 6,                                 # person + rider
    26: 7, 27: 7, 28: 7, 29: 7, 30: 7, 31: 7, 32: 7, 33: 7,  # vehicles
}

def build_labelid_to8_lut(ignore_value: int = 255) -> np.ndarray:
    """Create a 256-entry LUT mapping Cityscapes labelIds -> {0..7} or 255(ignore)."""
    lut = np.full(256, ignore_value, dtype=np.uint8)
    for k, v in CS_LABELID_TO_8.items():
        lut[k] = v
    return lut

LUT_32TO8 = build_labelid_to8_lut(ignore_value=255)

def remap_labelids_to8(arr_uint16: np.ndarray) -> np.ndarray:
    """Vectorized remap of HxW labelIds (uint16/uint8) to 8-class IDs with 255 ignore."""
    arr = arr_uint16.astype(np.uint16)
    arr = np.minimum(arr, 255).astype(np.uint8)
    return LUT_32TO8[arr]


In [None]:
PALETTE_8 = {
    0:(128,64,128),   # road
    1:(244,35,232),   # sidewalk
    2:(70,70,70),     # building+barrier
    3:(220,220,0),    # traffic objs
    4:(107,142,35),   # vegetation/terrain
    5:(70,130,180),   # sky
    6:(220,20,60),    # person+rider
    7:(0,0,142),      # vehicle
}


In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

def colorize_8(label8: np.ndarray, palette: dict) -> Image.Image:
    h, w = label8.shape
    rgb = np.zeros((h, w, 3), np.uint8)
    for k, c in palette.items():
        rgb[label8 == k] = c
    return Image.fromarray(rgb, "RGB")

sample_lbl = next(Path("../data/gtFine/val/frankfurt").glob("*_gtFine_labelIds.png"))
arr = np.array(Image.open(sample_lbl))
arr8 = remap_labelids_to8(arr)
plt.figure(figsize=(8,4)); plt.imshow(colorize_8(arr8, PALETTE_8)); plt.axis("off"); plt.title("8-class mask"); plt.show()


In [None]:
import os

# ↓↓↓ Quieter TensorFlow logs (set BEFORE importing tf)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"     # 0=all, 1=INFO off, 2=INFO+WARNING off, 3=all off
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"  # avoid grabbing all GPU memory
# Optional: disable oneDNN (removes the "oneDNN custom ops are on" line, and tiny numeric diffs)
# os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

import tensorflow as tf
from absl import logging as absl_logging
absl_logging.set_verbosity(absl_logging.ERROR)  # reduce absl spam

# (Optional) confirm GPU + set memory growth (extra safety)
gpus = tf.config.list_physical_devices("GPU")
for g in gpus:
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception:
        pass

print("TF:", tf.__version__, "| GPUs:", gpus)

# ==========================
# Cityscapes 32→8 remapping
# ==========================
import numpy as np

# 8 classes for embedded use (ignore=255):
# 0=road (7,9,10) | 1=sidewalk(8) | 2=building+barriers(11–16) | 3=traffic objs(17–20)
# 4=vegetation+terrain(21,22) | 5=sky(23) | 6=person+rider(24,25) | 7=vehicle(26–33)
CS_LABELID_TO_8 = {
    6:0,
    7:0, 9:0, 10:0,
    8:1,
    11:2, 12:2, 13:2, 14:2, 15:2, 16:2,
    17:3, 18:3, 19:3, 20:3,
    21:4, 22:4,
    23:5,
    24:6, 25:6,
    26:7, 27:7, 28:7, 29:7, 30:7, 31:7, 32:7, 33:7,
}

def build_labelid_to8_lut(ignore_value: int = 255) -> np.ndarray:
    lut = np.full(256, ignore_value, dtype=np.uint8)
    for k, v in CS_LABELID_TO_8.items():
        lut[k] = v
    return lut

LUT_32TO8 = build_labelid_to8_lut(ignore_value=255)
LUT_TF = tf.convert_to_tensor(LUT_32TO8, dtype=tf.uint8)  # shape [256]

# ===================
# Dataset (tf.data)
# ===================
from pathlib import Path
ROOT = Path("../data")               # <<< change if needed (WSL path)
INPUT_SIZE = (512, 1024)             # (H, W)
BATCH_SIZE = 4

SUF_LEFT = "_leftImg8bit.png"
SUF_LBL  = "_gtFine_labelIds.png"

def list_pairs(split: str):
    """Return two aligned lists: left paths and label paths for a given split."""
    lefts, labels = [], []
    lbl_root = ROOT / "gtFine" / split
    for lbl in sorted(lbl_root.rglob(f"*{SUF_LBL}")):
        city = lbl.parent.name
        stem = lbl.name.replace(SUF_LBL, "")
        left = ROOT / "leftImg8bit" / split / city / f"{stem}{SUF_LEFT}"
        if left.exists():
            lefts.append(str(left))
            labels.append(str(lbl))
    if not lefts:
        raise FileNotFoundError(f"No pairs found for split='{split}'. Check your paths under {ROOT}.")
    return lefts, labels

def decode_and_preprocess(left_path, lbl_path, training: bool):
    # 1) Read bytes
    left_bytes = tf.io.read_file(left_path)
    lbl_bytes  = tf.io.read_file(lbl_path)

    # 2) Decode
    img = tf.io.decode_png(left_bytes, channels=3)     # uint8 [H,W,3]
    lab = tf.io.decode_png(lbl_bytes,  channels=1)     # uint8/16 [H,W,1]

    # 3) To workable dtypes
    img = tf.image.convert_image_dtype(img, tf.float32)   # [0,1]
    lab = tf.cast(lab, tf.int32)                          # index dtype for LUT

    # 4) Remap 32→8 via LUT
    lab_clipped = tf.minimum(lab, 255)
    lab8 = tf.gather(LUT_TF, lab_clipped)                 # uint8 [H,W,1]
    lab8 = tf.squeeze(lab8, axis=-1)                      # uint8 [H,W]

    # 5) Simple augment (sync flip)
    if training:
        do_flip = tf.random.uniform(()) > 0.5
        img  = tf.cond(do_flip, lambda: tf.image.flip_left_right(img), lambda: img)
        lab8 = tf.cond(do_flip, lambda: tf.image.flip_left_right(lab8[..., None])[:, :, 0], lambda: lab8)

    # 6) Resize (labels in nearest, keep uint8)
    img  = tf.image.resize(img,  INPUT_SIZE, method="bilinear")
    lab8 = tf.cast(tf.image.resize(lab8[..., None], INPUT_SIZE, method="nearest")[:, :, 0], tf.uint8)

    # 7) Ignore handling → sample_weight (float32); labels safe (uint8→int32)
    ignore_val = tf.constant(255, dtype=tf.uint8)
    ignore = tf.equal(lab8, ignore_val)  # bool [H,W]

    weights = tf.where(ignore,
                       tf.zeros_like(lab8, dtype=tf.float32),
                       tf.ones_like(lab8,  dtype=tf.float32))              # float32 [H,W]

    lab8_safe = tf.where(ignore,
                         tf.zeros_like(lab8),   # uint8 0 (will be masked by weights anyway)
                         lab8)
    labels = tf.cast(lab8_safe, tf.int32)                                    # int32 [H,W]

    return img, labels, weights

def make_dataset(split: str, batch_size: int = BATCH_SIZE, training: bool = True) -> tf.data.Dataset:
    lefts, labels = list_pairs(split)
    ds = tf.data.Dataset.from_tensor_slices((lefts, labels))
    if training:
        ds = ds.shuffle(buffer_size=min(len(lefts), 2000), reshuffle_each_iteration=True)
    ds = ds.map(lambda l, y: decode_and_preprocess(l, y, training),
                num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=training)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

# ==============
# Smoke test
# ==============
train_ds = make_dataset("train", batch_size=2, training=True)
xb, yb, wb = next(iter(train_ds))
print("x:", xb.shape, xb.dtype, "| y:", yb.shape, yb.dtype, "| w:", wb.shape, wb.dtype)

# Example compile/fit (model must output logits with 8 channels)
# loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# opt  = tf.keras.optimizers.Adam(1e-3)
# model.compile(optimizer=opt, loss=loss, metrics=["accuracy"])
# model.fit(train_ds,
#           validation_data=make_dataset("val", batch_size=2, training=False),
#           epochs=1)


In [None]:
# ==== Class balance for Cityscapes 8 classes (with ignore=255) ====
from pathlib import Path
import numpy as np
from PIL import Image

# ---- Config (adapt if needed) ----
ROOT = Path("../data")          # dataset root (WSL path)
SPLIT = "train"                 # "train" | "val" | "test"
SUF_LBL = "_gtFine_labelIds.png"

# 8-class names (your mapping)
CLASS8_NAMES = [
    "road", "sidewalk", "building+barriers", "traffic-objs",
    "vegetation+terrain", "sky", "person+rider", "vehicle"
]

# If LUT_32TO8 not in scope, (re)build it quickly:
try:
    LUT_32TO8
except NameError:
    CS_LABELID_TO_8 = {
        6:0, 7:0, 9:0, 10:0, 8:1, 11:2,12:2,13:2,14:2,15:2,16:2,
        17:3,18:3,19:3,20:3, 21:4,22:4, 23:5, 24:6,25:6,
        26:7,27:7,28:7,29:7,30:7,31:7,32:7,33:7,
    }
    LUT_32TO8 = np.full(256, 255, dtype=np.uint8)
    for k, v in CS_LABELID_TO_8.items():
        LUT_32TO8[k] = v

def remap_to8_np(arr_label_ids: np.ndarray) -> np.ndarray:
    """arr_label_ids: HxW uint16/uint8 -> HxW uint8 in {0..7,255}"""
    arr = arr_label_ids.astype(np.uint16)
    arr = np.minimum(arr, 255).astype(np.uint8)
    return LUT_32TO8[arr]

def class_balance(split: str = SPLIT):
    lbl_paths = sorted((ROOT/"gtFine"/split).rglob(f"*{SUF_LBL}"))
    assert lbl_paths, f"No labels found under {ROOT}/gtFine/{split}"
    counts = np.zeros(8, dtype=np.int64)
    ignore = 0
    for i, p in enumerate(lbl_paths, 1):
        lab = np.array(Image.open(p))         # (H,W) uint16/uint8
        lab8 = remap_to8_np(lab)              # (H,W) uint8
        m_ignore = (lab8 == 255)
        ignore += int(m_ignore.sum())
        # bincount only on valid pixels
        c = np.bincount(lab8[~m_ignore].ravel(), minlength=8)
        counts += c[:8]
        if i % 500 == 0 or i == len(lbl_paths):
            print(f"[{split}] processed {i}/{len(lbl_paths)} images...", end="\r")
    print()
    total_valid = int(counts.sum())
    total_pixels = total_valid + ignore
    freqs = counts / max(total_valid, 1)
    return counts, ignore, total_valid, total_pixels, freqs

counts, ignore, total_valid, total_pixels, freqs = class_balance("train")

print("\n=== Class balance (train) ===")
for k, (name, n, f) in enumerate(zip(CLASS8_NAMES, counts, freqs)):
    print(f"{k}: {name:<20s}  pixels={n:,}   freq={f:.4%}")
print(f"\nignore pixels (==255): {ignore:,}")
print(f"total valid pixels:     {total_valid:,}")
print(f"total pixels (incl. ignore): {total_pixels:,}")

# ---- Optional: derive class weights ----
# Inverse-frequency, normalized to mean=1 (good starting point)
weights_inv = (1.0 / np.maximum(freqs, 1e-12))
weights_inv = weights_inv / weights_inv.mean()
print("\nSuggested class weights (inverse-freq, mean≈1):")
for k, (name, w) in enumerate(zip(CLASS8_NAMES, weights_inv)):
    print(f"{k}: {name:<20s}  w={w:.3f}")

# Median-frequency balancing (alternative)
median_f = np.median(freqs[freqs > 0])
weights_med = median_f / np.maximum(freqs, 1e-12)
weights_med = weights_med / weights_med.mean()
print("\nSuggested class weights (median-freq, mean≈1):")
for k, (name, w) in enumerate(zip(CLASS8_NAMES, weights_med)):
    print(f"{k}: {name:<20s}  w={w:.3f}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# expects: counts (np.array shape [8]), freqs (shape [8]), ignore (int),
#          total_valid (int), total_pixels (int), CLASS8_NAMES (list of 8 str)

# ---- 1) Bar chart des 8 classes (trié décroissant) ----
order = np.argsort(freqs)[::-1]
names_sorted = [CLASS8_NAMES[i] for i in order]
freqs_sorted = freqs[order]
counts_sorted = counts[order]

plt.figure(figsize=(10, 5))
bars = plt.bar(range(len(names_sorted)), freqs_sorted)  # no explicit colors
plt.xticks(range(len(names_sorted)), names_sorted, rotation=20, ha="right")
plt.ylabel("Frequency (share of valid pixels)")
plt.title("Cityscapes (train) — Class balance (8 classes)")

# annotations: % + millions de pixels
for i, (b, f, c) in enumerate(zip(bars, freqs_sorted, counts_sorted)):
    plt.text(b.get_x() + b.get_width()/2,
             b.get_height() + 0.002,
             f"{f*100:.1f}%\n{c/1e6:.1f}M",
             ha="center", va="bottom", fontsize=9)

plt.ylim(0, max(freqs_sorted)*1.15)
plt.tight_layout()
plt.show()

# ---- 2) Valid vs Ignore (pour info) ----
valid_share = total_valid / total_pixels
ignore_share = 1.0 - valid_share

plt.figure(figsize=(5, 4))
bars2 = plt.bar([0,1], [valid_share, ignore_share])
plt.xticks([0,1], ["valid", "ignore (==255)"])
plt.ylabel("Share of total pixels")
plt.title("Valid vs Ignore pixels (train)")

for x, v in zip([0,1], [valid_share, ignore_share]):
    plt.text(x, v + 0.005, f"{v*100:.1f}%", ha="center", va="bottom")

plt.ylim(0, 1.05)
plt.tight_layout()
plt.show()


In [None]:
# In notebook (Python)
from scripts.config import DataConfig, TrainConfig, AugmentConfig
from scripts.train import train

data_cfg = DataConfig(
    data_root="../data",
    height=512, width=1024,
    batch_size=2,
    max_train_samples=100,
    max_val_samples=100,
)

# Essai 1 : DeepLab ResNet50, augmentation légère
aug_cfg = AugmentConfig(
    enabled=False, hflip=True, vflip=False,
    random_rotate_deg=3.0,
    random_scale_min=0.85, random_scale_max=1.20,
    random_crop=True,
    brightness_delta=0.10, contrast_delta=0.10, saturation_delta=0.05, hue_delta=0.02,
    gaussian_noise_std=0.00
)
train_cfg = TrainConfig(lr=3e-4, epochs=60, optimizer="adam", exp_name="cityscapes-seg-8cls")


## Contrôle visuel de la data augmentation

La cellule suivante pioche quelques paires image/masque, applique le pipeline Albumentations configuré (\`aug_cfg\`) et affiche les versions redimensionnées vs augmentées pour vérifier que les masques restent alignés.


In [None]:
if not train_mode :
    import random
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    from PIL import Image

    from scripts.config import AugmentConfig
    from scripts.augment import build_augment_fn
    from scripts.remap import build_cityscapes_8cls_lut, remap_labels

    lut = build_cityscapes_8cls_lut(data_cfg.ignore_index)
    no_aug_fn = build_augment_fn(AugmentConfig(enabled=False), data_cfg.height, data_cfg.width, data_cfg.ignore_index)
    augmented_fn = build_augment_fn(aug_cfg, data_cfg.height, data_cfg.width, data_cfg.ignore_index)

    def remap_to_training_ids(mask_np):
        mask_tf = tf.convert_to_tensor(mask_np, dtype=tf.int32)
        return remap_labels(mask_tf, lut).numpy()

    def colorize_mask(mask_np, palette=PALETTE_8, ignore_value=data_cfg.ignore_index):
        rgb = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
        for cls_id, color in palette.items():
            rgb[mask_np == cls_id] = color
        if ignore_value is not None:
            rgb[mask_np == ignore_value] = (0, 0, 0)
        return rgb

    def overlay_mask(image_uint8, mask_uint8, alpha=0.45):
        colored = colorize_mask(mask_uint8)
        return np.clip((1.0 - alpha) * image_uint8 + alpha * colored, 0, 255).astype(np.uint8)

    samples = pairs("train")
    assert samples, "Aucun couple image/masque trouvé — vérifie le dossier data."

    random.shuffle(samples)
    num_rows = min(3, len(samples))
    fig, axes = plt.subplots(num_rows, 6, figsize=(22, 5 * num_rows))
    if num_rows == 1:
        axes = np.expand_dims(axes, axis=0)

    for row, (left_path, lbl_path) in enumerate(samples[:num_rows]):
        raw_img = np.array(Image.open(left_path).convert("RGB"), dtype=np.float32) / 255.0
        raw_mask = np.array(Image.open(lbl_path), dtype=np.int32)

        mask8 = remap_to_training_ids(raw_mask)

        img_tf = tf.convert_to_tensor(raw_img, dtype=tf.float32)
        mask_tf = tf.convert_to_tensor(mask8, dtype=tf.int32)

        base_img, base_mask = no_aug_fn(img_tf, mask_tf)
        aug_img, aug_mask = augmented_fn(img_tf, mask_tf)

        base_img_u8 = np.clip(base_img.numpy() * 255.0, 0, 255).astype(np.uint8)
        aug_img_u8 = np.clip(aug_img.numpy() * 255.0, 0, 255).astype(np.uint8)
        base_mask_u8 = base_mask.numpy().astype(np.uint8)
        aug_mask_u8 = aug_mask.numpy().astype(np.uint8)

        base_mask_rgb = colorize_mask(base_mask_u8)
        aug_mask_rgb = colorize_mask(aug_mask_u8)

        axes[row, 0].imshow(base_img_u8)
        axes[row, 0].set_title("Image (resize)")
        axes[row, 1].imshow(base_mask_rgb)
        axes[row, 1].set_title("Masque (resize)")
        axes[row, 2].imshow(overlay_mask(base_img_u8, base_mask_u8))
        axes[row, 2].set_title("Overlay resize")
        axes[row, 3].imshow(aug_img_u8)
        axes[row, 3].set_title("Image augmentée")
        axes[row, 4].imshow(aug_mask_rgb)
        axes[row, 4].set_title("Masque augmenté")
        axes[row, 5].imshow(overlay_mask(aug_img_u8, aug_mask_u8))
        axes[row, 5].set_title("Overlay augmentée")

        for ax in axes[row]:
            ax.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:

    from notebook.scripts.data import build_dataset

    val_ds = build_dataset(
        data_cfg,
        AugmentConfig(enabled=False),
        split="val",
        training=False,
    )

    images, masks, _ = next(iter(val_ds))
    images_np = images.numpy()
    masks_np = masks.numpy()

    num_samples = min(3, images_np.shape[0])
    plt.figure(figsize=(15, 5 * num_samples))

    for i in range(num_samples):
        image = images_np[i]
        mask = masks_np[i]

        if image.dtype != np.uint8:
            image_u8 = np.clip(image * 255.0, 0, 255).astype(np.uint8)
        else:
            image_u8 = image

        mask_u8 = mask.astype(np.uint8)
        mask_rgb = colorize_mask(mask_u8)
        overlay_rgb = overlay_mask(image_u8, mask_u8)

        overlay_on_black = overlay_mask(np.zeros_like(image_u8), mask_u8, alpha=1.0)
        assert np.array_equal(overlay_on_black, mask_rgb), "Overlay misaligned with mask (check dataset pipeline)."

        for j, (img, title) in enumerate([
            (image_u8, "Image (val)"),
            (mask_rgb, "Masque colorisé"),
            (overlay_rgb, "Overlay"),
        ]):
            ax = plt.subplot(num_samples, 3, i * 3 + j + 1)
            ax.imshow(img)
            ax.set_title(f"Échantillon {i + 1} — {title}")
            ax.axis("off")

    plt.tight_layout()
    plt.show()

### UNet Mini
Modèle U-Net compact utilisé comme référence rapide pour valider le pipeline d'entraînement.
Sa légèreté le rend adapté aux tests itératifs sur Cityscapes réduit.


In [None]:
    from dataclasses import replace
    unet_mini_cfg = replace(
        train_cfg,
        output_dir="artifacts/unet_mini",
    )
    train("unet_mini", data_cfg, unet_mini_cfg, aug_cfg)


### UNet VGG16
Architecture U-Net à encodeur VGG16 offrant une capacité accrue pour capturer des détails fins.
Elle reste compatible avec notre pipeline et sert de baseline plus profonde.

/!\ Le décodeur à été réduit sur plusieurs couches par soucis de consommation de vram


In [None]:
    unet_vgg16_cfg = replace(
        train_cfg,
        output_dir="artifacts/unet_vgg16",
    )
    train("unet_vgg16", data_cfg, unet_vgg16_cfg, aug_cfg)


### MobileDet Seg
Variante segmentation de MobileDet pensée pour l'inférence embarquée tout en conservant une précision correcte.
Ce modèle illustre un compromis agressif entre vitesse et qualité.


In [None]:
    mobiledet_seg_cfg = replace(
        train_cfg,
        output_dir="artifacts/mobiledet_seg",
    )
    train("mobiledet_seg", data_cfg, mobiledet_seg_cfg, aug_cfg)


### YOLOv9 Seg
Déclinaison segmentation de YOLOv9 visant une extraction simultanée des instances et des masques.
Nous l'évaluons pour mesurer le gain potentiel des architectures one-stage.


In [None]:
    yolov9_seg_cfg = replace(
        train_cfg,
        output_dir="artifacts/yolov9_seg",
    )
    train("yolov9_seg", data_cfg, yolov9_seg_cfg, aug_cfg)


## deeplab resnet50

In [None]:
    # Entraînement
    train("deeplab_resnet50", data_cfg, train_cfg, aug_cfg)

## 🧩 Vue d’ensemble des résultats

| Modèle                     |   Durée  | `masked_mIoU` (train) | `val_masked_mIoU` | `pix_acc` | `val_pix_acc` | `dice_coef` | `val_dice_coef` |
| :------------------------- | :------: | :-------------------: | :---------------: | :-------: | :-----------: | :---------: | :-------------: |
| **DeepLabV3+ (ResNet50)**  | 13.4 min |       **0.947**       |     **0.639**     | **0.989** |   **0.872**   |  **0.965**  |    **0.716**    |
| **YOLOv9_seg (simplifié)** | 10.5 min |         0.689         |       0.400       |   0.913   |     0.714     |    0.753    |      0.494      |
| **MobileDet_seg**          | 16.3 min |         0.938         |       0.502       |   0.987   |     0.779     |    0.953    |      0.600      |
| **U-Net VGG16**            | 29.7 min |         0.903         |       0.542       |   0.977   |     0.805     |    0.923    |      0.633      |
| **U-Net mini**             |  6.1 min |         0.563         |       0.319       |   0.851   |     0.634     |    0.650    |      0.407      |

---

## 🔍 Interprétation métrique par métrique

### 🟦 `masked_mIoU` (train)

* Mesure principale de segmentation (intersection sur union moyenne).
* Tous sauf U-Net mini > 0.9 en entraînement → bon apprentissage.
* U-Net mini (0.56) : trop léger, manque de capacité.

### 🟧 `val_masked_mIoU`

* Évalue la **généralisation**.
* DeepLab (0.639) est **nettement supérieur** aux autres.
* U-Net VGG16 (0.54) et MobileDet (0.50) suivent derrière.
* YOLOv9 seg (0.40) et U-Net mini (0.32) décrochent clairement.

### 🟩 `val_pix_acc`

* Corrélation assez bonne avec `val_mIoU`.
* DeepLab atteint 0.87 → très bonne segmentation globale.
* U-Net VGG16 ≈ 0.80 → correct.
* Les autres chutent < 0.78.

### 🟪 `val_dice_coef`

* Très proche du mIoU mais plus sensible aux petits objets.
* DeepLab ≈ 0.72 → cohérent avec sa bonne mIoU.
* U-Net VGG16 ≈ 0.63 et MobileDet ≈ 0.60 → acceptables.
* YOLOv9 ≈ 0.49, U-Net mini ≈ 0.40 → faibles.

---

## ⚖️ Analyse comparative

| Critère                               | Meilleur modèle                            |
| :------------------------------------ | :----------------------------------------- |
| **Précision globale (mIoU/Dice)**     | 🟢 **DeepLabV3+ ResNet50**                 |
| **Généralisation / stabilité val**    | 🟢 **DeepLabV3+ ResNet50**                 |
| **Compromis vitesse/qualité**         | 🟢 **MobileDet_seg** (plus léger, correct) |
| **Performance brute (haute qualité)** | 🟢 **U-Net VGG16** si VRAM suffisante      |
| **Légereté / prototypage rapide**     | 🟢 **U-Net mini**, mais précision faible   |

---

## Interprétation détaillée

### 🥇 **DeepLabV3+ (ResNet50)**

* **Meilleur équilibre** entre précision et stabilité.
* mIoU = 0.64 (val) et Dice = 0.72 (val) : excellents scores sur 8 classes.
* Surapprentissage modéré (train-val gap raisonnable).
* Très bonne capacité à capter les contours fins et la hiérarchie spatiale.
  ✅ **→ Modèle à garder comme référence.**

### 🥈 **U-Net VGG16**

* Très bon entraînement, mais écart train-val > 0.35 : léger overfit.
* Lourdeur mémoire (VGG16) mais résultats solides.
  ✅ Alternative si tu veux plus de stabilité visuelle (textures fines).

### 🥉 **MobileDet_seg**

* Performances correctes pour un modèle “mobile-like”.
* Bonne efficacité (seulement 16 min d’entraînement, résultats décents).
  🟡 Bon compromis si tu cibles l’inférence embarquée.

### ⚙️ **YOLOv9_seg**

* Correct mais sous-optimal : architecture pas parfaitement adaptée à la segmentation dense.
* Val mIoU = 0.40, Dice = 0.49 : pas suffisant pour une segmentation de qualité.
  🔴 À éviter pour cette tâche spécifique.

### ⚪ **U-Net mini**

* Très rapide mais sous-entraîné / sous-dimensionné.
* Mauvais scores val (mIoU = 0.32, Dice = 0.40).
  🔴 Bon pour tests rapides, pas pour production.

---

## 🧾 Conclusion

| Rang | Modèle                    | Pourquoi                                                  |
| :--: | :------------------------ | :-------------------------------------------------------- |
|  🥇  | **DeepLabV3+ (ResNet50)** | Meilleur équilibre précision / généralisation / stabilité |
|  🥈  | **U-Net VGG16**           | Très bon mais plus lourd, tendance à overfitter           |
|  🥉  | **MobileDet_seg**         | Légèreté et vitesse, mais précision un cran en dessous    |
|   4  | **YOLOv9_seg**            | Pas adapté à la segmentation dense                        |
|   5  | **U-Net mini**            | Trop limité, résultats faibles                            |

---

### 🔧 En résumé

> **DeepLabV3+ ResNet50** est le **meilleur modèle global** :
>
> * meilleures métriques de validation,
> * bon Dice et mIoU,
> * rapport vitesse/qualité très favorable,
> * faible overfit comparé à VGG16.


## 🚀 Entraînement final DeepLabV3+ (ResNet50) sur l'intégralité du dataset

Nous relançons DeepLabV3+ avec **toutes** les images `train/val` de Cityscapes (plus de limite `max_*_samples`).
Les artefacts (best/final) seront exportés dans `artifacts/deeplab_resnet50_full` et suivis dans **MLflow** (`artifacts/mlruns`).

In [None]:
# Configuration finale : dataset complet + sortie dédiée
final_data_cfg = replace(
    data_cfg,
    max_train_samples=None,
    max_val_samples=None,
)
final_train_cfg = replace(
    train_cfg,
    output_dir="artifacts/deeplab_resnet50_full",
    exp_name="cityscapes-seg-8cls-full",
)
final_aug_cfg = replace(
    aug_cfg,
    enabled=True,
)

# Lancement de l'entraînement complet (sauvegarde locale + tracking MLflow)
train("deeplab_resnet50", final_data_cfg, final_train_cfg, final_aug_cfg)

In [None]:
# Copie explicite du meilleur modèle pour l'API
from pathlib import Path
import shutil

best_model = Path("artifacts/deeplab_resnet50_full/deeplab_resnet50_best.keras")
final_model = Path("artifacts/deeplab_resnet50_full/deeplab_resnet50_final.keras")
api_export = Path("artifacts/api/deeplabv3plus_resnet50_full.keras")
api_export.parent.mkdir(parents=True, exist_ok=True)

if best_model.exists():
    shutil.copy2(best_model, api_export)
    print(f"✅ Modèle API (best) : {api_export}")
elif final_model.exists():
    shutil.copy2(final_model, api_export)
    print(f"⚠️ Best absent, export du modèle final : {api_export}")
else:
    raise FileNotFoundError("Aucun modèle entraîné trouvé. Lance d'abord la cellule d'entraînement.")