In [2]:
import os
import json
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import pandas as pd
import numpy as np

# -------- CONFIG --------
BASE_DIR = r"C:\Users\NXTWAVE\Downloads\Ad Insight"
VISUALS_DIR = os.path.join(BASE_DIR, "visuals")
os.makedirs(VISUALS_DIR, exist_ok=True)

TRAIN_CSV = os.path.join(BASE_DIR, "archive", "train.csv")

# Load train data
df = pd.read_csv(TRAIN_CSV)

# Categorical encoding (same as pipeline)
from sklearn.preprocessing import LabelEncoder
def preprocess(df):
    df = df.copy()
    cat_cols = df.select_dtypes(include=["object"]).columns
    for c in cat_cols:
        df[c] = df[c].fillna("UNK")
        le = LabelEncoder()
        df[c] = le.fit_transform(df[c])
    return df.fillna(0)

df = preprocess(df)

target_cols = ["engagement", "mood", "recall"]
feature_cols = [c for c in df.columns if c not in target_cols]

X = df[feature_cols].values
y_dict = {c: df[c].values for c in target_cols if c in df}

# Load models
models = {
    "Engagement": os.path.join(BASE_DIR, "engagement_model.pkl"),
    "Mood": os.path.join(BASE_DIR, "mood_model.pkl"),
    "Recall": os.path.join(BASE_DIR, "recall_model.pkl"),
}
loaded_models = {k: joblib.load(v) for k, v in models.items() if os.path.exists(v)}

# Store accuracies
acc_scores = {}

# -------- Generate Accuracy Graph & Heatmaps --------
for name, model in loaded_models.items():
    y_true = y_dict[name.lower()] if name.lower() in y_dict else None
    if y_true is None:
        print(f"[WARN] No labels for {name}, skipping heatmap.")
        continue

    y_pred = model.predict(X)
    acc = np.mean(y_true == y_pred)
    acc_scores[name] = acc

    # Confusion matrix heatmap
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title(f"{name} Confusion Matrix (Accuracy={acc:.2f})")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    heatmap_path = os.path.join(VISUALS_DIR, f"{name.lower()}_heatmap.png")
    plt.savefig(heatmap_path)
    plt.close()
    print(f"[INFO] Saved heatmap for {name} at {heatmap_path}")

# Accuracy Bar Graph
if acc_scores:
    plt.figure(figsize=(6, 4))
    sns.barplot(x=list(acc_scores.keys()), y=list(acc_scores.values()), palette="viridis")
    plt.ylim(0, 1)
    plt.ylabel("Accuracy")
    plt.title("Model Accuracies")
    for i, v in enumerate(acc_scores.values()):
        plt.text(i, v + 0.02, f"{v:.2f}", ha="center", fontsize=10)
    bar_path = os.path.join(VISUALS_DIR, "accuracy_bar.png")
    plt.savefig(bar_path)
    plt.close()
    print(f"[INFO] Saved accuracy bar chart at {bar_path}")
