In [2]:
# ============================================
# WildTrack — SeaTurtleID2022 Re-ID baseline (robust)
# Fixes: single-class issue, top-5 crash, smarter label parsing
# Saves: model.h5, preprocessor.pkl, model_config.yaml, metrics.json
# ============================================
import os, json, pickle, random
from datetime import datetime
from collections import Counter

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import accuracy_score, top_k_accuracy_score
from sklearn.utils.class_weight import compute_class_weight
import yaml

# -------------------------
# Paths
# -------------------------
DATA_DIR   = r"C:\Users\sagni\Downloads\WildTrack\archive\turtles-data\data"
OUTPUT_DIR = r"C:\Users\sagni\Downloads\WildTrack"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------
# Config
# -------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

IMG_SIZE     = (224, 224)
BATCH        = 32
EPOCHS       = 3
VAL_SPLIT    = 0.15
MAX_CLASSES  = None           # keep all if available
MIN_IMAGES_PER_CLASS = 2      # lower to avoid dropping many IDs
EMBED_DIM    = 256
BACKBONE     = "EfficientNetB0"  # or "MobileNetV2"

# -------------------------
# Utils
# -------------------------
def find_images(root, exts=(".jpg",".jpeg",".png",".bmp",".tif",".tiff")):
    out = []
    for dp, _, files in os.walk(root):
        for f in files:
            if f.lower().endswith(exts):
                out.append(os.path.join(dp, f))
    return out

def try_labels_csv(root):
    # Try to auto-detect a CSV with columns mapping image->id
    candidate_csvs = [
        "labels.csv","train.csv","metadata.csv","ids.csv","annotations.csv",
        "train_labels.csv","train_annotations.csv"
    ]
    for name in candidate_csvs:
        p = os.path.join(root, name)
        if os.path.isfile(p):
            try:
                df = pd.read_csv(p)
                cl = [c.lower() for c in df.columns]
                img_col = None
                for c in ["image","img","path","file","filename","filepath"]:
                    if c in cl: img_col = df.columns[cl.index(c)]; break
                id_col = None
                for c in ["id","individual_id","label","class","identity"]:
                    if c in cl: id_col = df.columns[cl.index(c)]; break
                if img_col and id_col:
                    m = df[[img_col,id_col]].rename(columns={img_col:"image", id_col:"id"})
                    return m
            except Exception:
                pass
    return None

# Smarter label-from-path: skip container names
SKIP_DIRS = set([
    "data","dataset","datasets","images","imgs","img",
    "train","val","valid","validation","test","all","photos","pictures"
])
def smart_label_from_path(p):
    parts = os.path.normpath(p).split(os.sep)
    # Start from parent and walk up until a non-container folder is found
    for i in range(len(parts)-2, -1, -1):
        name = parts[i]
        if name.lower() not in SKIP_DIRS:
            return name
    # Fallback to immediate parent
    return os.path.basename(os.path.dirname(p))

# -------------------------
# Build (path, label) table
# -------------------------
images, labels = [], []

# Prefer a CSV mapping if present
df_map = try_labels_csv(DATA_DIR)
if df_map is not None:
    for _, r in df_map.iterrows():
        p = r["image"]
        if not os.path.isabs(p):
            p = os.path.join(DATA_DIR, p)
        if os.path.isfile(p):
            images.append(p)
            labels.append(str(r["id"]))
else:
    # Fallback: discover all images and derive labels from path
    all_imgs = find_images(DATA_DIR)
    if not all_imgs:
        raise RuntimeError(f"No images found under {DATA_DIR}. Check the path.")
    for p in all_imgs:
        labels.append(smart_label_from_path(p))
        images.append(p)

df = pd.DataFrame({"path": images, "label": labels})

# Drop tiny classes
vc = df["label"].value_counts()
keep = vc[vc >= MIN_IMAGES_PER_CLASS].index
df = df[df["label"].isin(keep)].reset_index(drop=True)

# Optional cap
if MAX_CLASSES is not None and df["label"].nunique() > MAX_CLASSES:
    top_ids = df["label"].value_counts().head(MAX_CLASSES).index
    df = df[df["label"].isin(top_ids)].reset_index(drop=True)

classes = sorted(df["label"].unique().tolist())
n_classes = len(classes)
print(f"[INFO] Images: {len(df)} | Classes: {n_classes}")
if n_classes < 2:
    print("[WARN] Only one class detected. Training a classifier is not meaningful. "
          "Consider pointing DATA_DIR to a deeper folder (where subfolders are individual IDs), "
          "or include/enable a labels CSV. Proceeding anyway and skipping Top-5 metrics.")

# Stratified train/val
rng = np.random.RandomState(SEED)
df["rnd"] = rng.rand(len(df))
val_mask = df.groupby("label")["rnd"].transform(lambda s: s.rank(pct=True)) <= VAL_SPLIT
df_train = df[~val_mask].drop(columns=["rnd"]).reset_index(drop=True)
df_val   = df[val_mask].drop(columns=["rnd"]).reset_index(drop=True)
print(f"[INFO] Train: {len(df_train)} | Val: {len(df_val)}")

label2id = {c:i for i,c in enumerate(classes)}
id2label = {i:c for c,i in label2id.items()}

# -------------------------
# tf.data pipelines
# -------------------------
AUTOTUNE = tf.data.AUTOTUNE

def decode_img(path):
    x = tf.io.read_file(path)
    x = tf.image.decode_image(x, channels=3, expand_animations=False)
    x = tf.image.convert_image_dtype(x, tf.float32)
    x = tf.image.resize(x, IMG_SIZE)
    return x

def py_map(path, label_str):
    idx = label2id[label_str.numpy().decode("utf-8")]
    return path, idx

def tf_map(path, label_str):
    p, y = tf.py_function(py_map, [path, label_str], [tf.string, tf.int64])
    p.set_shape([]); y.set_shape([])
    img = decode_img(p)
    return img, tf.cast(y, tf.int32)

def augment(img, y):
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_brightness(img, 0.1)
    img = tf.image.random_contrast(img, 0.9, 1.1)
    return img, y

def make_ds(frame, training=True):
    ds = tf.data.Dataset.from_tensor_slices((frame["path"].values, frame["label"].values))
    if training: ds = ds.shuffle(len(frame), seed=SEED)
    ds = ds.map(tf_map, num_parallel_calls=AUTOTUNE)
    if training: ds = ds.map(augment, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(BATCH).prefetch(AUTOTUNE)
    return ds

train_ds = make_ds(df_train, True)
val_ds   = make_ds(df_val, False)

# -------------------------
# Model
# -------------------------
def build_model(nc:int, embed_dim:int=256, backbone:str="EfficientNetB0"):
    inputs = keras.Input(shape=(*IMG_SIZE,3))
    if backbone == "MobileNetV2":
        base = keras.applications.MobileNetV2(include_top=False, input_tensor=inputs, weights="imagenet")
    else:
        base = keras.applications.EfficientNetB0(include_top=False, input_tensor=inputs, weights="imagenet")
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.Dense(embed_dim, activation=None, name="embedding")(x)
    x = layers.Lambda(lambda t: tf.math.l2_normalize(t, axis=-1), name="l2norm")(x)
    if nc == 1:
        # Binary-ish head (degenerate dataset). Use sigmoid to avoid softmax(1) warning.
        logits = layers.Dense(1, activation="sigmoid", name="classifier")(x)
        loss   = "binary_crossentropy"
        metrics= [keras.metrics.BinaryAccuracy(name="acc")]
    else:
        logits = layers.Dense(nc, activation="softmax", name="classifier")(x)
        loss   = "sparse_categorical_crossentropy"
        metrics= [keras.metrics.SparseCategoricalAccuracy(name="acc")]
    model = keras.Model(inputs, logits, name="reid_classifier")
    model.compile(optimizer=keras.optimizers.Adam(1e-3), loss=loss, metrics=metrics)
    return model

model = build_model(n_classes, EMBED_DIM, BACKBONE)

# Class weights (only if >1 class)
class_weight = None
if n_classes > 1:
    y_train_idx = df_train["label"].map(label2id).values
    cw = compute_class_weight("balanced", classes=np.arange(n_classes), y=y_train_idx)
    class_weight = {i: float(w) for i,w in enumerate(cw)}

callbacks = [
    keras.callbacks.ModelCheckpoint(os.path.join(OUTPUT_DIR, "tmp_best.keras"),
                                    monitor="val_acc", mode="max", save_best_only=True),
    keras.callbacks.EarlyStopping(monitor="val_acc", mode="max", patience=2, restore_best_weights=True)
]

# -------------------------
# Train
# -------------------------
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weight,
    verbose=1,
    callbacks=callbacks
)

# Reload best
try:
    model = keras.models.load_model(os.path.join(OUTPUT_DIR, "tmp_best.keras"))
except Exception:
    pass

# -------------------------
# Evaluate
# -------------------------
y_true, y_prob = [], []
for bx, by in val_ds:
    pr = model.predict(bx, verbose=0)
    y_prob.append(pr)
    y_true.append(by.numpy())
y_prob = np.concatenate(y_prob, axis=0)
y_true = np.concatenate(y_true, axis=0)

if n_classes == 1:
    # binary head: y_prob shape (N,1); fake y_pred as zeros (single class)
    y_pred = (y_prob >= 0.5).astype(int).squeeze()
    top1 = float(accuracy_score(np.zeros_like(y_true), np.zeros_like(y_pred)))  # degenerate but defined
    top5 = None
else:
    y_pred = y_prob.argmax(axis=1)
    top1 = float(accuracy_score(y_true, y_pred))
    # Guard top-k when shapes/classes mismatch or nc < 2
    try:
        top5 = float(top_k_accuracy_score(y_true, y_prob, k=min(5, n_classes), labels=np.arange(n_classes)))
    except Exception:
        top5 = None

print(f"[INFO] Val Top-1 Acc: {top1:.4f}" + ("" if top5 is None else f" | Top-5 Acc: {top5:.4f}"))

# -------------------------
# Save artifacts
# -------------------------
# model.h5
h5_path = os.path.join(OUTPUT_DIR, "model.h5")
model.save(h5_path)
print(f"[INFO] Saved -> {h5_path}")

# preprocessor.pkl
preproc = {
    "image_size": IMG_SIZE,
    "embed_dim": EMBED_DIM,
    "backbone": BACKBONE,
    "label2id": label2id,
    "id2label": {int(k):v for k,v in id2label.items()},
    "class_count": n_classes,
    "min_images_per_class": MIN_IMAGES_PER_CLASS,
    "val_split": VAL_SPLIT,
    "selected_classes": classes,
}
with open(os.path.join(OUTPUT_DIR, "preprocessor.pkl"), "wb") as f:
    pickle.dump(preproc, f)
print(f"[INFO] Saved -> {os.path.join(OUTPUT_DIR, 'preprocessor.pkl')}")

# model_config.yaml
cfg = {
    "created": datetime.now().isoformat(timespec="seconds"),
    "task": "wildlife_reid_baseline_classifier",
    "data_dir": DATA_DIR,
    "output_dir": OUTPUT_DIR,
    "image_size": list(IMG_SIZE),
    "batch": BATCH,
    "epochs": EPOCHS,
    "backbone": BACKBONE,
    "embed_dim": EMBED_DIM,
    "num_classes": n_classes,
    "training": {"optimizer":"Adam","lr":1e-3,
                 "loss": "binary_crossentropy" if n_classes==1 else "sparse_categorical_crossentropy",
                 "metrics": ["acc"]}
}
with open(os.path.join(OUTPUT_DIR, "model_config.yaml"), "w", encoding="utf-8") as f:
    yaml.safe_dump(cfg, f, sort_keys=False)
print(f"[INFO] Saved -> {os.path.join(OUTPUT_DIR, 'model_config.yaml')}")

# metrics.json
metrics = {
    "val_top1_accuracy": top1,
    "val_top5_accuracy": top5,
    "num_validation_samples": int(len(df_val)),
    "num_train_samples": int(len(df_train)),
    "num_classes": n_classes
}
with open(os.path.join(OUTPUT_DIR, "metrics.json"), "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=2)
print(f"[INFO] Saved -> {os.path.join(OUTPUT_DIR, 'metrics.json')}")

print("\n[INFO] Done. If you still see only one class, your paths likely point to a container folder.")
print("Try setting DATA_DIR to the level where subfolders are actual individual IDs, or provide a labels CSV.")


[INFO] Images: 8728 | Classes: 437
[INFO] Train: 7633 | Val: 1095
Epoch 1/3
[1m239/239[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m538s[0m 2s/step - acc: 0.0117 - loss: 6.0230 - val_acc: 0.0119 - val_loss: 6.0649
Epoch 2/3
[1m239/239[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m475s[0m 2s/step - acc: 0.0538 - loss: 5.8971 - val_acc: 0.0091 - val_loss: 6.1290
Epoch 3/3
[1m239/239[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m474s[0m 2s/step - acc: 0.0788 - loss: 5.4494 - val_acc: 9.1324e-04 - val_loss: 6.3437




[INFO] Val Top-1 Acc: 0.0119 | Top-5 Acc: 0.0320
[INFO] Saved -> C:\Users\sagni\Downloads\WildTrack\model.h5
[INFO] Saved -> C:\Users\sagni\Downloads\WildTrack\preprocessor.pkl
[INFO] Saved -> C:\Users\sagni\Downloads\WildTrack\model_config.yaml
[INFO] Saved -> C:\Users\sagni\Downloads\WildTrack\metrics.json

[INFO] Done. If you still see only one class, your paths likely point to a container folder.
Try setting DATA_DIR to the level where subfolders are actual individual IDs, or provide a labels CSV.
