# 1. Setup (import paths so we can use src/)

In [None]:
# Make src importable from the notebook folder
import sys, pathlib
ROOT = pathlib.Path.cwd().parent  # notebooks/ -> project root
SRC = ROOT / "src"
sys.path.append(str(SRC))

print("ROOT:", ROOT)
print("SRC :", SRC)


# 2. Imports & Config

In [None]:
import tensorflow as tf
from pathlib import Path
import json

import config as cfg
import dataset as ds
from model import unet_model, bce_dice_loss, BinaryMeanIoU, DiceMetric

print("IMG_SIZE:", cfg.IMG_SIZE)
print("TRAIN_IMG_DIR:", cfg.TRAIN_IMG_DIR)
print("TRAIN_MASK_IMG_DIR:", cfg.TRAIN_MASK_IMG_DIR)
print("RESULTS_DIR:", cfg.RESULTS_DIR)


# 3. (Optional) GPU sanity + mixed precision

In [None]:
# Show GPUs
print("GPUs:", tf.config.list_physical_devices("GPU"))

# If you want mixed-precision (often helps on modern GPUs)
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
print("Policy:", mixed_precision.global_policy())


# 4. Load Datasets

In [None]:
train_ds, val_ds = ds.get_datasets(val_split=0.2, batch_size=cfg.BATCH_SIZE)
train_card = tf.data.experimental.cardinality(train_ds).numpy()
val_card = tf.data.experimental.cardinality(val_ds).numpy()
print("Train cardinality:", train_card, " Val cardinality:", val_card)


# 5. Build & Compile Model

In [None]:
model = unet_model(input_shape=(cfg.IMG_SIZE[0], cfg.IMG_SIZE[1], 1))
model.compile(
    optimizer=tf.keras.optimizers.Adam(cfg.LEARNING_RATE),
    loss=bce_dice_loss,
    metrics=["accuracy", BinaryMeanIoU(), DiceMetric()]
)
model.summary()


# 6. Callbacks (mirror train.py)

In [None]:
cfg.RESULTS_DIR.mkdir(parents=True, exist_ok=True)

ckpt = tf.keras.callbacks.ModelCheckpoint(
    filepath=str(cfg.MODEL_PATH),
    save_best_only=True,
    monitor="val_loss",
    mode="min",
    verbose=1
)
early = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=8,
    restore_best_weights=True,
    verbose=1
)
reduce = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=3,
    verbose=1,
    min_lr=1e-7
)
callbacks = [ckpt, early, reduce]


# 7. Train

In [None]:
print("Starting training…")
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=cfg.EPOCHS,
    callbacks=callbacks,
    verbose=1
)


# 8. Save History + Quick Plots

In [None]:
# Save history dict
with open(cfg.HISTORY_PATH, "w") as f:
    f.write(str(history.history))
print("History saved to:", cfg.HISTORY_PATH)

# Plot
import matplotlib.pyplot as plt

def plot_curve(hist, key, title=None):
    plt.figure()
    plt.plot(hist.history[key], label=f"train_{key}")
    if f"val_{key}" in hist.history:
        plt.plot(hist.history[f"val_{key}"], label=f"val_{key}")
    plt.xlabel("Epoch")
    plt.ylabel(key)
    if title: plt.title(title)
    plt.legend()
    plt.show()

plot_curve(history, "loss", "Loss")
plot_curve(history, "accuracy", "Accuracy")
plot_curve(history, "mean_io_u", "Mean IoU")
plot_curve(history, "dice_metric", "Dice (thresholded)")
