In [1]:
# ============================================================
# WildTrack (SeaTurtleID) — Re-ID Evaluation & Visualization
# Saves under OUTPUT_DIR:
#  - metrics.json
#  - tsne.png
#  - retrieval_example_*.png
#  - gradcam_*.png   (if a conv layer is found)
# Requires: model.{keras|h5}, preprocessor.pkl from your training step
# ============================================================
import os, io, json, random, math, csv, pickle
from collections import defaultdict
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw, ImageFont

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from sklearn.manifold import TSNE
from sklearn.metrics import average_precision_score

# ------------ Paths (edit if needed) ------------
DATA_DIR   = r"C:\Users\sagni\Downloads\WildTrack\archive\turtles-data\data"
OUTPUT_DIR = r"C:\Users\sagni\Downloads\WildTrack"

MODEL_KERAS = os.path.join(OUTPUT_DIR, "model.keras")
MODEL_H5    = os.path.join(OUTPUT_DIR, "model.h5")
PP_PATH     = os.path.join(OUTPUT_DIR, "preprocessor.pkl")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ------------ Load preprocessor ------------
with open(PP_PATH, "rb") as f:
    preproc = pickle.load(f)

IMG_SIZE  = tuple(preproc.get("image_size", (224, 224)))
label2id  = {k:int(v) for k,v in preproc["label2id"].items()}
id2label  = {int(k):v for k,v in preproc["id2label"].items()}
backbone  = str(preproc.get("backbone", "EfficientNetB0"))
embed_dim = int(preproc.get("embed_dim", 256))
use_l2    = bool(preproc.get("use_l2norm", True))
SEED      = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# ------------ Version-safe serializable L2 ------------
try:
    from tensorflow.keras.utils import register_keras_serializable
except Exception:
    try:
        from keras.utils import register_keras_serializable
    except Exception:
        def register_keras_serializable(package="WildTrack"):
            def deco(obj): return obj
            return deco

@register_keras_serializable(package="WildTrack")
class L2Normalize(layers.Layer):
    def __init__(self, axis=-1, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis
    def call(self, x): return tf.math.l2_normalize(x, axis=self.axis)
    def get_config(self):
        cfg = super().get_config(); cfg.update({"axis": self.axis}); return cfg

# ------------ Robust model loader ------------
def build_classifier(n_classes:int, image_size, backbone="EfficientNetB0",
                     embed_dim=256, use_l2=True):
    inputs = keras.Input(shape=(image_size[0], image_size[1], 3))
    bb = backbone.lower()
    if bb == "efficientnetb0":
        base = keras.applications.EfficientNetB0(include_top=False, weights=None, pooling="avg")
    elif bb == "resnet50":
        base = keras.applications.ResNet50(include_top=False, weights=None, pooling="avg")
    elif bb == "mobilenetv2":
        base = keras.applications.MobileNetV2(include_top=False, weights=None, pooling="avg")
    else:
        base = keras.applications.EfficientNetB0(include_top=False, weights=None, pooling="avg")
    x = base(inputs)
    x = layers.Dropout(0.2, name="dropout")(x)
    if embed_dim and embed_dim > 0:
        x = layers.Dense(embed_dim, name="emb")(x)
    if use_l2:
        x = L2Normalize(name="l2norm")(x)
    outputs = layers.Dense(n_classes, activation="softmax", name="softmax")(x)
    return keras.Model(inputs, outputs, name="wildtrack_classifier")

def load_classifier():
    custom = {"L2Normalize": L2Normalize}
    if os.path.exists(MODEL_KERAS):
        try:
            m = keras.models.load_model(MODEL_KERAS, custom_objects=custom)
            print(f"[INFO] Loaded model: {MODEL_KERAS}")
            return m
        except Exception as e:
            print("[WARN] model.keras load failed:", e)
    if os.path.exists(MODEL_H5):
        try:
            m = keras.models.load_model(MODEL_H5, compile=False, custom_objects=custom)
            print(f"[INFO] Loaded model: {MODEL_H5}")
            return m
        except Exception as e:
            print("[WARN] Direct H5 load failed:", e)
            # Rebuild & try by_name
            m = build_classifier(len(id2label), IMG_SIZE, backbone, embed_dim, use_l2)
            try:
                m.load_weights(MODEL_H5, by_name=True, skip_mismatch=True)
                print("[INFO] Loaded weights by_name from H5.")
                return m
            except Exception as e2:
                print("[ERROR] Could not load weights:", e2)
                raise
    raise FileNotFoundError("Missing model file (.keras/.h5).")

model = load_classifier()

# ------------ Embedding model (pre-softmax) ------------
def build_embedding_model(classifier: keras.Model):
    # if L2 head exists, use it directly
    try:
        l = classifier.get_layer("l2norm")
        print("[INFO] Using 'l2norm' output for embeddings.")
        return keras.Model(classifier.inputs, l.output)
    except Exception:
        pass
    # else, take tensor feeding softmax and normalize
    softmax_layer = None
    for l in classifier.layers[::-1]:
        if isinstance(l, layers.Dense) and getattr(l, "activation", None) == keras.activations.softmax:
            softmax_layer = l; break
        if l.name.lower() == "softmax":
            softmax_layer = l; break
    if softmax_layer is None:
        print("[WARN] Softmax not found; using last layer.")
        return keras.Model(classifier.inputs, classifier.layers[-1].output)
    emb_t = softmax_layer.input
    x = emb_t / (tf.norm(emb_t, axis=-1, keepdims=True) + 1e-12)
    return keras.Model(classifier.inputs, x)

emb_model = build_embedding_model(model)

# ------------ Dataset utils ------------
def list_images(root, exts=(".jpg",".jpeg",".png",".bmp",".tif",".tiff")):
    out = []
    for dp,_,files in os.walk(root):
        for f in files:
            if f.lower().endswith(exts):
                out.append(os.path.join(dp, f))
    return out

SKIP_DIRS = {"data","dataset","datasets","images","imgs","img","train",
             "val","valid","validation","test","all","photos","pictures"}

def smart_label_from_path(p):
    parts = os.path.normpath(p).split(os.sep)
    for i in range(len(parts)-2,-1,-1):
        name = parts[i]
        if name.lower() not in SKIP_DIRS:
            return name
    return os.path.basename(os.path.dirname(p))

def load_and_preprocess(path):
    x = tf.io.read_file(path)
    x = tf.image.decode_image(x, channels=3, expand_animations=False)
    x = tf.image.convert_image_dtype(x, tf.float32)
    x = tf.image.resize(x, IMG_SIZE)
    return x

# ------------ Build per-ID groups & split ------------
all_paths = list_images(DATA_DIR)
if not all_paths:
    raise RuntimeError(f"No images found under {DATA_DIR}")

all_labels = [smart_label_from_path(p) for p in all_paths]
# keep only labels known to training
keep = [lab in label2id for lab in all_labels]
paths  = [p for p,k in zip(all_paths, keep) if k]
labels = [l for l,k in zip(all_labels, keep) if k]

groups = defaultdict(list)
for p,l in zip(paths, labels):
    groups[l].append(p)

# Require at least 2 images per ID for re-ID
groups = {l:imgs for l,imgs in groups.items() if len(imgs) >= 2}
ids = sorted(groups.keys())
print(f"[INFO] Eligible IDs (>=2 imgs): {len(ids)}")

# Split: 1 query per ID, rest gallery
queries, q_labels = [], []
gallery, g_labels = [], []
for l in ids:
    imgs = groups[l].copy()
    random.shuffle(imgs)
    q = imgs.pop()            # one query
    queries.append(q); q_labels.append(l)
    for g in imgs:            # remaining gallery
        gallery.append(g); g_labels.append(l)

print(f"[INFO] Queries: {len(queries)} | Gallery: {len(gallery)}")

# ------------ Embed (batched) ------------
def embed_paths(p_list, batch=32):
    out = []
    for i in range(0, len(p_list), batch):
        chunk = p_list[i:i+batch]
        batch_img = tf.stack([load_and_preprocess(p) for p in chunk], axis=0)
        emb = emb_model.predict(batch_img, verbose=0)
        emb = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-12)
        out.append(emb.astype(np.float32))
    return np.vstack(out)

print("[INFO] Embedding gallery...")
G = embed_paths(gallery, batch=32)  # (Ng, D)
print("[INFO] Embedding queries...")
Q = embed_paths(queries, batch=32)  # (Nq, D)

# ------------ Similarity & metrics ------------
def rank_search(q, G):
    sims = G @ q.astype(np.float32)
    idxs = np.argsort(sims)[::-1]
    return idxs, sims[idxs]

def cmc_map(q_labels, g_labels, Q, G, ks=(1,5,10)):
    ks = tuple(sorted(ks))
    correct_at = np.zeros((len(q_labels), max(ks)), dtype=np.int32)
    ap_list = []
    for qi in range(len(q_labels)):
        idxs, sims = rank_search(Q[qi], G)
        rel = np.array([1 if g_labels[j] == q_labels[qi] else 0 for j in idxs], dtype=np.int32)
        # CMC
        for r in range(min(len(rel), max(ks))):
            correct_at[qi, r] = 1 if rel[:r+1].any() else 0
        # AP
        if rel.sum() == 0:
            ap_list.append(0.0)
        else:
            # classic AP computation
            cum_rel = np.cumsum(rel)
            precisions = cum_rel / (np.arange(len(rel)) + 1)
            ap = (precisions * rel).sum() / rel.sum()
            ap_list.append(float(ap))
    recalls = {f"R@{k}": float(correct_at[:, k-1].mean()) for k in ks}
    metrics = {
        **recalls,
        "mAP": float(np.mean(ap_list)),
        "num_queries": int(len(q_labels)),
        "num_gallery": int(len(g_labels)),
        "num_ids": int(len(set(q_labels)))
    }
    return metrics

metrics = cmc_map(q_labels, g_labels, Q, G, ks=(1,5,10))
print("[INFO] Metrics:", metrics)

with open(os.path.join(OUTPUT_DIR, "metrics.json"), "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=2)
print(f"[INFO] Saved -> {os.path.join(OUTPUT_DIR, 'metrics.json')}")

# ------------ t-SNE (subset for speed) ------------
try:
    subset = min(2000, len(gallery))
    sel = np.random.choice(len(gallery), subset, replace=False)
    X = G[sel]
    y = np.array([g_labels[i] for i in sel])
    print("[INFO] Running t-SNE on", X.shape)
    tsne = TSNE(n_components=2, perplexity=30, learning_rate='auto', init='pca', random_state=SEED)
    Z = tsne.fit_transform(X)
    # simple scatter with tiny dots (no legend to avoid clutter)
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    plt.figure(figsize=(7,6))
    plt.scatter(Z[:,0], Z[:,1], s=4, c=np.random.RandomState(SEED).rand(len(Z)), alpha=0.7)
    plt.title("t-SNE of Re-ID Embeddings (gallery subset)")
    out_tsne = os.path.join(OUTPUT_DIR, "tsne.png")
    plt.tight_layout(); plt.savefig(out_tsne, dpi=160); plt.close()
    print(f"[INFO] Saved -> {out_tsne}")
except Exception as e:
    print("[WARN] t-SNE failed:", e)

# ------------ Retrieval panels (top-5) ------------
def load_thumb(path, side=224):
    im = Image.open(path).convert("RGB")
    im = im.resize((side, side), Image.BILINEAR)
    return im

def draw_panel(query_path, q_label, top_paths, top_labels, top_sims, save_path):
    pad = 8
    cell = 200
    cols = 6  # 1 query + 5 results
    W = cols*cell + (cols+1)*pad
    H = cell + 2*pad + 40
    canvas = Image.new("RGB", (W, H), (250,250,250))
    draw = ImageDraw.Draw(canvas)
    # fonts (fallback to default if not available)
    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except:
        font = ImageFont.load_default()

    # place query
    q = load_thumb(query_path, side=cell)
    canvas.paste(q, (pad, pad))
    draw.text((pad, pad+cell+5), f"QUERY\n{q_label}", fill=(0,0,0), font=font)

    # place results
    for i,(p,l,s) in enumerate(zip(top_paths, top_labels, top_sims), start=1):
        x = i*(cell+pad) + pad
        y = pad
        im = load_thumb(p, side=cell)
        canvas.paste(im, (x, y))
        col = (0,140,0) if l == q_label else (180,0,0)
        txt = f"{l}  sim={s:.3f}"
        draw.text((x, y+cell+5), txt, fill=col, font=font)

    canvas.save(save_path)

N_EXAMPLES = min(6, len(queries))
example_q = np.random.choice(len(queries), N_EXAMPLES, replace=False)
for idx in example_q:
    q = Q[idx]
    qlab = q_labels[idx]
    order, sims = rank_search(q, G)
    topk = 5
    inds = order[:topk]
    tops = [gallery[i] for i in inds]
    tlabs = [g_labels[i] for i in inds]
    tsims = sims[:topk]
    out = os.path.join(OUTPUT_DIR, f"retrieval_example_{idx}.png")
    draw_panel(queries[idx], qlab, tops, tlabs, tsims, out)
    print(f"[INFO] Saved -> {out}")

# ------------ Grad-CAM on a few queries (if conv exists) ------------
def find_last_conv(model):
    for l in reversed(model.layers):
        if isinstance(l, (layers.Conv2D, layers.SeparableConv2D, layers.DepthwiseConv2D)):
            return l.name
    return None

last_conv_name = find_last_conv(model)
if last_conv_name is None:
    print("[WARN] No conv layer found; skipping Grad-CAM.")
else:
    try:
        grad_model = keras.Model(
            [model.inputs],
            [model.get_layer(last_conv_name).output, model.output]
        )
        def gradcam(path, save_path):
            img = load_and_preprocess(path)[None, ...]
            with tf.GradientTape() as tape:
                conv_out, preds = grad_model(img, training=False)
                class_idx = tf.argmax(preds[0])
                loss = preds[:, class_idx]
            grads = tape.gradient(loss, conv_out)
            pooled = tf.reduce_mean(grads, axis=(1,2), keepdims=True)
            cam = tf.nn.relu(tf.reduce_sum(pooled * conv_out, axis=-1))[0]  # (H,W)
            cam = cam.numpy()
            cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
            # overlay
            base = Image.open(path).convert("RGB").resize((IMG_SIZE[1], IMG_SIZE[0]))
            heat = Image.fromarray(np.uint8(255*cam)).resize(base.size, Image.BILINEAR)
            heat = heat.convert("RGBA")
            # colorize heatmap (simple red colormap)
            r = np.array(heat)
            rgba = np.zeros((r.shape[0], r.shape[1], 4), dtype=np.uint8)
            rgba[...,0] = r   # R
            rgba[...,3] = (r*0.6).astype(np.uint8)  # alpha
            heat_col = Image.fromarray(rgba, mode="RGBA")
            out = base.copy()
            out.paste(heat_col, (0,0), heat_col)
            out.save(save_path)

        N_GC = min(4, len(queries))
        picks = np.random.choice(len(queries), N_GC, replace=False)
        for idx in picks:
            out = os.path.join(OUTPUT_DIR, f"gradcam_{idx}.png")
            gradcam(queries[idx], out)
            print(f"[INFO] Saved -> {out}")
    except Exception as e:
        print("[WARN] Grad-CAM failed:", e)

print("\n[INFO] Done. Artifacts:")
print(" - metrics.json")
print(" - tsne.png")
print(" - retrieval_example_*.png")
print(" - gradcam_*.png (if available)")



[WARN] Direct H5 load failed: Exception encountered when calling Lambda.call().

[1mWe could not automatically infer the shape of the Lambda's output. Please specify the `output_shape` argument for this Lambda layer.[0m

Arguments received by Lambda.call():
  • args=('<KerasTensor shape=(None, 256), dtype=float32, sparse=False, ragged=False, name=keras_tensor_502>',)
  • kwargs={'mask': 'None'}
[INFO] Loaded weights by_name from H5.
[INFO] Using 'l2norm' output for embeddings.
[INFO] Eligible IDs (>=2 imgs): 437
[INFO] Queries: 437 | Gallery: 8291
[INFO] Embedding gallery...
[INFO] Embedding queries...
[INFO] Metrics: {'R@1': 0.13272311212814644, 'R@5': 0.2723112128146453, 'R@10': 0.32036613272311215, 'mAP': 0.04022098200458359, 'num_queries': 437, 'num_gallery': 8291, 'num_ids': 437}
[INFO] Saved -> C:\Users\sagni\Downloads\WildTrack\metrics.json
[INFO] Running t-SNE on (2000, 256)
[INFO] Saved -> C:\Users\sagni\Downloads\WildTrack\tsne.png
[INFO] Saved -> C:\Users\sagni\Downloads\