In [3]:
# ============================================
# WildTrack — Accuracy graph + Confusion Matrix heatmap (robust loader)
# Fixes AttributeError: 'keras' has no attribute 'saving'
# Uses version-safe register_keras_serializable import
# ============================================
import os, pickle
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt

try:
    import seaborn as sns
    USE_SNS = True
except Exception:
    USE_SNS = False

# -------------------------
# Version-safe registration
# -------------------------
try:
    from tensorflow.keras.utils import register_keras_serializable
except Exception:
    try:
        from keras.utils import register_keras_serializable  # pragma: no cover
    except Exception:  # very old TF/Keras
        def register_keras_serializable(package="Custom", name=None):
            def deco(obj): return obj
            return deco

# -------------------------
# Paths (edit if needed)
# -------------------------
DATA_DIR   = r"C:\Users\sagni\Downloads\WildTrack\archive\turtles-data\data"
OUTPUT_DIR = r"C:\Users\sagni\Downloads\WildTrack"
MODEL_KERAS = os.path.join(OUTPUT_DIR, "model.keras")
MODEL_H5    = os.path.join(OUTPUT_DIR, "model.h5")
PP_PATH     = os.path.join(OUTPUT_DIR, "preprocessor.pkl")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------
# Load preprocessor
# -------------------------
with open(PP_PATH, "rb") as f:
    preproc = pickle.load(f)

IMG_SIZE   = tuple(preproc.get("image_size", (224,224)))
VAL_SPLIT  = float(preproc.get("val_split", 0.15))
label2id   = dict(preproc["label2id"])
id2label   = {int(k): v for k, v in preproc["id2label"].items()}
SEED       = 42

# Optional model config (helps rebuilding when using H5)
backbone_name = str(preproc.get("backbone", "EfficientNetB0"))
embed_dim     = int(preproc.get("embed_dim", 256))
use_l2norm    = bool(preproc.get("use_l2norm", True))

# -------------------------
# Serializable L2 normalize (avoids Lambda pitfalls)
# -------------------------
@register_keras_serializable(package="WildTrack")
class L2Normalize(layers.Layer):
    def __init__(self, axis=-1, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis
    def call(self, x):
        return tf.math.l2_normalize(x, axis=self.axis)
    def get_config(self):
        cfg = super().get_config()
        cfg.update({"axis": self.axis})
        return cfg

def build_classifier(n_classes:int, image_size, backbone="EfficientNetB0",
                     embed_dim=256, use_l2=True):
    inputs = keras.Input(shape=(image_size[0], image_size[1], 3))
    bb = backbone.lower()
    if bb == "efficientnetb0":
        base = keras.applications.EfficientNetB0(include_top=False, weights=None, pooling="avg")
    elif bb == "resnet50":
        base = keras.applications.ResNet50(include_top=False, weights=None, pooling="avg")
    elif bb == "mobilenetv2":
        base = keras.applications.MobileNetV2(include_top=False, weights=None, pooling="avg")
    else:
        base = keras.applications.EfficientNetB0(include_top=False, weights=None, pooling="avg")
    x = base(inputs)
    x = layers.Dropout(0.2, name="dropout")(x)
    if embed_dim and embed_dim > 0:
        x = layers.Dense(embed_dim, name="emb")(x)
    if use_l2:
        x = L2Normalize(name="l2norm")(x)
    outputs = layers.Dense(n_classes, activation="softmax", name="softmax")(x)
    model = keras.Model(inputs, outputs, name="wildtrack_classifier")
    return model

# -------------------------
# Robust model loader
# -------------------------
def load_model_robust(n_classes):
    custom = {"L2Normalize": L2Normalize}
    # 1) Prefer native .keras (no Lambda issues)
    if os.path.exists(MODEL_KERAS):
        try:
            m = keras.models.load_model(MODEL_KERAS, custom_objects=custom)
            print(f"[INFO] Loaded native model: {MODEL_KERAS}")
            return m
        except Exception as e:
            print("[WARN] Could not load model.keras:", e)

    # 2) Rebuild arch and load weights from H5 (skip mismatches)
    if os.path.exists(MODEL_H5):
        try:
            # Try direct H5 load first (may fail if it had Lambda)
            m = keras.models.load_model(MODEL_H5, compile=False, custom_objects=custom)
            print(f"[INFO] Loaded H5 model directly: {MODEL_H5}")
            return m
        except Exception as e:
            print("[WARN] Direct H5 load failed (expected with legacy Lambda):", e)

        model = build_classifier(
            n_classes=n_classes,
            image_size=IMG_SIZE,
            backbone=backbone_name,
            embed_dim=embed_dim,
            use_l2=use_l2norm
        )
        try:
            model.load_weights(MODEL_H5, by_name=True, skip_mismatch=True)
            print(f"[INFO] Loaded weights from H5 into rebuilt architecture (by_name, skip_mismatch).")
            return model
        except Exception as e2:
            print("[ERROR] Could not load weights from H5:", e2)
            raise

    raise FileNotFoundError("No model file found. Expected model.keras or model.h5 in OUTPUT_DIR.")

# -------------------------
# Data helpers
# -------------------------
SKIP_DIRS = set([
    "data","dataset","datasets","images","imgs","img",
    "train","val","valid","validation","test","all","photos","pictures"
])

def smart_label_from_path(p):
    parts = os.path.normpath(p).split(os.sep)
    for i in range(len(parts)-2, -1, -1):
        name = parts[i]
        if name.lower() not in SKIP_DIRS:
            return name
    return os.path.basename(os.path.dirname(p))

def find_images(root, exts=(".jpg",".jpeg",".png",".bmp",".tif",".tiff")):
    out = []
    for dp, _, files in os.walk(root):
        for f in files:
            if f.lower().endswith(exts):
                out.append(os.path.join(dp, f))
    return out

# -------------------------
# Build validation split (deterministic)
# -------------------------
all_imgs = find_images(DATA_DIR)
if not all_imgs:
    raise RuntimeError(f"No images found under {DATA_DIR}")

paths, labels = [], []
for p in all_imgs:
    lab = smart_label_from_path(p)
    if lab in label2id:        # keep only classes seen at train time
        paths.append(p)
        labels.append(lab)

df = pd.DataFrame({"path": paths, "label": labels})
if df.empty:
    raise RuntimeError("After filtering to known classes, no images remain. Check DATA_DIR and label map.")

rng = np.random.RandomState(SEED)
df["rnd"] = rng.rand(len(df))
val_mask = df.groupby("label")["rnd"].transform(lambda s: s.rank(pct=True)) <= VAL_SPLIT
df_val   = df[val_mask].drop(columns=["rnd"]).reset_index(drop=True)

if df_val.empty:
    raise RuntimeError("Validation split is empty. Increase VAL_SPLIT or verify dataset layout.")

print(f"[INFO] Validation samples: {len(df_val)} | Classes in val: {df_val['label'].nunique()}")

# -------------------------
# tf.data pipeline
# -------------------------
AUTOTUNE = tf.data.AUTOTUNE

def decode_img(path):
    x = tf.io.read_file(path)
    x = tf.image.decode_image(x, channels=3, expand_animations=False)
    x = tf.image.convert_image_dtype(x, tf.float32)
    x = tf.image.resize(x, IMG_SIZE)
    return x

def tf_map(path, label_str):
    # Map label string to index via label2id
    def _map(s):
        return np.int32(label2id[s.decode("utf-8")])
    y = tf.numpy_function(_map, [label_str], tf.int32)
    y.set_shape([])  # scalar
    img = decode_img(path)
    return img, y

val_ds = tf.data.Dataset.from_tensor_slices((df_val["path"].values, df_val["label"].values))
val_ds = val_ds.map(tf_map, num_parallel_calls=AUTOTUNE).batch(32).prefetch(AUTOTUNE)

# -------------------------
# Load model (robust)
# -------------------------
n_classes = len(label2id)
model = load_model_robust(n_classes)

# -------------------------
# Predict & metrics
# -------------------------
y_true, y_prob = [], []
for bx, by in val_ds:
    pr = model.predict(bx, verbose=0)
    y_prob.append(pr)
    y_true.append(by.numpy())
y_prob = np.concatenate(y_prob, axis=0)
y_true = np.concatenate(y_true, axis=0)

# Handle binary/single-class heads gracefully
multi_class = (y_prob.ndim == 2 and y_prob.shape[1] > 1)
if multi_class:
    y_pred = y_prob.argmax(axis=1)
    label_names = [id2label[i] for i in range(n_classes)]
else:
    # Sigmoid/one-logit case
    y_pred = (y_prob.reshape(-1) >= 0.5).astype(int)
    # Derive label set actually present
    uniq_ids = np.unique(np.concatenate([y_true, y_pred]))
    label_names = [id2label.get(int(i), f"class_{int(i)}") for i in uniq_ids]

acc = float(accuracy_score(y_true, y_pred))
print(f"[INFO] Validation accuracy (Top-1): {acc:.4f}")

# -------------------------
# Confusion Matrix Heatmap
# -------------------------
if multi_class:
    label_ids = np.arange(n_classes)
else:
    label_ids = np.unique(np.concatenate([y_true, y_pred]))

cm = confusion_matrix(y_true, y_pred, labels=label_ids)

plt.figure(figsize=(max(6, 0.5*len(label_names)), max(5, 0.5*len(label_names))))
if USE_SNS:
    sns.heatmap(cm, annot=len(label_names) <= 30, fmt="d", cmap="Blues",
                xticklabels=label_names, yticklabels=label_names, cbar=True)
else:
    plt.imshow(cm, cmap="Blues")
    plt.colorbar()
    plt.xticks(range(len(label_names)), label_names, rotation=90)
    plt.yticks(range(len(label_names)), label_names)
    if len(label_names) <= 30:
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                plt.text(j, i, str(cm[i, j]),
                         ha="center", va="center",
                         color="white" if cm[i, j] > cm.max()/2 else "black")
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
cm_path = os.path.join(OUTPUT_DIR, "confusion_matrix.png")
plt.savefig(cm_path, dpi=160)
plt.close()
print(f"[INFO] Saved heatmap -> {cm_path}")

# -------------------------
# Per-class Accuracy Bar Chart
# -------------------------
if multi_class:
    per_class_acc = []
    for i in range(n_classes):
        idx = (y_true == i)
        if idx.sum() == 0:
            per_class_acc.append(np.nan)
        else:
            per_class_acc.append((y_pred[idx] == i).mean())

    plt.figure(figsize=(max(8, 0.6*n_classes), 5))
    xs = np.arange(n_classes)
    vals = np.array(per_class_acc)
    plt.bar(xs, np.nan_to_num(vals, nan=0.0))
    plt.xticks(xs, [id2label[i] for i in range(n_classes)], rotation=90)
    plt.ylim(0, 1)
    plt.ylabel("Accuracy")
    plt.title("Per-class Accuracy (Validation)")
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()
    acc_path = os.path.join(OUTPUT_DIR, "accuracy_per_class.png")
    plt.savefig(acc_path, dpi=160)
    plt.close()
    print(f"[INFO] Saved per-class accuracy graph -> {acc_path}")
else:
    print("[INFO] Single-class head detected; skipping per-class bar chart.")


[INFO] Validation samples: 1095 | Classes in val: 337
[WARN] Direct H5 load failed (expected with legacy Lambda): Exception encountered when calling Lambda.call().

[1mWe could not automatically infer the shape of the Lambda's output. Please specify the `output_shape` argument for this Lambda layer.[0m

Arguments received by Lambda.call():
  • args=('<KerasTensor shape=(None, 256), dtype=float32, sparse=False, ragged=False, name=keras_tensor_1007>',)
  • kwargs={'mask': 'None'}
[INFO] Loaded weights from H5 into rebuilt architecture (by_name, skip_mismatch).
[INFO] Validation accuracy (Top-1): 0.0018
[INFO] Saved heatmap -> C:\Users\sagni\Downloads\WildTrack\confusion_matrix.png
[INFO] Saved per-class accuracy graph -> C:\Users\sagni\Downloads\WildTrack\accuracy_per_class.png
