In [None]:
# utils.py
import os
import random
import shutil
from pathlib import Path
import matplotlib.pyplot as plt

def create_train_val_test_split(source_dir, target_dir, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42):
    """
    source_dir: folder with subfolders per class (all images)
    target_dir: output folder with train/val/test subfolders
    """
    random.seed(seed)
    src = Path(source_dir)
    tgt = Path(target_dir)
    if not src.exists():
        raise FileNotFoundError(source_dir)
    for split in ("train", "val", "test"):
        (tgt / split).mkdir(parents=True, exist_ok=True)
    for cls in [d for d in src.iterdir() if d.is_dir()]:
        imgs = list(cls.glob("*"))
        random.shuffle(imgs)
        n = len(imgs)
        ntrain = int(n * train_ratio)
        nval = int(n * val_ratio)
        # train
        for p in imgs[:ntrain]:
            dst = tgt / "train" / cls.name
            dst.mkdir(parents=True, exist_ok=True)
            shutil.copy(p, dst / p.name)
        # val
        for p in imgs[ntrain:ntrain+nval]:
            dst = tgt / "val" / cls.name
            dst.mkdir(parents=True, exist_ok=True)
            shutil.copy(p, dst / p.name)
        # test
        for p in imgs[ntrain+nval:]:
            dst = tgt / "test" / cls.name
            dst.mkdir(parents=True, exist_ok=True)
            shutil.copy(p, dst / p.name)

def plot_history(history, save_path=None):
    import matplotlib.pyplot as plt
    acc = history.history.get("accuracy", history.history.get("acc"))
    val_acc = history.history.get("val_accuracy")
    loss = history.history.get("loss")
    val_loss = history.history.get("val_loss")
    epochs = range(1, len(loss) + 1)
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(epochs, acc, "b-", label="train_acc")
    plt.plot(epochs, val_acc, "r--", label="val_acc")
    plt.legend()
    plt.subplot(1,2,2)
    plt.plot(epochs, loss, "b-", label="train_loss")
    plt.plot(epochs, val_loss, "r--", label="val_loss")
    plt.legend()
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()