
# Results & Plots
Combine TensorFlow and PyTorch metrics, generate comparison tables and charts.


In [3]:
import os, json, ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# =====================================================================
# Directories
# =====================================================================
RESULTS_DIR = Path("results")
VIZ_DIR = RESULTS_DIR / "visualizations"
VIZ_DIR.mkdir(exist_ok=True, parents=True)

csv_path = RESULTS_DIR / "metrics.csv"
df = pd.read_csv(csv_path)

# =====================================================================
# Fix ROC Columns ‚Äî handles Infinity, -Infinity, NaN
# =====================================================================
def clean_list_string(s):
    """Convert stringified lists with Infinity/NaN into real Python lists."""
    if not isinstance(s, str):
        return s

    # Replace invalid JSON tokens
    s = s.replace("Infinity", "1e309")
    s = s.replace("-Infinity", "-1e309")
    s = s.replace("NaN", "null")

    # Try JSON parse
    try:
        return json.loads(s)
    except:
        pass

    # Try Python literal parse
    try:
        return ast.literal_eval(s)
    except:
        print("‚ùå Could not parse:", s[:80])
        return []

for col in ["roc_fpr", "roc_tpr", "roc_thresholds"]:
    df[col] = df[col].apply(clean_list_string)

print("‚úÖ ROC vectors cleaned successfully!")


# =====================================================================
# Helper: Save plot
# =====================================================================
def save_plot(filename):
    plt.tight_layout()
    plt.savefig(VIZ_DIR / filename, dpi=300)
    plt.close()


# =====================================================================
# Bar Plot
# =====================================================================
def plot_bar(metric, title=None, ylabel=None):
    if metric not in df.columns:
        print(f"‚ö†Ô∏è Metric '{metric}' not found in CSV ‚Äî skipping")
        return

    plt.figure(figsize=(10,5))
    plt.bar(df["model"], df[metric])
    plt.xticks(rotation=45, ha="right")
    plt.title(title or metric)
    plt.ylabel(ylabel or metric)
    save_plot(f"{metric}.png")


# =====================================================================
# ROC Curve for ALL models in one plot
# =====================================================================
def plot_all_roc():
    plt.figure(figsize=(8,6))

    for _, row in df.iterrows():
        fpr = row["roc_fpr"]
        tpr = row["roc_tpr"]

        if isinstance(fpr, list) and isinstance(tpr, list) and len(fpr) == len(tpr):
            plt.plot(fpr, tpr, label=row["model"])

    plt.plot([0,1], [0,1], "k--", alpha=0.4)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curves ‚Äî All Models")
    plt.legend()

    save_plot("roc_all_models.png")


# =====================================================================
# Radar Chart for each model
# =====================================================================
def plot_radar(model_name):
    row = df[df["model"] == model_name].iloc[0]

    metrics = ["test_accuracy","test_precision","test_recall","test_f1","test_auc"]
    values = [row[m] for m in metrics]
    values += values[:1]  # close radar loop

    angles = np.linspace(0, 2*np.pi, len(metrics) + 1)

    plt.figure(figsize=(6,6))
    ax = plt.subplot(111, polar=True)

    ax.plot(angles, values, linewidth=2)
    ax.fill(angles, values, alpha=0.25)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(metrics)

    plt.title(f"Radar Chart ‚Äî {model_name}")
    save_plot(f"radar_{model_name}.png")


# =====================================================================
# Generate All Visualizations
# =====================================================================
bar_metrics = [
    # validation
    "val_accuracy", "val_precision", "val_recall", "val_f1", "val_auc",
    "val_loss",

    # test
    "test_accuracy", "test_precision", "test_recall", "test_f1", "test_auc",
    "test_loss",

    # resource usage
    "train_time_sec", "memory_mb", "gpu_used_memory_mb",
]

for metric in bar_metrics:
    plot_bar(metric, f"{metric} Comparison", metric)

# ROC
plot_all_roc()

# Radar per model
for model in df["model"]:
    plot_radar(model)

print(f"üéâ All visualizations generated in: {VIZ_DIR}")


‚úÖ ROC vectors cleaned successfully!
üéâ All visualizations generated in: results/visualizations
