In [1]:
# ============================
# PlantDocX — Accuracy Graph + Confusion Matrix Heatmap
# ============================
import os, json, pickle, warnings
warnings.filterwarnings("ignore")

import numpy as np
import tensorflow as tf
from tensorflow import keras
from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt

try:
    import seaborn as sns
    USE_SNS = True
except Exception:
    USE_SNS = False

# ---------------------------
# Paths
# ---------------------------
DATA_DIR   = r"C:\Users\sagni\Downloads\PlantDocX\archive\PlantVillage"   # root with class subfolders
OUTPUT_DIR = r"C:\Users\sagni\Downloads\PlantDocX"
os.makedirs(OUTPUT_DIR, exist_ok=True)

PP_PATH    = os.path.join(OUTPUT_DIR, "preprocessor.pkl")
KERAS_PATH = os.path.join(OUTPUT_DIR, "cls_model.keras")
H5_PATH    = os.path.join(OUTPUT_DIR, "cls_model.h5")

# ---------------------------
# Load preprocessor + model
# ---------------------------
if not os.path.exists(PP_PATH):
    raise FileNotFoundError(f"preprocessor.pkl not found at {PP_PATH}. Train/export step must run first.")

with open(PP_PATH, "rb") as f:
    preproc = pickle.load(f)

class_names   = preproc.get("class_names", None)
if class_names is None:
    raise RuntimeError("class_names missing in preprocessor.pkl")

IMG_SIZE      = tuple(preproc.get("image_size", (224,224)))
SEED          = int(preproc.get("seed", 42))

model = None
if os.path.exists(KERAS_PATH):
    try:
        model = tf.keras.models.load_model(KERAS_PATH, safe_mode=False)
        print(f"[INFO] Loaded model: {KERAS_PATH}")
    except Exception as e:
        print("[WARN] Could not load .keras model:", e)

if model is None and os.path.exists(H5_PATH):
    model = tf.keras.models.load_model(H5_PATH, compile=False)
    print(f"[INFO] Loaded model: {H5_PATH}")

if model is None:
    raise FileNotFoundError("No model file found (.keras or .h5).")

num_classes = len(class_names)

# ---------------------------
# Build validation dataset (same split & class order)
# ---------------------------
BATCH = 32

def make_val_ds():
    # Prefer passing class_names to lock label indices to training order
    try:
        ds = tf.keras.utils.image_dataset_from_directory(
            DATA_DIR, labels="inferred", label_mode="int",
            class_names=class_names,               # keep same order as training
            color_mode="rgb", image_size=IMG_SIZE,
            batch_size=BATCH, shuffle=False, seed=SEED,
            validation_split=0.2, subset="validation",
        )
        derived_class_names = class_names[:]  # already forced
    except TypeError:
        # Older TF that doesn't accept class_names; derive then remap labels
        ds = tf.keras.utils.image_dataset_from_directory(
            DATA_DIR, labels="inferred", label_mode="int",
            color_mode="rgb", image_size=IMG_SIZE,
            batch_size=BATCH, shuffle=False, seed=SEED,
            validation_split=0.2, subset="validation",
        )
        derived_class_names = getattr(ds, "class_names", None)

    # Be tolerant of a few corrupt images
    try:
        ds = ds.ignore_errors()
    except Exception:
        try:
            ds = ds.apply(tf.data.experimental.ignore_errors())
        except Exception:
            pass

    return ds, derived_class_names

val_ds, derived_class_names = make_val_ds()

# Build mapping from dataset indices -> training indices if needed
index_remap = None
if derived_class_names is not None and derived_class_names != class_names:
    # Map e.g. ds_idx -> preproc_idx using names
    name2preproc = {n:i for i,n in enumerate(class_names)}
    index_remap = {i: name2preproc[n] for i,n in enumerate(derived_class_names)}
    print("[WARN] Directory class order differs from training order. Applying index remap.")
else:
    index_remap = None

# Prefetch for speed
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)

# ---------------------------
# Collect predictions
# ---------------------------
y_true_raw, y_pred = [], []
for xb, yb in val_ds:
    probs = model.predict(xb, verbose=0)
    yhat = np.argmax(probs, axis=1)
    y_pred.extend(yhat.tolist())
    y_true_raw.extend(yb.numpy().tolist())

y_pred = np.array(y_pred, dtype=int)
y_true_raw = np.array(y_true_raw, dtype=int)

# Remap true labels to training order if needed
if index_remap is not None:
    remap_vec = np.vectorize(lambda idx: index_remap.get(int(idx), -1))
    y_true = remap_vec(y_true_raw).astype(int)
else:
    y_true = y_true_raw

# Filter out any -1 (unmappable) just in case
mask_valid = (y_true >= 0) & (y_true < num_classes)
y_true = y_true[mask_valid]
y_pred = y_pred[mask_valid]

# ---------------------------
# Overall accuracy
# ---------------------------
acc = float(accuracy_score(y_true, y_pred))
print(f"[INFO] Validation accuracy: {acc:.4f}")

# ---------------------------
# Confusion matrix heatmap
# ---------------------------
cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_classes))
fig = plt.figure(figsize=(max(6, num_classes*0.4), max(5, num_classes*0.4)))
if USE_SNS:
    sns.heatmap(cm, annot=False, fmt="d", cmap="viridis")
    plt.title("Confusion Matrix")
else:
    plt.imshow(cm, interpolation="nearest")
    plt.title("Confusion Matrix")
    plt.colorbar()
plt.xlabel("Predicted")
plt.ylabel("Actual")
# Tick labels (sparse if many classes)
if num_classes <= 30:
    plt.xticks(ticks=np.arange(num_classes), labels=class_names, rotation=90)
    plt.yticks(ticks=np.arange(num_classes), labels=class_names)
else:
    plt.xticks([]); plt.yticks([])
plt.tight_layout()
cm_path = os.path.join(OUTPUT_DIR, "confusion_matrix.png")
plt.savefig(cm_path, dpi=150)
plt.close(fig)
print(f"[INFO] Saved heatmap -> {cm_path}")

# ---------------------------
# Per-class accuracy bar chart
# ---------------------------
per_acc = []
support = []
for i in range(num_classes):
    sel = (y_true == i)
    n = sel.sum()
    support.append(int(n))
    if n == 0:
        per_acc.append(0.0)
    else:
        per_acc.append(float((y_pred[sel] == i).mean()))

# Sort by accuracy (optional)
order = np.argsort(per_acc)[::-1]
sorted_names = [class_names[i] for i in order]
sorted_acc   = [per_acc[i] for i in order]
sorted_sup   = [support[i] for i in order]

fig = plt.figure(figsize=(max(8, num_classes*0.45), 6))
plt.bar(range(num_classes), sorted_acc)
plt.xticks(range(num_classes), sorted_names, rotation=90)
plt.ylim(0, 1.0)
plt.ylabel("Per-class accuracy")
plt.title(f"Per-class Accuracy (val) — overall={acc:.3f}")
# Optional: annotate supports for context
for i, (a, s) in enumerate(zip(sorted_acc, sorted_sup)):
    if num_classes <= 40 and s > 0:
        plt.text(i, a+0.01, str(s), ha="center", va="bottom", fontsize=7, rotation=90)
plt.tight_layout()
acc_path = os.path.join(OUTPUT_DIR, "accuracy_per_class.png")
plt.savefig(acc_path, dpi=150)
plt.close(fig)
print(f"[INFO] Saved per-class accuracy graph -> {acc_path}")


[INFO] Loaded model: C:\Users\sagni\Downloads\PlantDocX\cls_model.keras
Found 20638 files belonging to 15 classes.
Using 4127 files for validation.
[INFO] Validation accuracy: 0.0000
[INFO] Saved heatmap -> C:\Users\sagni\Downloads\PlantDocX\confusion_matrix.png
[INFO] Saved per-class accuracy graph -> C:\Users\sagni\Downloads\PlantDocX\accuracy_per_class.png
