
# Results & Plots
Combine TensorFlow and PyTorch metrics, generate comparison tables and charts.


In [3]:
import os, json, ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# =====================================================================
# Directories
# =====================================================================
RESULTS_DIR = Path("results")
VIZ_DIR = RESULTS_DIR / "visualizations"
VIZ_DIR.mkdir(exist_ok=True, parents=True)

csv_path = RESULTS_DIR / "metrics.csv"
df = pd.read_csv(csv_path)

# =====================================================================
# Fix ROC Columns ‚Äî handles Infinity, -Infinity, NaN
# =====================================================================
def clean_list_string(s):
    """Convert stringified lists with Infinity/NaN into real Python lists."""
    if not isinstance(s, str):
        return s

    # Replace invalid JSON tokens
    s = s.replace("Infinity", "1e309")
    s = s.replace("-Infinity", "-1e309")
    s = s.replace("NaN", "null")

    # Try JSON parse
    try:
        return json.loads(s)
    except:
        pass

    # Try Python literal parse
    try:
        return ast.literal_eval(s)
    except:
        print("‚ùå Could not parse:", s[:80])
        return []

for col in ["roc_fpr", "roc_tpr", "roc_thresholds"]:
    df[col] = df[col].apply(clean_list_string)

print("‚úÖ ROC vectors cleaned successfully!")


# =====================================================================
# Helper: Save plot
# =====================================================================
def save_plot(filename):
    plt.tight_layout()
    plt.savefig(VIZ_DIR / filename, dpi=300)
    plt.close()


# =====================================================================
# Bar Plot
# =====================================================================
def plot_bar(metric, title=None, ylabel=None):
    if metric not in df.columns:
        print(f"‚ö†Ô∏è Metric '{metric}' not found in CSV ‚Äî skipping")
        return

    plt.figure(figsize=(10,5))
    plt.bar(df["model"], df[metric])
    plt.xticks(rotation=45, ha="right")
    plt.title(title or metric)
    plt.ylabel(ylabel or metric)
    save_plot(f"{metric}.png")


# =====================================================================
# ROC Curve for ALL models in one plot
# =====================================================================
def plot_all_roc():
    plt.figure(figsize=(8,6))

    for _, row in df.iterrows():
        fpr = row["roc_fpr"]
        tpr = row["roc_tpr"]

        if isinstance(fpr, list) and isinstance(tpr, list) and len(fpr) == len(tpr):
            plt.plot(fpr, tpr, label=row["model"])

    plt.plot([0,1], [0,1], "k--", alpha=0.4)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curves ‚Äî All Models")
    plt.legend()

    save_plot("roc_all_models.png")


# =====================================================================
# Radar Chart for each model
# =====================================================================
def plot_radar(model_name):
    row = df[df["model"] == model_name].iloc[0]

    metrics = ["test_accuracy","test_precision","test_recall","test_f1","test_auc"]
    values = [row[m] for m in metrics]
    values += values[:1]  # close radar loop

    angles = np.linspace(0, 2*np.pi, len(metrics) + 1)

    plt.figure(figsize=(6,6))
    ax = plt.subplot(111, polar=True)

    ax.plot(angles, values, linewidth=2)
    ax.fill(angles, values, alpha=0.25)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(metrics)

    plt.title(f"Radar Chart ‚Äî {model_name}")
    save_plot(f"radar_{model_name}.png")


# =====================================================================
# Generate All Visualizations
# =====================================================================
bar_metrics = [
    # validation
    "val_accuracy", "val_precision", "val_recall", "val_f1", "val_auc",
    "val_loss",

    # test
    "test_accuracy", "test_precision", "test_recall", "test_f1", "test_auc",
    "test_loss",

    # resource usage
    "train_time_sec", "memory_mb", "gpu_used_memory_mb",
]

for metric in bar_metrics:
    plot_bar(metric, f"{metric} Comparison", metric)

# ROC
plot_all_roc()

# Radar per model
for model in df["model"]:
    plot_radar(model)

print(f"üéâ All visualizations generated in: {VIZ_DIR}")


‚úÖ ROC vectors cleaned successfully!
üéâ All visualizations generated in: results/visualizations


In [4]:
RESULTS_DIR = Path("results")

csv_path = RESULTS_DIR / "metrics.csv"
df = pd.read_csv(csv_path)
df

Unnamed: 0,model,val_loss,val_accuracy,val_precision,val_recall,val_f1,val_auc,test_loss,test_accuracy,test_precision,...,test_f1,test_auc,roc_fpr,roc_tpr,roc_thresholds,train_time_sec,memory_mb,gpu_name,gpu_total_memory_mb,gpu_used_memory_mb
0,mobilenet_v2,0.121545,0.954144,0.991968,0.94419,0.967489,0.994532,0.116739,0.960044,0.995223,...,0.97188,0.995889,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0002532286654849329, 0.003038743985819...","[Infinity, 0.9999982118606567, 0.9999909400939...",1139.772988,1232.859375,Apple MPS,0,0
1,tf_efficientnet_lite4,0.05837,0.980479,0.992518,0.980377,0.98641,0.99817,0.060741,0.980851,0.992066,...,0.986762,0.998054,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.29754368194479613, 0.3535072170169663,...","[Infinity, 1.0, 0.9999998807907104, 0.99999976...",2086.465915,2230.09375,Apple MPS,0,0
2,shufflenet_v2,0.323561,0.921915,0.982892,0.907747,0.943826,0.985057,0.323396,0.925796,0.987892,...,0.946855,0.987645,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0002532286654849329, 0.122562674094707...","[Infinity, 0.9842752814292908, 0.8996083736419...",812.475008,2357.546875,Apple MPS,0,0
3,googlenet,0.145177,0.953407,0.981379,0.953619,0.9673,0.990848,0.147475,0.952311,0.983998,...,0.966628,0.9914,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0002532286654849329, 0.013674347936186...","[Infinity, 0.9999465942382812, 0.9994543194770...",924.96101,2641.15625,Apple MPS,0,0
4,alexnet,0.032671,0.990055,0.997685,0.988532,0.993088,0.999009,0.029783,0.9884,0.99795,...,0.991976,0.999605,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.4193466700430489, 0.4591035705241833, ...","[Infinity, 1.0, 0.9999998807907104, 0.99999976...",1138.084349,2159.21875,Apple MPS,0,0
5,vgg16,0.0355,0.990424,0.992872,0.993884,0.993377,0.999286,0.034475,0.989873,0.992662,...,0.993039,0.999055,"[0.0, 0.0006747638326585695, 0.000674763832658...","[0.0, 0.7991896682704482, 0.8214737908331223, ...","[Infinity, 1.0, 0.9999998807907104, 0.99999976...",3800.241618,1085.140625,Apple MPS,0,0
6,resnet50,0.090814,0.972928,0.98585,0.976555,0.98118,0.995645,0.091074,0.972013,0.98742,...,0.980622,0.99572,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0002532286654849329, 0.000759685996454...","[Infinity, 0.9999994039535522, 0.9999992847442...",1981.659394,2409.140625,Apple MPS,0,0
7,densenet121,0.107681,0.968324,0.98778,0.968145,0.977864,0.995241,0.100272,0.973486,0.991221,...,0.98159,0.996507,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0002532286654849329, 0.006583945302608...","[Infinity, 0.9999980926513672, 0.9999660253524...",2045.957682,2376.96875,Apple MPS,0,0
8,squeezenet1_0,0.37618,0.832044,0.813098,0.996687,0.89558,0.97262,0.377987,0.838704,0.820037,...,0.899886,0.975445,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0002532286654849329, 0.002025829323879...","[Infinity, 0.9999998807907104, 0.9999973773956...",1113.83872,1748.703125,Apple MPS,0,0
9,mnasnet1_0,0.354169,0.855801,0.997151,0.802752,0.889454,0.988027,0.34098,0.854171,0.998107,...,0.888733,0.991156,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0002532286654849329, 0.011901747277791...","[Infinity, 0.9999991655349731, 0.9999428987503...",992.070379,2091.203125,Apple MPS,0,0
