In [1]:
# ============================================================
# WildTrack — Build embedding index + search API (SeaTurtleID)
# Saves: embeddings.npy, meta.csv, index.faiss/index.npz
# Writes: app.py (FastAPI) and index.html demo
# ============================================================
import os, io, sys, csv, base64, pickle, json
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from PIL import Image

# ---------- Paths (edit if needed) ----------
DATA_DIR   = r"C:\Users\sagni\Downloads\WildTrack\archive\turtles-data\data"
OUTPUT_DIR = r"C:\Users\sagni\Downloads\WildTrack"

MODEL_KERAS = os.path.join(OUTPUT_DIR, "model.keras")
MODEL_H5    = os.path.join(OUTPUT_DIR, "model.h5")
PP_PATH     = os.path.join(OUTPUT_DIR, "preprocessor.pkl")

EMB_NPY     = os.path.join(OUTPUT_DIR, "embeddings.npy")
META_CSV    = os.path.join(OUTPUT_DIR, "meta.csv")
FAISS_IDX   = os.path.join(OUTPUT_DIR, "index.faiss")
NP_IDX      = os.path.join(OUTPUT_DIR, "index.npz")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------- Load preprocessor ----------
with open(PP_PATH, "rb") as f:
    preproc = pickle.load(f)

IMG_SIZE   = tuple(preproc.get("image_size", (224, 224)))
label2id   = {k: int(v) for k, v in preproc["label2id"].items()}
id2label   = {int(k): v for k, v in preproc["id2label"].items()}
backbone   = str(preproc.get("backbone", "EfficientNetB0"))
embed_dim  = int(preproc.get("embed_dim", 256))
use_l2     = bool(preproc.get("use_l2norm", True))
VAL_SPLIT  = float(preproc.get("val_split", 0.15))
SEED       = 42

# ---------- Version-safe custom layer (for older Keras) ----------
try:
    from tensorflow.keras.utils import register_keras_serializable
except Exception:
    try:
        from keras.utils import register_keras_serializable
    except Exception:
        def register_keras_serializable(package="WildTrack"):
            def deco(obj): return obj
            return deco

@register_keras_serializable(package="WildTrack")
class L2Normalize(layers.Layer):
    def __init__(self, axis=-1, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis
    def call(self, x):
        return tf.math.l2_normalize(x, axis=self.axis)
    def get_config(self):
        cfg = super().get_config()
        cfg.update({"axis": self.axis})
        return cfg

# ---------- Robust model loader ----------
def build_classifier(n_classes:int, image_size, backbone="EfficientNetB0",
                     embed_dim=256, use_l2=True):
    inputs = keras.Input(shape=(image_size[0], image_size[1], 3))
    bb = backbone.lower()
    if bb == "efficientnetb0":
        base = keras.applications.EfficientNetB0(include_top=False, weights=None, pooling="avg")
    elif bb == "resnet50":
        base = keras.applications.ResNet50(include_top=False, weights=None, pooling="avg")
    elif bb == "mobilenetv2":
        base = keras.applications.MobileNetV2(include_top=False, weights=None, pooling="avg")
    else:
        base = keras.applications.EfficientNetB0(include_top=False, weights=None, pooling="avg")
    x = base(inputs)
    x = layers.Dropout(0.2, name="dropout")(x)
    if embed_dim and embed_dim > 0:
        x = layers.Dense(embed_dim, name="emb")(x)
    if use_l2:
        x = L2Normalize(name="l2norm")(x)
    outputs = layers.Dense(len(id2label), activation="softmax", name="softmax")(x)
    return keras.Model(inputs, outputs, name="wildtrack_classifier")

def load_classifier():
    custom = {"L2Normalize": L2Normalize}
    if os.path.exists(MODEL_KERAS):
        try:
            m = keras.models.load_model(MODEL_KERAS, custom_objects=custom)
            print(f"[INFO] Loaded model: {MODEL_KERAS}")
            return m
        except Exception as e:
            print("[WARN] model.keras load failed:", e)
    if os.path.exists(MODEL_H5):
        try:
            m = keras.models.load_model(MODEL_H5, compile=False, custom_objects=custom)
            print(f"[INFO] Loaded model: {MODEL_H5}")
            return m
        except Exception as e:
            print("[WARN] Direct H5 load failed:", e)
            # Rebuild & load weights (skip mismatches)
            m = build_classifier(len(id2label), IMG_SIZE, backbone, embed_dim, use_l2)
            try:
                m.load_weights(MODEL_H5, by_name=True, skip_mismatch=True)
                print("[INFO] Loaded weights by_name from H5.")
                return m
            except Exception as e2:
                print("[ERROR] Could not load weights:", e2)
                raise
    raise FileNotFoundError("Missing model file (.keras or .h5) in OUTPUT_DIR.")

model = load_classifier()

# ---------- Build embedding model (pre-softmax) ----------
def build_embedding_model(classifier: keras.Model):
    # Prefer "l2norm" layer if present (already L2-normalized)
    try:
        l = classifier.get_layer("l2norm")
        out = l.output
        print("[INFO] Using 'l2norm' layer output for embeddings.")
        return keras.Model(classifier.inputs, out)
    except Exception:
        pass
    # Else take the tensor feeding the final softmax
    # Find softmax Dense layer
    softmax_layer = None
    for l in classifier.layers[::-1]:
        if isinstance(l, layers.Dense) and getattr(l, "activation", None) == keras.activations.softmax:
            softmax_layer = l
            break
        if l.name.lower() == "softmax":
            softmax_layer = l
            break
    if softmax_layer is None:
        # fallback to last Dense or GlobalPool
        print("[WARN] Softmax layer not found; using last layer output.")
        return keras.Model(classifier.inputs, classifier.layers[-1].output)
    emb_tensor = softmax_layer.input  # tensor before softmax
    print("[INFO] Using pre-softmax tensor for embeddings.")
    emb_model = keras.Model(classifier.inputs, emb_tensor)
    # normalize at runtime to keep cosine meaningful
    inp = classifier.inputs
    x = emb_model(inp)
    x = tf.math.l2_normalize(x, axis=-1)
    return keras.Model(inp, x)

emb_model = build_embedding_model(model)

# ---------- Image utils ----------
def list_images(root, exts=(".jpg",".jpeg",".png",".bmp",".tif",".tiff")):
    out = []
    for dp,_,files in os.walk(root):
        for f in files:
            if f.lower().endswith(exts):
                out.append(os.path.join(dp, f))
    return out

SKIP_DIRS = set(["data","dataset","datasets","images","imgs","img","train",
                 "val","valid","validation","test","all","photos","pictures"])

def smart_label_from_path(p):
    parts = os.path.normpath(p).split(os.sep)
    for i in range(len(parts)-2,-1,-1):
        name = parts[i]
        if name.lower() not in SKIP_DIRS:
            return name
    return os.path.basename(os.path.dirname(p))

def load_and_preprocess(path):
    x = tf.io.read_file(path)
    x = tf.image.decode_image(x, channels=3, expand_animations=False)
    x = tf.image.convert_image_dtype(x, tf.float32)
    x = tf.image.resize(x, IMG_SIZE)
    return x

# ---------- Build gallery embeddings ----------
BATCH = 32
paths = list_images(DATA_DIR)
if not paths:
    raise RuntimeError(f"No images found under {DATA_DIR}")

labels = [smart_label_from_path(p) for p in paths]
# filter to known classes
keep_mask = [lab in label2id for lab in labels]
paths  = [p for p,k in zip(paths, keep_mask) if k]
labels = [l for l,k in zip(labels, keep_mask) if k]

print(f"[INFO] Gallery images: {len(paths)} across {len(set(labels))} IDs")

def batched(iterable, n):
    for i in range(0, len(iterable), n):
        yield iterable[i:i+n]

emb_list = []
for chunk in batched(paths, BATCH):
    batch = tf.stack([load_and_preprocess(p) for p in chunk], axis=0)
    embs  = emb_model.predict(batch, verbose=0)
    # L2 normalize (in case embedding head didn't do it)
    embs  = embs / (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12)
    emb_list.append(embs.astype(np.float32))

embeddings = np.vstack(emb_list)
assert embeddings.shape[0] == len(paths)
np.save(EMB_NPY, embeddings)
print(f"[INFO] Saved embeddings -> {EMB_NPY}  shape={embeddings.shape}")

# ---------- Save metadata ----------
with open(META_CSV, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["index","path","label","label_id"])
    for i,(p,l) in enumerate(zip(paths, labels)):
        w.writerow([i, p, l, label2id[l]])
print(f"[INFO] Saved metadata -> {META_CSV}")

# ---------- Build FAISS (if available) or NumPy index ----------
use_faiss = False
try:
    import faiss  # faiss-cpu
    d = embeddings.shape[1]
    index = faiss.IndexFlatIP(d)  # cosine if vectors are L2-normalized
    index.add(embeddings)
    faiss.write_index(index, FAISS_IDX)
    use_faiss = True
    print(f"[INFO] Saved FAISS index -> {FAISS_IDX}")
except Exception as e:
    np.savez(NP_IDX, embeddings=embeddings)
    print(f"[WARN] FAISS not available ({e}). Saved NumPy index -> {NP_IDX}")

# ---------- Search function ----------
def search(query_path, top_k=5):
    qimg = load_and_preprocess(query_path)[None, ...]
    qemb = emb_model.predict(qimg, verbose=0)[0]
    qemb = qemb / (np.linalg.norm(qemb) + 1e-12)

    if use_faiss:
        import faiss
        idx = faiss.read_index(FAISS_IDX)
        D, I = idx.search(qemb[None, :].astype(np.float32), top_k)
        sims = D[0].tolist()
        inds = I[0].tolist()
    else:
        sims = (embeddings @ qemb.astype(np.float32)).tolist()
        inds = np.argsort(sims)[::-1][:top_k].tolist()
        sims = [sims[i] for i in inds]

    results = []
    for rnk,(i,s) in enumerate(zip(inds, sims), 1):
        lab = labels[i]
        results.append({
            "rank": rnk,
            "index": int(i),
            "path": paths[i],
            "label": lab,
            "label_id": int(label2id[lab]),
            "similarity": float(s)
        })
    return results

# Quick smoke test (optional): uncomment to try one random image
# import random; sample = random.choice(paths); print(search(sample, top_k=5))

# ---------- Write FastAPI server (app.py) ----------
APP_PY = os.path.join(OUTPUT_DIR, "app.py")
with open(APP_PY, "w", encoding="utf-8") as f:
    f.write(f'''# FastAPI visual search for WildTrack (SeaTurtleID)
import os, io, csv, base64, json
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import uvicorn
import tensorflow as tf
from tensorflow import keras

OUTPUT_DIR = r"{OUTPUT_DIR}"
DATA_DIR   = r"{DATA_DIR}"
EMB_NPY    = r"{EMB_NPY}"
META_CSV   = r"{META_CSV}"
FAISS_IDX  = r"{FAISS_IDX}"
NP_IDX     = r"{NP_IDX}"
MODEL_KERAS = r"{MODEL_KERAS}"
MODEL_H5    = r"{MODEL_H5}"
IMG_SIZE    = {list(IMG_SIZE)}

# ---- L2Normalize layer (for loading) ----
try:
    from tensorflow.keras.utils import register_keras_serializable
except Exception:
    try:
        from keras.utils import register_keras_serializable
    except Exception:
        def register_keras_serializable(package="WildTrack"):
            def deco(obj): return obj
            return deco

@register_keras_serializable(package="WildTrack")
class L2Normalize(keras.layers.Layer):
    def __init__(self, axis=-1, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis
    def call(self, x):
        return tf.math.l2_normalize(x, axis=self.axis)
    def get_config(self):
        cfg = super().get_config()
        cfg.update({{"axis": self.axis}})
        return cfg

# ---- Load meta and embeddings ----
paths, labels = [], []
with open(META_CSV, newline="", encoding="utf-8") as fcsv:
    r = csv.DictReader(fcsv)
    for row in r:
        paths.append(row["path"])
        labels.append(row["label"])

if os.path.exists(EMB_NPY):
    embeddings = np.load(EMB_NPY).astype(np.float32)
else:
    raise FileNotFoundError("Missing embeddings.npy. Run the index builder first.")

use_faiss = False
try:
    import faiss
    if os.path.exists(FAISS_IDX):
        index = faiss.read_index(FAISS_IDX)
        use_faiss = True
except Exception:
    if os.path.exists(NP_IDX):
        npz = np.load(NP_IDX)
        # embeddings already loaded

# ---- Load classifier & build embedding model ----
def load_classifier():
    custom = {{"L2Normalize": L2Normalize}}
    if os.path.exists(MODEL_KERAS):
        try:
            return keras.models.load_model(MODEL_KERAS, custom_objects=custom)
        except Exception as e:
            print("[WARN] model.keras load failed:", e)
    if os.path.exists(MODEL_H5):
        try:
            return keras.models.load_model(MODEL_H5, compile=False, custom_objects=custom)
        except Exception as e:
            print("[WARN] model.h5 load failed:", e)
    raise FileNotFoundError("No model file found.")

clf = load_classifier()

def build_embedding_model(classifier):
    # Prefer l2norm if present
    try:
        l = classifier.get_layer("l2norm")
        return keras.Model(classifier.inputs, l.output)
    except Exception:
        pass
    # Pre-softmax fallback
    softmax_layer = None
    for l in classifier.layers[::-1]:
        if isinstance(l, keras.layers.Dense) and getattr(l, "activation", None) == keras.activations.softmax:
            softmax_layer = l
            break
        if l.name.lower() == "softmax":
            softmax_layer = l
            break
    if softmax_layer is None:
        return keras.Model(classifier.inputs, classifier.layers[-1].output)
    emb_tensor = softmax_layer.input
    inp = classifier.inputs
    x = keras.Model(inp, emb_tensor)(inp)
    x = tf.math.l2_normalize(x, axis=-1)
    return keras.Model(inp, x)

emb_model = build_embedding_model(clf)

# ---- Helpers ----
def load_and_preprocess_bytes(b: bytes):
    img = Image.open(io.BytesIO(b)).convert("RGB")
    img = img.resize(tuple(IMG_SIZE), Image.BILINEAR)
    arr = np.array(img).astype(np.float32) / 255.0
    return arr

def encode_thumb(path, max_side=256):
    try:
        img = Image.open(path).convert("RGB")
        w, h = img.size
        scale = max_side / max(w, h)
        if scale < 1:
            img = img.resize((int(w*scale), int(h*scale)), Image.BILINEAR)
        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=85)
        return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("ascii")
    except Exception:
        return None

# ---- Search core ----
def search_vector(qemb: np.ndarray, top_k=5):
    if use_faiss:
        import faiss
        D, I = index.search(qemb[None, :].astype(np.float32), top_k)
        sims = D[0].tolist()
        inds = I[0].tolist()
    else:
        sims_all = embeddings @ qemb.astype(np.float32)
        inds = np.argsort(sims_all)[::-1][:top_k].tolist()
        sims = [float(sims_all[i]) for i in inds]
    out = []
    for rnk,(i,s) in enumerate(zip(inds, sims), 1):
        out.append({{
            "rank": rnk,
            "index": int(i),
            "path": paths[i],
            "label": labels[i],
            "similarity": float(s),
            "thumbnail": encode_thumb(paths[i])
        }})
    return out

# ---- FastAPI app ----
app = FastAPI(title="WildTrack Search API", version="1.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

@app.get("/")
def root():
    html = (Path(OUTPUT_DIR) / "index.html").read_text(encoding="utf-8")
    return HTMLResponse(html)

@app.post("/search")
async def do_search(file: UploadFile = File(...), top_k: int = Form(5)):
    b = await file.read()
    arr = load_and_preprocess_bytes(b)
    qemb = emb_model.predict(arr[None, ...], verbose=0)[0]
    qemb = qemb / (np.linalg.norm(qemb) + 1e-12)
    results = search_vector(qemb, top_k=top_k)
    return JSONResponse({{"results": results}})

if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8000)
''')
print(f"[INFO] Wrote FastAPI app -> {APP_PY}")

# ---------- Write tiny HTML client ----------
INDEX_HTML = os.path.join(OUTPUT_DIR, "index.html")
with open(INDEX_HTML, "w", encoding="utf-8") as f:
    f.write("""<!doctype html>
<html>
<head>
<meta charset="utf-8"/>
<title>WildTrack Visual Search</title>
<style>
body{font-family:system-ui,Arial;margin:20px;max-width:900px}
#grid{display:grid;grid-template-columns:repeat(auto-fill,minmax(180px,1fr));gap:12px;margin-top:16px}
.card{border:1px solid #ddd;border-radius:10px;padding:8px;box-shadow:0 1px 4px rgba(0,0,0,.05)}
.card img{width:100%;border-radius:8px}
small{color:#555}
</style>
</head>
<body>
<h2>WildTrack — SeaTurtleID Search</h2>
<p>Upload an image to find visually similar individuals.</p>
<input type="file" id="file" accept="image/*"/>
<input type="number" id="k" min="1" max="20" value="5" style="width:72px;margin-left:8px"/> top-k
<button id="btn">Search</button>
<div id="status"></div>
<div id="grid"></div>
<script>
const btn = document.getElementById('btn');
btn.onclick = async () => {
  const file = document.getElementById('file').files[0];
  const k = document.getElementById('k').value;
  if(!file){ alert('Choose an image first'); return; }
  const fd = new FormData();
  fd.append('file', file);
  fd.append('top_k', k);
  document.getElementById('status').innerText = 'Searching...';
  const res = await fetch('/search', { method:'POST', body: fd });
  const js = await res.json();
  document.getElementById('status').innerText = '';
  const grid = document.getElementById('grid');
  grid.innerHTML = '';
  js.results.forEach(r => {
    const div = document.createElement('div');
    div.className = 'card';
    const img = document.createElement('img');
    img.src = r.thumbnail || '';
    const p = document.createElement('div');
    p.innerHTML = `<b>${r.label}</b><br/><small>${r.path}</small><br/><small>sim=${r.similarity.toFixed(3)}</small>`;
    div.appendChild(img);
    div.appendChild(p);
    grid.appendChild(div);
  });
};
</script>
</body>
</html>
""")
print(f"[INFO] Wrote demo HTML -> {INDEX_HTML}")

print("""
[INFO] Index build complete.
- embeddings.npy, meta.csv
- index.faiss (if faiss available) OR index.npz
- app.py (FastAPI) and index.html written to OUTPUT_DIR

Run the API:
  python app.py
Then open:
  http://127.0.0.1:8000
""")



[WARN] Direct H5 load failed: Exception encountered when calling Lambda.call().

[1mWe could not automatically infer the shape of the Lambda's output. Please specify the `output_shape` argument for this Lambda layer.[0m

Arguments received by Lambda.call():
  • args=('<KerasTensor shape=(None, 256), dtype=float32, sparse=False, ragged=False, name=keras_tensor_502>',)
  • kwargs={'mask': 'None'}
[INFO] Loaded weights by_name from H5.
[INFO] Using 'l2norm' layer output for embeddings.
[INFO] Gallery images: 8728 across 437 IDs
[INFO] Saved embeddings -> C:\Users\sagni\Downloads\WildTrack\embeddings.npy  shape=(8728, 256)
[INFO] Saved metadata -> C:\Users\sagni\Downloads\WildTrack\meta.csv
[WARN] FAISS not available (No module named 'faiss'). Saved NumPy index -> C:\Users\sagni\Downloads\WildTrack\index.npz
[INFO] Wrote FastAPI app -> C:\Users\sagni\Downloads\WildTrack\app.py
[INFO] Wrote demo HTML -> C:\Users\sagni\Downloads\WildTrack\index.html

[INFO] Index build complete.
- embeddi