In [1]:
import os
import ast
import json
import math
import warnings
from collections import Counter

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix
)

# ==============================
# Paths (EDIT if needed)
# ==============================
INPUT_CSV  = r"C:\Users\sagni\Downloads\RAG based Medical FAQ Chatbot\archive\train.csv"   # or your predictions CSV
OUTPUT_DIR = r"C:\Users\sagni\Downloads\RAG based Medical FAQ Chatbot"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ==============================
# Column name candidates
# ==============================
GT_CANDIDATES   = ["label", "target", "answer", "ground_truth", "gt"]
PRED_CANDIDATES = ["pred", "prediction", "predicted", "response", "pred_label"]

# Retrieval columns (optional)
QUERY_COL_CAND   = ["query", "question"]
GOLD_ID_CAND     = ["gold_id", "relevant_id", "doc_id", "answer_id"]
RETR_IDS_CAND    = ["retrieved_ids", "retrieved", "neighbors", "ctx_ids"]

# ==============================
# Load data
# ==============================
print(f"[INFO] Loading: {INPUT_CSV}")
df = pd.read_csv(INPUT_CSV)
print(f"[INFO] Loaded shape: {df.shape}")

def find_col(candidates, dfcols):
    for c in candidates:
        if c in dfcols:
            return c
    return None

cols_lower = {c.lower(): c for c in df.columns}
# Map to lowercase to allow flexible matching
df.columns = [c.lower() for c in df.columns]

gt_col   = find_col(GT_CANDIDATES, df.columns)
pred_col = find_col(PRED_CANDIDATES, df.columns)

if gt_col is None or pred_col is None:
    print("[WARN] Could not find ground-truth or prediction columns by common names.")
    print("[HINT] Ensure your CSV has columns like:")
    print(f"  GT  : {GT_CANDIDATES}")
    print(f"  Pred: {PRED_CANDIDATES}")
    # We’ll attempt a graceful exit with message
    # (You can adapt this to your actual column names.)
else:
    # ==============================
    # Clean + align labels
    # ==============================
    def normalize(x):
        # unify strings; keep numbers as strings too
        if pd.isna(x):
            return ""
        return str(x).strip()

    y_true = df[gt_col].map(normalize).values
    y_pred = df[pred_col].map(normalize).values

    # Remove rows with empty gt/pred
    mask = (y_true != "") & (y_pred != "")
    if mask.sum() == 0:
        raise ValueError("No valid rows after filtering empty ground-truth/prediction.")
    y_true = y_true[mask]
    y_pred = y_pred[mask]

    # ==============================
    # Accuracy
    # ==============================
    acc = accuracy_score(y_true, y_pred)
    print(f"[OK] Accuracy: {acc:.4f}")

    # ==============================
    # Classification report
    # ==============================
    report_dict = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    rep_df = pd.DataFrame(report_dict).transpose()
    rep_csv = os.path.join(OUTPUT_DIR, "classification_report.csv")
    rep_df.to_csv(rep_csv, index=True, encoding="utf-8")
    print(f"[OK] Saved classification report -> {rep_csv}")

    # ==============================
    # Confusion matrix
    # ==============================
    labels = sorted(list(set(y_true) | set(y_pred)))
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    cm_csv = os.path.join(OUTPUT_DIR, "confusion_matrix_raw.csv")
    pd.DataFrame(cm, index=labels, columns=labels).to_csv(cm_csv, encoding="utf-8")
    print(f"[OK] Saved confusion matrix (raw counts) -> {cm_csv}")

    # Normalized (row-wise)
    with np.errstate(invalid="ignore"):
        cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
        cm_norm = np.nan_to_num(cm_norm)
    cmn_csv = os.path.join(OUTPUT_DIR, "confusion_matrix_normalized.csv")
    pd.DataFrame(cm_norm, index=labels, columns=labels).to_csv(cmn_csv, encoding="utf-8")
    print(f"[OK] Saved confusion matrix (normalized) -> {cmn_csv}")

    # ==============================
    # Plot: Accuracy bar
    # ==============================
    acc_png = os.path.join(OUTPUT_DIR, "accuracy.png")
    plt.figure(figsize=(5, 4), dpi=140)
    plt.bar(["Accuracy"], [acc])
    plt.ylim(0, 1)
    plt.title("Model Accuracy")
    plt.ylabel("Accuracy")
    plt.grid(axis="y", linestyle="--", alpha=0.4)
    for i, v in enumerate([acc]):
        plt.text(i, v + 0.02, f"{v:.3f}", ha="center", fontsize=10)
    plt.tight_layout()
    plt.savefig(acc_png, bbox_inches="tight")
    plt.close()
    print(f"[OK] Saved accuracy graph -> {acc_png}")

    # ==============================
    # Plot: Confusion matrix heatmap (raw)
    # ==============================
    cm_png = os.path.join(OUTPUT_DIR, "confusion_matrix_heatmap.png")
    fig = plt.figure(figsize=(max(6, 0.35*len(labels)), max(5, 0.35*len(labels))), dpi=140)
    ax = fig.add_subplot(111)
    im = ax.imshow(cm, interpolation="nearest")
    plt.title("Confusion Matrix (Counts)")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(labels))
    ax.set_xticks(tick_marks)
    ax.set_yticks(tick_marks)
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_yticklabels(labels)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")

    # annotate
    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            val = cm[i, j]
            ax.text(j, i, str(val),
                    ha="center", va="center",
                    color="white" if val > thresh else "black",
                    fontsize=8)
    plt.tight_layout()
    plt.savefig(cm_png, bbox_inches="tight")
    plt.close()
    print(f"[OK] Saved confusion-matrix heatmap (counts) -> {cm_png}")

    # ==============================
    # Plot: Confusion matrix heatmap (normalized)
    # ==============================
    cmn_png = os.path.join(OUTPUT_DIR, "confusion_matrix_heatmap_normalized.png")
    fig = plt.figure(figsize=(max(6, 0.35*len(labels)), max(5, 0.35*len(labels))), dpi=140)
    ax = fig.add_subplot(111)
    im = ax.imshow(cm_norm, interpolation="nearest")
    plt.title("Confusion Matrix (Row-normalized)")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(labels))
    ax.set_xticks(tick_marks)
    ax.set_yticks(tick_marks)
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_yticklabels(labels)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")

    # annotate with percentages
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            val = cm_norm[i, j]
            ax.text(j, i, f"{val*100:.1f}%",
                    ha="center", va="center",
                    color="white" if val > 0.5 else "black",
                    fontsize=8)
    plt.tight_layout()
    plt.savefig(cmn_png, bbox_inches="tight")
    plt.close()
    print(f"[OK] Saved confusion-matrix heatmap (normalized) -> {cmn_png}")

# ==============================
# OPTIONAL: Retrieval Recall@K
# If your CSV has retrieval columns, we compute Recall@1/3/5 and plot a bar chart.
# ==============================
q_col = find_col(QUERY_COL_CAND, df.columns)
g_col = find_col(GOLD_ID_CAND, df.columns)
r_col = find_col(RETR_IDS_CAND, df.columns)

def parse_id_list(x):
    """Accepts formats: JSON list, 'a,b,c', 'a|b|c'; returns list of strings."""
    if pd.isna(x):
        return []
    s = str(x).strip()
    if not s:
        return []
    # Try JSON/list literal
    try:
        v = ast.literal_eval(s)
        if isinstance(v, (list, tuple)):
            return [str(z) for z in v]
    except Exception:
        pass
    # Fallback: split by comma or pipe
    if "|" in s:
        return [t.strip() for t in s.split("|") if t.strip()]
    if "," in s:
        return [t.strip() for t in s.split(",") if t.strip()]
    return [s]

if all(c is not None for c in [q_col, g_col, r_col]):
    print("[INFO] Retrieval columns detected. Computing Recall@K...")
    golds = df[g_col].astype(str).fillna("")
    rets  = df[r_col].apply(parse_id_list)

    def recall_at_k(k):
        hits = 0
        total = 0
        for g, lst in zip(golds, rets):
            if not g or not lst:
                continue
            total += 1
            if g in lst[:k]:
                hits += 1
        return (hits / total) if total > 0 else 0.0, total

    ks = [1, 3, 5]
    recalls = []
    base_total = None
    for k in ks:
        r, total = recall_at_k(k)
        recalls.append(r)
        base_total = total if base_total is None else base_total

    # Save table
    recall_df = pd.DataFrame({"K": ks, "Recall": recalls})
    recall_csv = os.path.join(OUTPUT_DIR, "retrieval_recall_at_k.csv")
    recall_df.to_csv(recall_csv, index=False, encoding="utf-8")
    print(f"[OK] Saved retrieval Recall@K -> {recall_csv} (evaluated over {base_total} queries)")

    # Plot bar
    recall_png = os.path.join(OUTPUT_DIR, "retrieval_recall_at_k.png")
    plt.figure(figsize=(6, 4), dpi=140)
    plt.bar([str(k) for k in ks], recalls)
    plt.ylim(0, 1)
    plt.title("Retrieval Recall@K")
    plt.xlabel("K")
    plt.ylabel("Recall")
    for i, v in enumerate(recalls):
        plt.text(i, v + 0.02, f"{v:.3f}", ha="center", fontsize=10)
    plt.grid(axis="y", linestyle="--", alpha=0.4)
    plt.tight_layout()
    plt.savefig(recall_png, bbox_inches="tight")
    plt.close()
    print(f"[OK] Saved retrieval Recall@K chart -> {recall_png}")
else:
    print("[INFO] Retrieval columns not found; skipping Recall@K (optional).")

print("\n✅ Done. Check your output folder:")
print(OUTPUT_DIR)


[INFO] Loading: C:\Users\sagni\Downloads\RAG based Medical FAQ Chatbot\archive\train.csv
[INFO] Loaded shape: (16407, 3)
[WARN] Could not find ground-truth or prediction columns by common names.
[HINT] Ensure your CSV has columns like:
  GT  : ['label', 'target', 'answer', 'ground_truth', 'gt']
  Pred: ['pred', 'prediction', 'predicted', 'response', 'pred_label']
[INFO] Retrieval columns not found; skipping Recall@K (optional).

✅ Done. Check your output folder:
C:\Users\sagni\Downloads\RAG based Medical FAQ Chatbot
