"""
Test script for RGB image classifier with 4-quadrant patching per image.

Each image is split into: top-left (TL), top-right (TR), bottom-left (BL), bottom-right (BR).
Image-level prediction = 1 if any quadrant predicts 1, else 0.

Args:
    --model_path: Path to the trained model .pth file
Returns:
    Prints classification report and confusion matrix.
"""
import argparse
import numpy as np
from collections import Counter

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix
from torchvision.transforms import InterpolationMode
from config import config

from torchvision import models
from torchvision.models import ResNet18_Weights
import torch.nn as nn

def load_model(image_width=64, image_height=64):
    model = models.resnet18(weights=ResNet18_Weights.DEFAULT)

    # Replace 3-channel conv1 with 1-channel
    # model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # Replace the classifier for binary output
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 1)
    )

    return model

PATCH_NAMES = ["TL", "TR", "BL", "BR"]


def get_test_transforms(width, height):
    """Transforms applied per PATCH (not on the whole image)."""
    return transforms.Compose([
        transforms.Lambda(lambda img: img.convert("RGB")),
        transforms.Resize((height, width), interpolation=InterpolationMode.BILINEAR, antialias=True),  # resize each patch to model input
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])


class QuadrantPatchDataset(Dataset):
    """Wraps ImageFolder; for each image returns 4 patches (TL, TR, BL, BR)."""
    def __init__(self, root, transform=None):
        self.base = ImageFolder(root)
        self.transform = transform
        # Keep class names and image paths accessible
        self.classes = self.base.classes
        self.class_to_idx = self.base.class_to_idx
        self.samples = self.base.samples
        self.image_paths = [p for p, _ in self.base.samples]

    def __len__(self):
        return len(self.base)

    def _quadrants(self, img):
        # img is a PIL Image
        W, H = img.size
        xm, ym = W // 2, H // 2
        # Boxes: (left, upper, right, lower)
        print(f"Image size: W={W}, H={H}, xm={xm}, ym={ym}")
        boxes = [
            (0,   0,   xm, ym),    # TL
            (xm,  0,   W,  ym),    # TR
            (0,   ym,  xm, H),     # BL
            (xm,  ym,  W,  H),     # BR
        ]
        
        return [img.crop(b) for b in boxes]

    def __getitem__(self, idx):
        path, label = self.base.samples[idx]
        img = self.base.loader(path)  # PIL
        quads = self._quadrants(img)
        if self.transform is not None:
            quads = [self.transform(q) for q in quads]  # list of Tensors [C,H,W]
        # Stack to [4, C, H, W]
        patches = torch.stack(quads, dim=0)
        return patches, label, path


def test_model(config, model_path: str):
    """Test RGB image classifier with quadrant patching and print metrics."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset & loader
    transform = get_test_transforms(config["image_width"], config["image_height"])
    test_set = QuadrantPatchDataset(config["test_dir"], transform=transform)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

    # Load model
    model = load_model()
    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state)
    model.eval().to(device)

    # Evaluation
    y_true, y_pred = [], []
    wrong_images = []
    wrong_details = []  # store which patches fired

    with torch.no_grad():
        for patches, label, path in test_loader:
            # patches: [1, 4, C, H, W] -> [4, C, H, W]
            patches = patches.squeeze(0).to(device)
            # Model forward on the 4 patches as a batch of size 4
            logits = model(patches)  # expect [4, 1] or [4] for binary BCEWithLogitsLoss setup
            if logits.ndim == 2 and logits.shape[1] == 1:
                logits = logits.squeeze(1)  # [4]
            # Convert logits to probabilities for thresholding
            probs = torch.sigmoid(logits)  # [4]
            preds_patch = (probs > 0.5).int().cpu().numpy()  # [4] in {0,1}

            # Aggregate to image-level prediction: 1 if any patch == 1
            img_pred = int(np.any(preds_patch == 1))
            y_pred.append(img_pred)
            y_true.append(int(label))

            if img_pred != int(label):
                # Which patches fired?
                fired = [PATCH_NAMES[i] for i, v in enumerate(preds_patch.tolist()) if v == 1]
                wrong_images.append(path)
                wrong_details.append((path, fired))

    # Reporting
    print("🧪 Test Results")
    print(classification_report(y_true, y_pred, target_names=test_set.classes, zero_division=0))

    print("🧮 Confusion Matrix")
    print(confusion_matrix(y_true, y_pred, labels=[0, 1]))

    print("\nClass balance (y_true):", Counter(y_true))
    print("Class balance (y_pred):", Counter(y_pred))

    if wrong_images:
        print("\nMisclassified images:")
        for (path, fired) in wrong_details:
            fired_str = ", ".join(fired) if fired else "none"
            print(f" - {path} | patches predicted 1: [{fired_str}]")
    else:
        print("\nNo misclassified images 🎉")

    return {
        "accuracy": float(np.mean(np.array(y_true) == np.array(y_pred))) if y_true else float("nan"),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0)
    }


if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Test RGB Classifier Model (Quadrant Patching)")
    ap.add_argument("--model_path", required=True, help="Path to the trained model .pth file")
    args = ap.parse_args()
    test_model(config, model_path=args.model_path)
