In [2]:
import os
import sys
import json
import logging
import re
from typing import List, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib

import tensorflow as tf
from tensorflow.keras.models import load_model

# =========================
# User knobs for Notebook
# =========================
INPUT_PATH  = r""  # e.g., r"C:\Users\sagni\Downloads\Art Genie\archive\wikiart_scraped.csv"
OUTPUT_PATH = r""  # e.g., r"C:\Users\sagni\Downloads\Art Genie\predictions.csv"
RUN_NOW     = False  # set True in notebooks to execute immediately

# =========================
# Artifacts & Outputs
# =========================
ARTIFACT_DIR   = r"C:\Users\sagni\Downloads\Art Genie"
PKL_PATH       = os.path.join(ARTIFACT_DIR, "artgenie_textclf.pkl")
H5_PATH        = os.path.join(ARTIFACT_DIR, "artgenie_textclf.h5")
HISTORY_CSV    = os.path.join(ARTIFACT_DIR, "history.csv")

DEFAULT_OUT    = os.path.join(ARTIFACT_DIR, "predictions.csv")
ACC_PNG        = os.path.join(ARTIFACT_DIR, "accuracy_curve.png")
CONF_FULL_CSV  = os.path.join(ARTIFACT_DIR, "confusion_matrix_full.csv")
HEATMAP_PNG    = os.path.join(ARTIFACT_DIR, "confusion_heatmap_top25.png")

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

# =========================
# Helpers (match training)
# =========================
POSSIBLE_TEXT_COLS = [
    "title","description","caption","tags","genre","style","artist",
    "movement","content","about","wiki","text","meta","materials","subject"
]
POSSIBLE_LABEL_COLS = ["style","genre","artist","label"]

def normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    out.columns = [str(c).strip().lower().replace(" ", "_") for c in df.columns]
    return out

def canonicalize(s: str) -> str:
    return re.sub(r"\s+", " ", str(s).lower()).strip()

def collect_text_columns(df: pd.DataFrame) -> List[str]:
    cols = [c for c in POSSIBLE_TEXT_COLS if c in df.columns]
    if not cols:
        cols = [c for c in df.columns if df[c].dtype == object]
    return cols[:8]

def build_text_series(df: pd.DataFrame, text_cols: List[str]) -> pd.Series:
    if not text_cols:
        raise ValueError("No text columns found to build input.")
    parts = [df[c].astype(str) for c in text_cols]
    txt = parts[0]
    for p in parts[1:]:
        txt = txt.str.cat(p, sep=" . ", na_rep="")
    return txt.fillna("").map(lambda s: re.sub(r"\s+", " ", str(s)).strip())

def pick_label_column(df: pd.DataFrame) -> Optional[str]:
    for c in POSSIBLE_LABEL_COLS:
        if c in df.columns and df[c].notna().any():
            return c
    return None

# =========================
# Plotting
# =========================
def plot_accuracy_curve(history_csv: str, out_png: str, also_show: bool = True):
    if not os.path.exists(history_csv):
        logging.warning(f"[PLOT] history.csv not found at {history_csv} — skipping accuracy plot.")
        return
    hist = pd.read_csv(history_csv)
    plt.figure(figsize=(9, 6))
    if "accuracy" in hist.columns:
        plt.plot(hist["epoch"], hist["accuracy"], label="Train Accuracy")
    if "val_accuracy" in hist.columns:
        plt.plot(hist["epoch"], hist["val_accuracy"], label="Val Accuracy")
    if "top3_acc" in hist.columns:
        plt.plot(hist["epoch"], hist["top3_acc"], label="Train Top-3")
    if "val_top3_acc" in hist.columns:
        plt.plot(hist["epoch"], hist["val_top3_acc"], label="Val Top-3")
    plt.title("Training Accuracy")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy")
    plt.legend(); plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tight_layout(); plt.savefig(out_png, dpi=160)
    if also_show:
        try: plt.show()
        except Exception: pass
    plt.close()
    logging.info(f"[PLOT] Saved accuracy curve → {out_png}")

def plot_confusion_heatmap(y_true_idx, y_pred_idx, class_names, out_png: str, full_csv: str, top_k: int = 25, also_show: bool = True):
    from sklearn.metrics import confusion_matrix
    cm_full = confusion_matrix(y_true_idx, y_pred_idx, labels=np.arange(len(class_names)))
    # Save full confusion matrix
    df_full = pd.DataFrame(cm_full, index=class_names, columns=class_names)
    df_full.to_csv(full_csv, encoding="utf-8")
    logging.info(f"[PLOT] Saved full confusion matrix CSV → {full_csv}")

    # Top-K by true frequency
    true_counts = cm_full.sum(axis=1)
    top_idx = np.argsort(true_counts)[::-1][:min(top_k, len(class_names))]
    sub = cm_full[np.ix_(top_idx, top_idx)]
    sub_labels = [class_names[i] for i in top_idx]

    plt.figure(figsize=(12, 10))
    im = plt.imshow(sub, interpolation="nearest", aspect="auto")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.title(f"Confusion Matrix Heatmap (Top {len(sub_labels)} Classes)")
    plt.xticks(range(len(sub_labels)), sub_labels, rotation=90)
    plt.yticks(range(len(sub_labels)), sub_labels)
    plt.tight_layout(); plt.savefig(out_png, dpi=160)
    if also_show:
        try: plt.show()
        except Exception: pass
    plt.close()
    logging.info(f"[PLOT] Saved confusion heatmap → {out_png}")

# =========================
# Core: predict + plots
# =========================
def predict_and_plot(input_csv: str, output_csv: str = DEFAULT_OUT, include_ids=True):
    # Load artifacts
    if not os.path.exists(PKL_PATH):
        raise FileNotFoundError(f"Missing sklearn pipeline PKL: {PKL_PATH}")
    if not os.path.exists(H5_PATH):
        raise FileNotFoundError(f"Missing Keras model H5: {H5_PATH}")

    bundle = joblib.load(PKL_PATH)  # {"pipeline": ..., "label_encoder": ...}
    sk_pipe = bundle["pipeline"]
    label_encoder = bundle["label_encoder"]
    class_names = list(label_encoder.classes_)
    num_classes = len(class_names)
    logging.info(f"[LOAD] Classes: {num_classes}")

    keras_model = load_model(H5_PATH)  # contains TextVectorization layer

    # Read input
    if not os.path.exists(input_csv):
        raise FileNotFoundError(f"Input CSV not found: {input_csv}")
    logging.info(f"[READ] {input_csv}")
    df = pd.read_csv(input_csv, engine="python")
    df = normalize_cols(df)

    # Build text like training
    text_cols = collect_text_columns(df)
    logging.info(f"[TEXT] Using columns: {text_cols}")
    df["__text__"] = build_text_series(df, text_cols)

    # Optional ground-truth for heatmap
    lbl_col = pick_label_column(df)
    if lbl_col:
        df[lbl_col] = df[lbl_col].astype(str).map(canonicalize)

    id_cols = [c for c in ["style","genre","artist","title"] if c in df.columns] if include_ids else []
    texts = df["__text__"].fillna("").astype(str).values

    # 1) sklearn predictions
    logging.info("[PRED] sklearn pipeline…")
    prob_sk = None
    try:
        prob_sk = sk_pipe.predict_proba(texts)
        idx_sk = np.argmax(prob_sk, axis=1)
        conf_sk = prob_sk[np.arange(len(idx_sk)), idx_sk]
    except Exception:
        idx_sk = sk_pipe.predict(texts)
        conf_sk = np.full_like(idx_sk, np.nan, dtype=float)
    pred_sk = label_encoder.inverse_transform(idx_sk)

    # 2) Keras predictions
    logging.info("[PRED] Keras .h5 model…")
    ds = tf.data.Dataset.from_tensor_slices(texts).batch(256)
    prob_k = keras_model.predict(ds, verbose=0)
    idx_k = np.argmax(prob_k, axis=1)
    conf_k = prob_k[np.arange(len(idx_k)), idx_k]
    pred_k = label_encoder.inverse_transform(idx_k)

    # 3) Simple average ensemble (if both prob available & aligned)
    if prob_sk is not None and prob_sk.shape == prob_k.shape == (len(texts), num_classes):
        avg_prob = (prob_sk + prob_k) / 2.0
        idx_avg = np.argmax(avg_prob, axis=1)
        conf_avg = avg_prob[np.arange(len(idx_avg)), idx_avg]
        pred_avg = label_encoder.inverse_transform(idx_avg)
    else:
        idx_avg, conf_avg, pred_avg = idx_k, conf_k, pred_k

    # Save predictions
    out = pd.DataFrame({
        "pred_sklearn": pred_sk,
        "pred_sklearn_conf": conf_sk,
        "pred_keras": pred_k,
        "pred_keras_conf": conf_k,
        "pred_avg": pred_avg,
        "pred_avg_conf": conf_avg
    })
    if id_cols:
        out = pd.concat([df[id_cols].reset_index(drop=True), out.reset_index(drop=True)], axis=1)

    os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
    out.to_csv(output_csv, index=False, encoding="utf-8")
    logging.info(f"[SAVE] Predictions → {output_csv}")

    # Plots
    plot_accuracy_curve(HISTORY_CSV, ACC_PNG, also_show=True)

    if lbl_col:
        logging.info(f"[EVAL] Found ground-truth column '{lbl_col}' — building confusion heatmap.")
        y_true_names = df[lbl_col].values
        is_known = np.isin(y_true_names, class_names)
        if np.any(is_known):
            y_true_idx = label_encoder.transform(y_true_names[is_known])
            y_pred_idx = idx_k[is_known]
            plot_confusion_heatmap(
                y_true_idx, y_pred_idx, class_names,
                out_png=HEATMAP_PNG, full_csv=CONF_FULL_CSV, top_k=25, also_show=True
            )
        else:
            logging.warning("[EVAL] None of the true labels match training classes — skipping heatmap.")
    else:
        logging.info("[EVAL] No true label column — skipping confusion heatmap.")

    return output_csv

# =========================
# CLI (Notebook-safe)
# =========================
def _maybe_parse_cli_args(argv):
    args = {"input": None, "out": DEFAULT_OUT, "include_ids": True}
    toks = list(argv[1:])
    i = 0
    while i < len(toks):
        t = toks[i]
        if t == "--input" and i + 1 < len(toks):
            args["input"] = toks[i+1]; i += 2
        elif t == "--out" and i + 1 < len(toks):
            args["out"] = toks[i+1]; i += 2
        elif t == "--no_ids":
            args["include_ids"] = False; i += 1
        else:
            i += 1
    return args

def _running_in_ipython() -> bool:
    try:
        from IPython import get_ipython  # noqa
        return get_ipython() is not None
    except Exception:
        return False

def main():
    # Notebook one-click run
    if _running_in_ipython() and RUN_NOW:
        if not INPUT_PATH:
            raise ValueError("Set INPUT_PATH at the top before RUN_NOW=True.")
        out_path = OUTPUT_PATH or DEFAULT_OUT
        predict_and_plot(INPUT_PATH, out_path, include_ids=True)
        return

    # Terminal / default path
    args = _maybe_parse_cli_args(sys.argv)
    if not args["input"]:
        script_name = os.path.basename(sys.argv[0]) if "__file__" not in globals() else os.path.basename(__file__)
        print(f"Usage:\n  python {script_name} --input <CSV> [--out <CSV>] [--no_ids]")
        print(f"Defaults:\n  Artifacts dir: {ARTIFACT_DIR}\n  Output: {DEFAULT_OUT}")
        sys.exit(2)
    predict_and_plot(args["input"], args["out"], include_ids=args["include_ids"])

if __name__ == "__main__":
    main()


Usage:
  python ipykernel_launcher.py --input <CSV> [--out <CSV>] [--no_ids]
Defaults:
  Artifacts dir: C:\Users\sagni\Downloads\Art Genie
  Output: C:\Users\sagni\Downloads\Art Genie\predictions.csv


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
