In [1]:
# save as plot_cxr_metrics.py
# run:
#   python plot_cxr_metrics.py
# Requirements: matplotlib, numpy, pandas (optional), json

import os, json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# ---------- CONFIG: artifacts directory ----------
ARTIFACTS = Path(r"C:\Users\NXTWAVE\Downloads\COVID Radiography Detection\artifacts")
HIST_PATH = ARTIFACTS / "history.json"
CM_PATH   = ARTIFACTS / "confusion_matrix.json"
METRICS_PATH = ARTIFACTS / "metrics.json"  # optional (not required)

def load_history(path):
    if not path.exists():
        raise FileNotFoundError(f"Missing {path}. Train first to generate history.json.")
    with open(path, "r") as f:
        hist = json.load(f)
    # Some Keras keys appear only if the metric exists — guard with defaults
    # Ensure lists are same length for plotting
    keys = {
        "accuracy": hist.get("accuracy", []),
        "val_accuracy": hist.get("val_accuracy", []),
        "loss": hist.get("loss", []),
        "val_loss": hist.get("val_loss", [])
    }
    return keys

def plot_accuracy_loss(hist, out_dir):
    # ---- Accuracy curve
    acc = hist.get("accuracy", [])
    val_acc = hist.get("val_accuracy", [])
    epochs_acc = range(1, len(acc) + 1)

    plt.figure(figsize=(7, 4.5))
    if len(acc) > 0:
        plt.plot(epochs_acc, acc, marker="o", label="Train Accuracy")
    if len(val_acc) > 0:
        plt.plot(range(1, len(val_acc) + 1), val_acc, marker="s", label="Val Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training & Validation Accuracy")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    out_acc = out_dir / "accuracy_curve.png"
    plt.savefig(out_acc, dpi=220)
    plt.close()
    print(f"[Saved] {out_acc}")

    # ---- Loss curve
    loss = hist.get("loss", [])
    val_loss = hist.get("val_loss", [])
    epochs_loss = range(1, len(loss) + 1)

    plt.figure(figsize=(7, 4.5))
    if len(loss) > 0:
        plt.plot(epochs_loss, loss, marker="o", label="Train Loss")
    if len(val_loss) > 0:
        plt.plot(range(1, len(val_loss) + 1), val_loss, marker="s", label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    out_loss = out_dir / "loss_curve.png"
    plt.savefig(out_loss, dpi=220)
    plt.close()
    print(f"[Saved] {out_loss}")

def load_confusion_matrix(path):
    if not path.exists():
        raise FileNotFoundError(f"Missing {path}. Evaluate first to generate confusion_matrix.json.")
    with open(path, "r") as f:
        data = json.load(f)
    labels = data.get("labels", None)
    matrix = np.array(data.get("matrix", []), dtype=float)
    if labels is None:
        labels = [str(i) for i in range(matrix.shape[0])]
    return labels, matrix

def heatmap(ax, data, row_labels, col_labels, title="Confusion Matrix"):
    im = ax.imshow(data, interpolation="nearest", aspect="auto")
    ax.set_title(title)
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_xticks(np.arange(len(col_labels)))
    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_xticklabels(col_labels, rotation=45, ha="right")
    ax.set_yticklabels(row_labels)
    # Annotate cells
    nrows, ncols = data.shape
    for i in range(nrows):
        for j in range(ncols):
            val = data[i, j]
            txt = f"{val:.2f}" if (val % 1 != 0) else f"{int(val)}"
            ax.text(j, i, txt, ha="center", va="center", fontsize=8)
    return im, cbar

def plot_confusion_matrices(labels, cm, out_dir):
    # Raw counts
    fig, ax = plt.subplots(figsize=(6.5, 5.5))
    heatmap(ax, cm, labels, labels, title="Confusion Matrix (Counts)")
    plt.tight_layout()
    out_counts = out_dir / "confusion_matrix.png"
    plt.savefig(out_counts, dpi=220)
    plt.close()
    print(f"[Saved] {out_counts}")

    # Row-normalized
    with np.errstate(invalid="ignore", divide="ignore"):
        row_sums = cm.sum(axis=1, keepdims=True)
        cm_norm = np.where(row_sums == 0, 0, cm / row_sums)

    fig, ax = plt.subplots(figsize=(6.5, 5.5))
    heatmap(ax, cm_norm, labels, labels, title="Confusion Matrix (Row-Normalized)")
    plt.tight_layout()
    out_norm = out_dir / "confusion_matrix_norm.png"
    plt.savefig(out_norm, dpi=220)
    plt.close()
    print(f"[Saved] {out_norm}")

def maybe_plot_per_class_f1(metrics_path, out_dir):
    """Optional: bar chart of per-class F1 from metrics.json if available."""
    if not metrics_path.exists():
        return
    with open(metrics_path, "r") as f:
        met = json.load(f)
    rep = met.get("test", {}).get("classification_report", {})
    # Filter only class entries (ignore 'accuracy','macro avg','weighted avg')
    classes, f1s = [], []
    for k, v in rep.items():
        if isinstance(v, dict) and "f1-score" in v:
            classes.append(k)
            f1s.append(v["f1-score"])
    if not classes:
        return
    x = np.arange(len(classes))
    plt.figure(figsize=(8, 4.5))
    plt.bar(x, f1s)
    plt.xticks(x, classes, rotation=45, ha="right")
    plt.ylabel("F1-score")
    plt.title("Per-class F1 (Test)")
    plt.tight_layout()
    out_path = out_dir / "per_class_f1.png"
    plt.savefig(out_path, dpi=220)
    plt.close()
    print(f"[Saved] {out_path}")

def main():
    ARTIFACTS.mkdir(parents=True, exist_ok=True)

    # Accuracy & Loss curves
    hist = load_history(HIST_PATH)
    plot_accuracy_loss(hist, ARTIFACTS)

    # Confusion matrix heatmaps
    labels, cm = load_confusion_matrix(CM_PATH)
    plot_confusion_matrices(labels, cm, ARTIFACTS)

    # Optional per-class F1 bar (if metrics.json exists)
    maybe_plot_per_class_f1(METRICS_PATH, ARTIFACTS)

if __name__ == "__main__":
    main()


[Saved] C:\Users\NXTWAVE\Downloads\COVID Radiography Detection\artifacts\accuracy_curve.png
[Saved] C:\Users\NXTWAVE\Downloads\COVID Radiography Detection\artifacts\loss_curve.png
[Saved] C:\Users\NXTWAVE\Downloads\COVID Radiography Detection\artifacts\confusion_matrix.png
[Saved] C:\Users\NXTWAVE\Downloads\COVID Radiography Detection\artifacts\confusion_matrix_norm.png
[Saved] C:\Users\NXTWAVE\Downloads\COVID Radiography Detection\artifacts\per_class_f1.png
