<a href="https://colab.research.google.com/github/sayyamalam/sign-language-recognition/blob/main/2D_CNN/2D_CNN_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Nur fehlende Pakete installieren
%pip install -q --upgrade-strategy only-if-needed mlflow pyyaml
%pip install -q --no-deps decord

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.4/26.4 MB[0m [31m93.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m88.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m66.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m247.4/247.4 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m147.8/147.8 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.9/114.9 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.0/85.0 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m703.4/703.4 kB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# Basis-Imports
import os, json, random, math, time, pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import mlflow

# decord (Video-Decode)
try:
    import decord
    decord.bridge.set_bridge("native")  # numpy-Arrays
    _DECORD_AVAILABLE = True
except Exception as e:
    print("Warnung: decord nicht verfügbar:", e)
    _DECORD_AVAILABLE = False

# Drive mounten
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Standardpfade in Drive
MLFLOW_DIR   = "/content/drive/MyDrive/mlruns"
ARTIFACTS_DIR= "/content/drive/MyDrive/ml_artifacts/msasl_selected50/3d_baseline"
DATA_ROOT    = "/content/drive/MyDrive/msasl_clips"
SELECTED_DIR = "/content/sign-language-recognition/meta/selected50"

# Ordner anlegen
os.makedirs(MLFLOW_DIR, exist_ok=True)
os.makedirs(ARTIFACTS_DIR, exist_ok=True)

def device_report():
    print("TensorFlow:", tf.__version__)
    print("NumPy:", np.__version__)
    print("Pandas:", pd.__version__)
    print("GPUs:", tf.config.list_physical_devices('GPU'))
device_report()

# Performance-Flags
tf.keras.mixed_precision.set_global_policy("mixed_float16")
tf.config.optimizer.set_jit(True)
print("Mixed precision:", tf.keras.mixed_precision.global_policy())

Mounted at /content/drive
TensorFlow: 2.19.0
NumPy: 2.0.2
Pandas: 2.2.2
GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Mixed precision: <DTypePolicy "mixed_float16">


In [3]:
# Mini-Config (Baseline hart codiert)
CFG = dict(
    # Pfade
    DRIVE_ROOT=DATA_ROOT,
    SELECTED_DIR=SELECTED_DIR,
    ARTIFACTS_DIR=ARTIFACTS_DIR,
    MLFLOW_URI=MLFLOW_DIR,

    # MLflow Identität
    EXPERIMENT_NAME="MSASL_selected50",
    RUN_GROUP="3D_BASELINE",
    MODEL_FAMILY="3D",

    # Daten
    T=16, STRIDE=2, IMG_SIZE=160, BATCH=8,
    AUG="light", CROP_POLICY="box_1.2x", PAD_MODE="repeat",

    # Training
    EPOCHS=30, BASE_LR=3e-4, WEIGHT_DECAY=1e-4,
    WARMUP_EPOCHS=2, COSINE=True, CLIP_NORM=5.0,

    # Debug
    DRY_RUN=False, MAX_STEPS=0, EVAL_MAX_BATCHES=0, SKIP_DECODE=False,
    LIMIT_SAMPLES=dict(train=None, val=None, test=None),

    # Modell
    BACKBONE="i3d_baseline",
    FINE_TUNE="linear",
    HEAD_TYPE="gap_fc", HEAD_DIM=512, DROPOUT=0.3,
    LABEL_SMOOTHING=0.0,
)

# Falls decord fehlt und Decoding vorgesehen war → Dummy-Frames
if not _DECORD_AVAILABLE and not CFG["SKIP_DECODE"]:
    print("decord nicht verfügbar – schalte SKIP_DECODE=True (Dummy-Frames).")
    CFG["SKIP_DECODE"] = True

# --- MLflow robust initialisieren (repariert defektes .trash) ---
import mlflow
def _ensure_mlflow_dir_ok(path: str):
    os.makedirs(path, exist_ok=True)
    trash = os.path.join(path, ".trash")
    if os.path.exists(trash) and not os.path.isdir(trash):
        # defektes .trash (z. B. Datei) entfernen
        os.remove(trash)
    os.makedirs(trash, exist_ok=True)

_ensure_mlflow_dir_ok(CFG["MLFLOW_URI"])
mlflow.end_run()
mlflow.set_tracking_uri(CFG["MLFLOW_URI"])
mlflow.set_experiment(CFG["EXPERIMENT_NAME"])

def mlflow_start(run_name: str, tags: dict, params: dict = None):
    run = mlflow.start_run(run_name=run_name)
    mlflow.set_tags(tags)
    if params:
        to_log = {}
        for k,v in params.items():
            to_log[k] = float(v) if isinstance(v, (np.floating,)) else v
        mlflow.log_params(to_log)
    try:
        mlflow.tensorflow.autolog(log_models=True)
    except Exception as e:
        print("mlflow autolog Warnung:", e)
    return run

def mlflow_log_metrics(metrics: dict, step: int = None):
    if not metrics: return
    mlflow.log_metrics({k: float(v) for k, v in metrics.items()}, step=step)

def mlflow_log_artifact(path: str, artifact_path: str = None):
    if os.path.exists(path):
        mlflow.log_artifact(path, artifact_path=artifact_path)

def mlflow_end():
    mlflow.end_run()

print("MLflow bereit:", mlflow.get_tracking_uri())

MLflow bereit: /content/drive/MyDrive/mlruns


In [4]:
import re, glob, shutil

def youtube_id_from_url(url: str) -> str:
    """Extrahiert YouTube-ID aus URL."""
    m = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_\-]{6,})", url)
    if m: return m.group(1)
    return url.split("v=")[-1].split("&")[0]

def _slug(s: str, maxlen: int = 40) -> str:
    """Säubert Texte für Dateinamen."""
    s = re.sub(r"\s+", "_", str(s).strip())
    s = re.sub(r"[^\w\-]+", "_", s)
    s = re.sub(r"_+", "_", s).strip("_")
    return s[:maxlen] if maxlen else s

def _ts_token(start: float, end: float) -> str:
    """Erzeugt Zeitstempel-Token in ms."""
    s_ms = int(round(float(start) * 1000.0))
    e_ms = int(round(float(end)   * 1000.0))
    return f"{s_ms:07d}-{e_ms:07d}ms"

def make_filename(entry: dict) -> str:
    """Baut Dateinamen exakt wie im Downloader."""
    ytid  = youtube_id_from_url(entry["url"])
    split = entry["split"]
    label = int(entry.get("label", -1))
    signer_id = entry.get("signer_id", "na")

    start = float(entry.get("start_time", 0.0))
    end   = float(entry.get("end_time",   0.0))
    if end <= start: end = start + 3.0
    ts = _ts_token(start, end)

    clean_text = _slug(entry.get("clean_text", "na"), maxlen=40)
    return f"{ytid}__s-{ts}__lab-{label}__sig-{signer_id}__{clean_text}__{split}.mp4"

def load_split_index(selected_dir: str) -> dict:
    """Lädt JSONs und baut DataFrames mit Dateinamen wie im Downloader."""
    files = {"train": "MSASL_train_selected50.json",
             "val":   "MSASL_val_selected50.json",
             "test":  "MSASL_test_selected50.json"}
    out = {}
    for split, fname in files.items():
        path = os.path.join(selected_dir, fname)
        with open(path, "r") as f:
            data = json.load(f)
        recs = []
        for it in data:
            it["split"] = split
            fn = make_filename(it)
            recs.append({
                "label": str(it.get("label")),
                "filename": fn,
                "box": it.get("box") or None,
                "split": split
            })
        out[split] = pd.DataFrame.from_records(recs)
    return out

def build_class_mapping(dfs: dict):
    labels = pd.concat([dfs["train"]["label"], dfs["val"]["label"], dfs["test"]["label"]]).unique()
    class_names = sorted([str(x) for x in labels])
    label_to_id = {c: i for i, c in enumerate(class_names)}
    return label_to_id, class_names

# --- Lade Indexe ---
dfs = load_split_index(CFG["SELECTED_DIR"])
label_to_id, class_names = build_class_mapping(dfs)
num_classes = len(class_names)
print("Splits:", {k: len(v) for k, v in dfs.items()}, "| num_classes:", num_classes)

# --- Lokales Mirror-Directory für Geschwindigkeit ---
LOCAL_DATA_ROOT = "/content/msasl_clips_local"
if not os.path.exists(LOCAL_DATA_ROOT):
    for split in ["train","val","test"]:
        src = os.path.join(CFG["DRIVE_ROOT"], split)
        dst = os.path.join(LOCAL_DATA_ROOT, split)
        os.makedirs(dst, exist_ok=True)
        files = glob.glob(f"{src}/**/*.mp4", recursive=True)[:500]  # Beispiel: nur 500 pro Split
        for f in files:
            rel = os.path.relpath(f, src)
            dst_path = os.path.join(dst, rel)
            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            if not os.path.exists(dst_path):
                shutil.copy2(f, dst_path)
    print("✅ Clips lokal gespiegelt:", LOCAL_DATA_ROOT)

# CFG umschalten auf lokale Daten
CFG["DRIVE_ROOT"] = LOCAL_DATA_ROOT

# --- Quick Check ---
for split in ["train","val","test"]:
    df = dfs[split].head(3)
    for _, row in df.iterrows():
        p = os.path.join(CFG["DRIVE_ROOT"], split, row["label"], row["filename"])
        print(split, "→", os.path.exists(p), p)


Splits: {'train': 1677, 'val': 374, 'test': 248} | num_classes: 50
✅ Clips lokal gespiegelt: /content/msasl_clips_local
train → True /content/msasl_clips_local/train/8/jQb9NL9_S6U__s-0385765-0392077ms__lab-8__sig-6__want__train.mp4
train → True /content/msasl_clips_local/train/29/jQb9NL9_S6U__s-0433452-0437071ms__lab-29__sig-6__must__train.mp4
train → True /content/msasl_clips_local/train/2/_HOx2QkkTsg__s-0013995-0015498ms__lab-2__sig-144__teacher__train.mp4
val → True /content/msasl_clips_local/val/28/nhEw0JSb-XQ__s-0000000-0002933ms__lab-28__sig-3__table__val.mp4
val → True /content/msasl_clips_local/val/13/koMZVbqiXf4__s-0151351-0154755ms__lab-13__sig-125__white__val.mp4
val → True /content/msasl_clips_local/val/41/koMZVbqiXf4__s-0155956-0159126ms__lab-41__sig-125__black__val.mp4
test → True /content/msasl_clips_local/test/51/G77ZoILMYw4__s-0080380-0083283ms__lab-51__sig-9__doctor__test.mp4
test → True /content/msasl_clips_local/test/2/G77ZoILMYw4__s-0332331-0335301ms__lab-2__sig-9_

In [5]:
def make_clip_path(drive_root: str, split: str, label: str, filename: str) -> str:
    """Baut den absoluten Pfad zu einem Clip."""
    return os.path.join(drive_root, split, label, filename)

def sample_indices(num_frames: int, T: int, stride: int) -> np.ndarray:
    if num_frames <= 0:
        return np.zeros((T,), dtype=np.int64)
    idx = np.arange(0, T*stride, stride, dtype=np.int64)
    return np.clip(idx, 0, max(0, num_frames-1))

def decode_clip(path: str, indices: np.ndarray) -> np.ndarray:
    if CFG["SKIP_DECODE"]:
        return np.zeros((indices.shape[0], CFG["IMG_SIZE"], CFG["IMG_SIZE"], 3), dtype=np.uint8)
    vr = decord.VideoReader(path, num_threads=2, ctx=decord.cpu(0))
    indices = np.clip(indices, 0, len(vr)-1)
    return vr.get_batch(indices.tolist()).asnumpy()

def resize_and_box_crop(frames: np.ndarray, box, img_size: int, expand: float = 1.2) -> np.ndarray:
    T, H, W = frames.shape[0], frames.shape[1], frames.shape[2]
    if box is None:
        side = min(H, W); y1=(H-side)//2; x1=(W-side)//2; y2=y1+side; x2=x1+side
    else:
        x1r,y1r,x2r,y2r = box
        x1 = int(x1r*W); y1 = int(y1r*H); x2 = int(x2r*W); y2 = int(y2r*H)
        cx,cy = (x1+x2)/2,(y1+y2)/2
        bw=max(1,int((x2-x1)*expand)); bh=max(1,int((y2-y1)*expand))
        x1=int(cx-bw/2); x2=int(cx+bw/2); y1=int(cy-bh/2); y2=int(cy+bh/2)
        x1=max(0,x1); y1=max(0,y1); x2=min(W,x2); y2=min(H,y2)
    out = []
    for t in range(T):
        img = frames[t, y1:y2, x1:x2, :]
        img = tf.image.resize(img, (img_size, img_size), method="bilinear").numpy().astype(np.uint8)
        out.append(img)
    return np.stack(out, axis=0)

def augment_light(frames: np.ndarray, training: bool) -> np.ndarray:
    if not training or CFG["AUG"] in (None, "none"):
        return frames
    f = frames
    if random.random() < 0.5:
        f = f[:, :, ::-1, :]
    ff = tf.convert_to_tensor(f, dtype=tf.float32)
    ff = tf.image.random_brightness(ff, max_delta=0.05)
    ff = tf.image.random_contrast(ff, lower=0.95, upper=1.05)
    return tf.cast(tf.clip_by_value(ff, 0, 255), tf.uint8).numpy()


In [6]:
def build_split_dataframe(dfs: dict, split: str, limit: int | None) -> pd.DataFrame:
    df = dfs[split]
    if limit is not None:
        df = df.sample(n=min(limit, len(df)), random_state=42).reset_index(drop=True)
    return df

def gen_examples(dfs: dict, split: str):
    df = build_split_dataframe(dfs, split, CFG["LIMIT_SAMPLES"].get(split))
    for _, row in df.iterrows():
        label = row["label"]; label_id = label_to_id[label]
        path = make_clip_path(CFG["DRIVE_ROOT"], split, label, row["filename"])
        if not os.path.exists(path) and not CFG["SKIP_DECODE"]:
            continue
        try:
            if CFG["SKIP_DECODE"]:
                num_frames = CFG["T"] * CFG["STRIDE"]
            else:
                vr = decord.VideoReader(path, num_threads=1, ctx=decord.cpu(0))
                num_frames = len(vr)
            idx = sample_indices(num_frames, CFG["T"], CFG["STRIDE"])
            frames = decode_clip(path, idx)
        except Exception:
            continue
        frames = resize_and_box_crop(frames, row.get("box"), CFG["IMG_SIZE"], expand=1.2)
        frames = augment_light(frames, training=(split == "train"))
        yield frames, label_id

def _to_tensor(frames, label_id):
    x = tf.convert_to_tensor(frames, dtype=tf.uint8)
    x = tf.image.convert_image_dtype(x, tf.float32)
    y = tf.one_hot(label_id, depth=num_classes, dtype=tf.float32)
    return x, y

def _map_to_tensor(f, y):
    x, y = tf.py_function(func=_to_tensor, inp=[f, y], Tout=(tf.float32, tf.float32))
    x.set_shape((CFG["T"], CFG["IMG_SIZE"], CFG["IMG_SIZE"], 3))
    y.set_shape((num_classes,))
    return x, y

def build_dataset(dfs: dict, split: str, batch_size: int, shuffle: bool):
    output_sig = (
        tf.TensorSpec(shape=(CFG["T"], CFG["IMG_SIZE"], CFG["IMG_SIZE"], 3), dtype=tf.uint8),
        tf.TensorSpec(shape=(), dtype=tf.int32),
    )
    ds = tf.data.Dataset.from_generator(lambda: gen_examples(dfs, split), output_signature=output_sig)
    if shuffle: ds = ds.shuffle(buffer_size=2048, reshuffle_each_iteration=True)
    ds = ds.map(_map_to_tensor, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False if shuffle else True)
    ds = ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
    return ds

ds_train = build_dataset(dfs, "train", CFG["BATCH"], shuffle=True)
ds_val   = build_dataset(dfs, "val",   CFG["BATCH"], shuffle=False)
ds_test  = build_dataset(dfs, "test",  CFG["BATCH"], shuffle=False)

xb, yb = next(iter(ds_train))
print("Batch shapes:", xb.shape, yb.shape)


Batch shapes: (8, 16, 160, 160, 3) (8, 50)


In [7]:
from tensorflow.keras import layers, models, applications

def build_2d_baseline(input_shape, num_classes, dropout=0.3):
    """
    2D-Baseline: Pretrained MobileNetV3-Small als Frame-Encoder + Temporal Average.
    """
    inp = layers.Input(shape=input_shape, dtype=tf.float32)  # (T, H, W, 3)

    # Frames einzeln durch 2D-Backbone
    x = layers.TimeDistributed(
            applications.MobileNetV3Small(
                include_top=False,
                weights="imagenet",
                input_shape=input_shape[1:],  # (H, W, 3)
                pooling="avg"
            )
        )(inp)   # (T, feature_dim)

    # Zeitliche Aggregation
    x = layers.GlobalAveragePooling1D()(x)   # (feature_dim,)

    if dropout > 0:
        x = layers.Dropout(dropout)(x)

    out = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return models.Model(inp, out, name="mobilenetv3_2d_baseline")

model = build_2d_baseline(
    input_shape=(CFG["T"], CFG["IMG_SIZE"], CFG["IMG_SIZE"], 3),
    num_classes=num_classes,
    dropout=CFG["DROPOUT"]
)
model.summary()


  return MobileNetV3(


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v3/weights_mobilenet_v3_small_224_1.0_float_no_top_v2.h5
[1m4334752/4334752[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [8]:
def build_optimizer(steps_per_epoch=None):
    lr, wd = CFG["BASE_LR"], CFG["WEIGHT_DECAY"]
    if CFG["COSINE"] and steps_per_epoch:
        total_steps = steps_per_epoch * CFG["EPOCHS"]
        warmup_steps = CFG["WARMUP_EPOCHS"] * steps_per_epoch
        def schedule(step):
            step = tf.cast(step, tf.float32)
            lr_warm = lr * (step / tf.cast(max(1,warmup_steps), tf.float32))
            lr_cos  = 0.5*lr*(1.0+tf.cos(np.pi*(step-warmup_steps)/tf.cast(max(1,total_steps-warmup_steps), tf.float32)))
            return tf.where(step<warmup_steps, lr_warm, lr_cos)
        lr_schedule = tf.keras.optimizers.schedules.LearningRateSchedule(schedule)
        return tf.keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=wd, global_clipnorm=CFG["CLIP_NORM"])
    return tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=wd, global_clipnorm=CFG["CLIP_NORM"])

def compile_model(model, steps_per_epoch=None):
    opt = build_optimizer(steps_per_epoch)
    loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing=CFG["LABEL_SMOOTHING"])
    metrics = [tf.keras.metrics.TopKCategoricalAccuracy(k=1, name="top1"),
               tf.keras.metrics.TopKCategoricalAccuracy(k=5, name="top5")]
    model.compile(optimizer=opt, loss=loss, metrics=metrics)
    return model

def build_callbacks(out_dir: str):
    os.makedirs(out_dir, exist_ok=True)
    return [
        tf.keras.callbacks.TerminateOnNaN(),
        tf.keras.callbacks.ModelCheckpoint(os.path.join(out_dir,"ckpt_best.weights.h5"),
                                           monitor="val_top1", mode="max",
                                           save_best_only=True, save_weights_only=True),
        tf.keras.callbacks.EarlyStopping(monitor="val_top1", mode="max", patience=5, restore_best_weights=True),
        tf.keras.callbacks.CSVLogger(os.path.join(out_dir,"train_log.csv"))
    ]

steps_per_epoch = None if (not CFG["DRY_RUN"] or not CFG["MAX_STEPS"]) else CFG["MAX_STEPS"]
model = compile_model(model, steps_per_epoch)
callbacks = build_callbacks(CFG["ARTIFACTS_DIR"])

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import json

def evaluate_and_log(model, dataset, class_names, run_name="test_2d"):
    """
    Evaluiert das Modell auf dataset und loggt Ergebnisse in MLflow.
    """
    y_true, y_pred, y_pred_top5 = [], [], []
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    losses = []

    for xb, yb in dataset:
        probs = model(xb, training=False).numpy()
        preds = np.argmax(probs, axis=1)
        preds_top5 = np.argsort(probs, axis=1)[:, -5:]
        labels = np.argmax(yb.numpy(), axis=1)

        y_true.extend(labels)
        y_pred.extend(preds)
        y_pred_top5.extend(preds_top5)
        losses.append(loss_fn(yb, probs).numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred_top5 = np.array(y_pred_top5)
    avg_loss = float(np.mean(losses))

    # --- Top1 / Top5 ---
    top1 = float((y_true == y_pred).mean())
    top5 = float(np.mean([yt in yp for yt, yp in zip(y_true, y_pred_top5)]))

    # --- Confusion Matrix ---
    cm = confusion_matrix(y_true, y_pred, labels=range(len(class_names)))
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, xticklabels=class_names, yticklabels=class_names,
                cmap="Blues", cbar=True)
    plt.title(f"Confusion Matrix ({run_name})")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    cm_png = os.path.join(CFG["ARTIFACTS_DIR"], f"confusion_matrix_{run_name}.png")
    plt.tight_layout(); plt.savefig(cm_png, dpi=150); plt.close()
    mlflow_log_artifact(cm_png)

    # --- Per-Class Report ---
    report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    report_df = pd.DataFrame(report).transpose()
    report_csv = os.path.join(CFG["ARTIFACTS_DIR"], f"class_report_{run_name}.csv")
    report_df.to_csv(report_csv)
    mlflow_log_artifact(report_csv)

    # --- Metrics JSON ---
    mlflow.log_metrics({
        f"{run_name}_loss": avg_loss,
        f"{run_name}_top1": top1,
        f"{run_name}_top5": top5
    })

    print(f"[Eval-{run_name}] Top1={top1:.3f}, Top5={top5:.3f}, Loss={avg_loss:.3f}")
    return {"loss": avg_loss, "top1": top1, "top5": top5}

# =========================================================
# Training + Eval + Logging (2D-Baseline)
# =========================================================
run_name = f"2D_BASE_mobilenetv3_T{CFG['T']}_s{CFG['STRIDE']}"
tags = {
    "run_group": "2D_BASELINE",
    "model_family": "2D",
    "backbone": "mobilenetv3_small",
    "fine_tune": CFG["FINE_TUNE"],
    "head": "temporal_avg",
    "baseline": "true"
}
params = {
    "epochs": CFG["EPOCHS"],
    "batch_size": CFG["BATCH"],
    "base_lr": CFG["BASE_LR"]
}

_ = mlflow_start(run_name, tags, params)
print("MLflow run gestartet:", run_name)

# --- Training ---
history = model.fit(
    ds_train,
    validation_data=ds_val,
    epochs=CFG["EPOCHS"],
    steps_per_epoch=steps_per_epoch,
    verbose=1,
    callbacks=callbacks
)

# --- Lernkurven speichern & loggen ---
curves_png = os.path.join(CFG["ARTIFACTS_DIR"], "curves_2d.png")
plt.figure()
plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.plot(history.history["top1"], label="train_top1")
plt.plot(history.history["val_top1"], label="val_top1")
plt.legend(); plt.xlabel("epoch"); plt.ylabel("metric")
plt.title("Training Curves (2D)")
plt.savefig(curves_png); plt.close()
mlflow_log_artifact(curves_png)

# --- Evaluation auf Testset ---
evaluate_and_log(model, ds_test, class_names, run_name="test_2d")

# --- Run schließen ---
mlflow_end()
print("Training + Evaluation (2D-Baseline) beendet & MLflow-Run geschlossen.")


MLflow run gestartet: 2D_BASE_mobilenetv3_T16_s2
Epoch 1/30
     63/Unknown [1m904s[0m 73ms/step - loss: 4.0148 - top1: 0.0482 - top5: 0.2273



[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1191s[0m 5s/step - loss: 4.0070 - top1: 0.0484 - top5: 0.2287 - val_loss: 4.4013 - val_top1: 0.0163 - val_top5: 0.0924
Epoch 2/30
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m228s[0m 2s/step - loss: 2.6606 - top1: 0.1392 - top5: 0.5428 - val_loss: 4.8298 - val_top1: 0.0163 - val_top5: 0.1005
Epoch 3/30
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m227s[0m 2s/step - loss: 2.3749 - top1: 0.2703 - top5: 0.6764 - val_loss: 6.7552 - val_top1: 0.0082 - val_top5: 0.0897
Epoch 4/30
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m229s[0m 2s/step - loss: 2.1681 - top1: 0.3070 - top5: 0.7270 - val_loss: 29.5912 - val_top1: 0.0272 - val_top5: 0.1114
Epoch 5/30
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step - loss: 1.9591 - top1: 0.3943 - top5: 0.7518