# test_onnx_quadrants.py
"""
ONNX test script for RGB image classifier with 4-quadrant patching inside the ONNX graph.

Assumes the ONNX model:
- Accepts NHWC uint8 full images of arbitrary H×W
- Splits into TL/TR/BL/BR, resizes each tile to trained patch size
- Outputs:
    class_probabilities: [4B, 2]  (per-tile probs: [NoKnot, Knot])
    class_predictions:   [4B, 1]  (per-tile binary after sigmoid>0.5)

Image-level decision = 1 if ANY tile's class-1 probability >= threshold, else 0.
"""

import argparse
import os
from collections import Counter

import numpy as np
from PIL import Image
import onnxruntime as ort
from torchvision.datasets import ImageFolder
from PIL import ImageDraw
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score


PATCH_NAMES = ["TL", "TR", "BL", "BR"]


def create_session(onnx_path: str, providers=None) -> ort.InferenceSession:
    if providers is None:
        providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
    try:
        return ort.InferenceSession(onnx_path, providers=providers)
    except Exception:
        return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])


def load_image_nhwc_uint8(path: str, full_size: tuple[int, int]) -> np.ndarray:
    """
    Open image, convert to RGB, resize to (W,H)=full_size, return NHWC uint8 with batch dim.
    """
    W, H = full_size
    img = Image.open(path)
    arr = np.asarray(img, dtype=np.uint8)         # [H, W, 3]
    arr = np.expand_dims(arr, axis=0)             # [1, H, W, 3]
    return arr


def infer_image(sess: ort.InferenceSession, x_nhwc_uint8: np.ndarray, threshold: float):
    """
    Run one image through ONNX. Returns (overall_pred, pos_scores, fired_idxs).
    pos_scores are per-tile class-1 probabilities (length 4).
    """
    print(x_nhwc_uint8.shape)
    in_name = sess.get_inputs()[0].name
    outputs = sess.run(None, {in_name: x_nhwc_uint8})
    # Map names -> arrays
    out_names = [o.name for o in sess.get_outputs()]
    name2 = {o.name: arr for o, arr in zip(sess.get_outputs(), outputs)}

    if "class_probabilities" in out_names:
        probs = name2["class_probabilities"]  # [4B, 2]; B=1 -> [4, 2]
    else:
        # Fallback: assume first output are probs/logits with shape [4,2].
        probs = outputs[0]
        if not (probs.ndim == 2 and probs.shape[1] == 2):
            raise RuntimeError(f"Unexpected output shape {probs.shape}; expected [4,2].")

    # Per-tile positive (class-1) probabilities
    pos_scores = probs[:, 1]  # shape [4]
    fired_mask = pos_scores >= threshold
    overall_pred = int(np.any(fired_mask))
    fired_idxs = np.where(fired_mask)[0].tolist()
    return overall_pred, pos_scores, fired_idxs


def main():
    ap = argparse.ArgumentParser(description="Test ONNX model with quadrant tiling inside the graph")
    ap.add_argument("--onnx", required=True, help="Path to model.onnx")
    ap.add_argument("--root", required=True, help="Path to test ImageFolder root")
    ap.add_argument("--full_width", type=int, default=2448, help="Full image width fed to ONNX")
    ap.add_argument("--full_height", type=int, default=2048, help="Full image height fed to ONNX")
    ap.add_argument("--threshold", type=float, default=0.5, help="Tile-level positive threshold")
    args = ap.parse_args()

    # Build dataset (labels & paths from folder names)
    ds = ImageFolder(args.root)
    classes = ds.classes
    samples = ds.samples  # list of (path, label)

    print(f"Discovered {len(samples)} image(s) across classes: {classes}")
    if len(samples) == 0:
        print("No images found; exiting.")
        return

    # Create session
    sess = create_session(args.onnx)
    in_name = sess.get_inputs()[0].name
    print(f"ONNX input: {in_name} | providers: {sess.get_providers()}")

    y_true, y_pred = [], []
    wrong_details = []  # tuples: (path, true, pred, fired_names, pos_scores)

    # Evaluate
    for path, label in samples:
        x = load_image_nhwc_uint8(path, (args.full_width, args.full_height))
        try:
            pred, pos_scores, fired_idxs = infer_image(sess, x, args.threshold)
        except Exception as e:
            print(f"[ERROR] {path}: {e}")
            continue
        

        y_true.append(int(label))
        y_pred.append(int(pred))

        if pred != int(label):
            fired_names = [PATCH_NAMES[i] for i in fired_idxs]
            wrong_details.append((path, int(label), int(pred), fired_names, pos_scores.tolist()))
        # break

    # Reporting
    print("\n🧪 Test Results")
    print(classification_report(y_true, y_pred, target_names=classes, zero_division=0))

    print("🧮 Confusion Matrix")
    print(confusion_matrix(y_true, y_pred, labels=[0, 1]))

    print("\nClass balance (y_true):", Counter(y_true))
    print("Class balance (y_pred):", Counter(y_pred))

    if wrong_details:
        print("\nMisclassified images (tile probs shown as [TL, TR, BL, BR]):")
        for (path, y, p, fired, pos_scores) in wrong_details:
            # reorder pos_scores to TL,TR,BL,BR if needed (we assume wrapper order is TL,TR,BL,BR)
            tile_str = ", ".join(f"{name}={pos_scores[i]:.3f}" for i, name in enumerate(PATCH_NAMES))
            fired_str = ", ".join(fired) if fired else "none"
            print(f" - {path} | true={y} pred={p} | fired: [{fired_str}] | {tile_str}")
    else:
        print("\nNo misclassified images 🎉")

    # Return summary metrics if you want to call this as a function elsewhere
    acc = float(np.mean(np.array(y_true) == np.array(y_pred))) if y_true else float("nan")
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec  = recall_score(y_true, y_pred, zero_division=0)
    f1   = f1_score(y_true, y_pred, zero_division=0)
    print(f"\nSummary: acc={acc:.4f} precision={prec:.4f} recall={rec:.4f} f1={f1:.4f}")


if __name__ == "__main__":
    main()
