In [None]:
# Standard Library Imports
import os
import math
import random
import xml.etree.ElementTree as ET

# Third-Party Library Imports
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, average_precision_score

# PyTorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

# torchvision Imports
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import models
from torchvision.ops import complete_box_iou_loss

import albumentations as A
from albumentations.pytorch import ToTensorV2


In [None]:
class LocalizationDataset(Dataset):
    def __init__(self, dataset_folder, target_size=(256, 256), augment=False):
        self.target_size = target_size
        self.augment = augment

        # Will store PIL images or image paths
        self.images = []      # list of PIL images (augmented) or strings (original paths)
        self.bboxes = []      # list of 4-dim tensors (normalized)

        image_folder = os.path.join(dataset_folder, "images")
        label_folder = os.path.join(dataset_folder, "labels")

        # Base transform
        self.base_transform = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        for img_name in os.listdir(image_folder):
            if not img_name.endswith(".jpg"):
                continue

            img_path = os.path.join(image_folder, img_name)
            xml_path = os.path.join(label_folder, img_name.replace(".jpg", ".xml"))
            if not os.path.exists(xml_path):
                continue

            try:
                bbox, w, h = self.parse_xml(xml_path)

                nb = [
                    bbox[0] / w, bbox[1] / h,
                    bbox[2] / w, bbox[3] / h
                ]

                # STORE ORIGINAL IMAGE PATH
                self.images.append(img_path)
                self.bboxes.append(torch.tensor(nb, dtype=torch.float))

                # AUGMENT IF REQUESTED
                if augment:
                    for _ in range(random.randint(0, 3)):
                        self._add_augmented(img_path, nb)

            except Exception as e:
                print(f"Error processing {img_path}: {e}")
                continue

    def parse_xml(self, xml_file):
            """Parse the XML file to extract bounding box coordinates and image dimensions."""
            tree = ET.parse(xml_file)
            root = tree.getroot()
    
            # Extract image dimensions
            width = int(root.find('size/width').text)
            height = int(root.find('size/height').text)
    
            # Extract bounding box information (assumes a single object per image)
            for obj in root.findall('object'):
                bbox = obj.find('bndbox')
                xmin = int(bbox.find('xmin').text)
                ymin = int(bbox.find('ymin').text)
                xmax = int(bbox.find('xmax').text)
                ymax = int(bbox.find('ymax').text)
    
                return [xmin, ymin, xmax, ymax], width, height
    
            raise ValueError(f"No bounding box found in {xml_file}")
    # ---------------------------------------------------------
    # ADD AUGMENTED DATA INTO THE LISTS
    # ---------------------------------------------------------
    def _add_augmented(self, image_path, bbox):
        img = Image.open(image_path).convert("RGB")
        bbox = bbox.copy()

        ops = [
            self._aug_flip_horizontal,
            self._aug_flip_vertical,
            self._aug_color_jitter,
            self._aug_shift_xy,
            self._aug_small_rotate
        ]

        chosen_ops = random.sample(ops, random.randint(1, 3))

        for op in chosen_ops:
            img, bbox = op(img, bbox)

        # Store augmented image AS A PIL IMAGE (not a path!)
        self.images.append(img)
        self.bboxes.append(torch.tensor(bbox, dtype=torch.float))


    # ---------------------------------------------------------
    # AUGMENTATION OPERATIONS
    # ---------------------------------------------------------
    def _aug_flip_horizontal(self, img, bbox):
        img = TF.hflip(img)
        x1, y1, x2, y2 = bbox
        return img, [1-x2, y1, 1-x1, y2]

    def _aug_flip_vertical(self, img, bbox):
        img = TF.vflip(img)
        x1, y1, x2, y2 = bbox
        return img, [x1, 1-y2, x2, 1-y1]

    def _aug_color_jitter(self, img, bbox):
        jitter = transforms.ColorJitter(
            brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05
        )
        return jitter(img), bbox

    def _aug_shift_xy(self, img, bbox):
        w, h = img.size
        max_dx, max_dy = int(w * 0.1), int(h * 0.1)

        dx = random.randint(-max_dx, max_dx)
        dy = random.randint(-max_dy, max_dy)

        img = TF.affine(img, angle=0, translate=(dx, dy), scale=1.0, shear=0)

        x1, y1, x2, y2 = bbox
        x1 += dx / w; x2 += dx / w
        y1 += dy / h; y2 += dy / h

        x1 = max(0, min(1, x1))
        x2 = max(0, min(1, x2))
        y1 = max(0, min(1, y1))
        y2 = max(0, min(1, y2))

        return img, [x1, y1, x2, y2]

    def _aug_small_rotate(self, img, bbox):
        angle = random.uniform(-10, 10)
        w, h = img.size
        x1, y1, x2, y2 = bbox

        mask = Image.new("L", (w, h), 0)
        draw = ImageDraw.Draw(mask)
        draw.rectangle([x1*w, y1*h, x2*w, y2*h], fill=255)

        img = TF.rotate(img, angle)
        mask = TF.rotate(mask, angle)

        mask_np = torch.tensor(mask, dtype=torch.uint8)
        ys, xs = torch.where(mask_np > 0)

        if len(xs) == 0:
            return img, bbox

        nx1 = xs.min().item() / w
        nx2 = xs.max().item() / w
        ny1 = ys.min().item() / h
        ny2 = ys.max().item() / h

        return img, [nx1, ny1, nx2, ny2]


    # ---------------------------------------------------------
    # __getitem__
    # ---------------------------------------------------------
    def __getitem__(self, idx):
        item = self.images[idx]

        # item is either a path (string) OR a PIL image (augmentation)
        if isinstance(item, str):
            img = Image.open(item).convert("RGB")
        else:
            img = item

        img = self.base_transform(img)
        bbox = self.bboxes[idx]

        return {"image": img, "bbox": bbox}


    def __len__(self):
        return len(self.images)


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch
import os 

# IMPORTANT: Use the correct path for your Kaggle input
KAGGLE_RESNET_PATH = "/kaggle/input/resnet_50/pytorch/default/1/resnet50-11ad3fa6.pth"

class CCLN(nn.Module):
    def __init__(self, pretrained_weights_path=KAGGLE_RESNET_PATH):
        super(CCLN, self).__init__()

        # ... (Backbone Loading remains the same) ...
        resnet50 = models.resnet50(weights=None) 
        if pretrained_weights_path and os.path.exists(pretrained_weights_path):
             state_dict = torch.load(pretrained_weights_path, map_location='cpu', weights_only=True)
             resnet50.load_state_dict(state_dict, strict=True)

        # --- Model Components (Backbone) ---
        self.conv1 = resnet50.conv1; self.bn1 = resnet50.bn1; self.relu = resnet50.relu; self.maxpool = resnet50.maxpool
        self.layer1 = resnet50.layer1; self.layer2 = resnet50.layer2; self.layer3 = resnet50.layer3; self.layer4 = resnet50.layer4     
        
        # --- Activations ---
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)

        # ... (Decoder Layers remain the same) ...
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_concat1 = nn.Conv2d(2048 + 1024, 512, kernel_size=1); self.bn_concat1 = nn.BatchNorm2d(512)
        self.conv_up1 = nn.Conv2d(512, 512, kernel_size=3, padding=1); self.bn_up1 = nn.BatchNorm2d(512)

        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_concat2 = nn.Conv2d(512 + 512, 256, kernel_size=1); self.bn_concat2 = nn.BatchNorm2d(256)
        self.conv_up2 = nn.Conv2d(256, 256, kernel_size=3, padding=1); self.bn_up2 = nn.BatchNorm2d(256)

        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_concat3 = nn.Conv2d(256 + 256, 128, kernel_size=1); self.bn_concat3 = nn.BatchNorm2d(128)
        self.conv_up3 = nn.Conv2d(128, 128, kernel_size=3, padding=1); self.bn_up3 = nn.BatchNorm2d(128)

        # --- FINAL REGRESSION HEAD (CONV-BASED, SMALLER, AND STABLE) ---
        # 1. Final 1x1 Conv to predict 4 channels (x_min, y_min, x_max, y_max) per pixel
        self.conv_final = nn.Conv2d(128, 4, kernel_size=1)
        nn.init.constant_(self.conv_final.bias, 0.0)


    def forward(self, x):
        # ... (Backbone and Decoder layers are unchanged) ...
        x = self.conv1(x); x = self.bn1(x); x = self.relu(x); r0 = self.maxpool(x)
        r1 = self.layer1(r0); r2 = self.layer2(r1); r3 = self.layer3(r2); r4 = self.layer4(r3)

        # Decoder Path (Leaky ReLU)
        x = self.upsample1(r4); x = torch.cat([x, r3], dim=1); x = self.conv_concat1(x); x = self.leaky_relu(self.bn_concat1(x))
        x = self.conv_up1(x); x = self.leaky_relu(self.bn_up1(x))

        x = self.upsample2(x); x = torch.cat([x, r2], dim=1); x = self.conv_concat2(x); x = self.leaky_relu(self.bn_concat2(x))
        x = self.conv_up2(x); x = self.leaky_relu(self.bn_up2(x))

        x = self.upsample3(x); x = torch.cat([x, r1], dim=1); x = self.conv_concat3(x); x = self.leaky_relu(self.bn_concat3(x))
        x = self.conv_up3(x); x = self.leaky_relu(self.bn_up3(x)) # x is (B, 128, H/4, W/4)

        # --- FINAL REGRESSION HEAD (CONV-BASED) ---
        # 1. Final Conv (B, 4, H/4, W/4)
        x = self.conv_final(x)
        
        # 2. Sigmoid on the predicted values
        x = torch.sigmoid(x) 

        # 3. Global average pooling to get the final (xmin, ymin, xmax, ymax)
        output = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)

        return output

In [None]:
class CCLN3(nn.Module):
    def __init__(self, weights=True):
        super(CCLN3, self).__init__()

        # 1. Load the model architecture
        resnet = resnet50(weights=None)  # don't trigger download
        
        # 2. Load state_dict from your uploaded file
        state_dict_path = '/kaggle/input/resnet_50/pytorch/default/1/resnet50-11ad3fa6.pth'
        state_dict = torch.load(state_dict_path, map_location='cpu', weights_only=True)
        
        # 3. Load weights into the model
        resnet.load_state_dict(state_dict)

        # Use all layers except the final fully connected layer
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # Extract up to the last conv block

        # Upsampling and concatenation layers
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_up1 = nn.Conv2d(2048, 1024, kernel_size=3, padding=1)

        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_up2 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)

        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_up3 = nn.Conv2d(512, 256, kernel_size=3, padding=1)

        # Final output layer (4 bounding box coordinates)
        self.conv_out = nn.Conv2d(256, 4, kernel_size=1)

    def forward(self, x):
        # Feature extraction using ResNet-50 backbone
        x = self.backbone(x)  # Output: [batch_size, 2048, 10, 10] for 320x320 input

        # Upsampling layers
        x = self.upsample1(x)
        x = F.relu(self.conv_up1(x))

        x = self.upsample2(x)
        x = F.relu(self.conv_up2(x))

        x = self.upsample3(x)
        x = F.relu(self.conv_up3(x))

        # Output layer (bounding box coordinates)
        x = torch.sigmoid(self.conv_out(x))  # Normalize to [0, 1]

        # Flatten the output to [batch_size, 4]
        x = F.adaptive_avg_pool2d(x, (1, 1))  # Global average pooling to (1, 1)
        x = x.view(x.size(0), -1)  # Flatten to [batch_size, 4]

        return x


In [None]:
def display_image_with_boxes(image, true_box, pred_box, iou_loss):
    print("Prediction:", pred_box)
    print("Label:", true_box)
    print("IoU:", torchvision.ops.box_iou(true_box, pred_box).item())

    # Convert the PyTorch tensor to NumPy and reshape (H, W, C)
    image = image.permute(1, 2, 0).cpu().numpy()  # Convert to NumPy array and change shape to (H, W, C)

    # Create a figure and axis
    plt.figure(figsize=(12, 12))
    plt.imshow(image)

    # Define colors for bounding boxes
    colors = ['red', 'green']  # Red for true box, Green for predicted box

    # Draw bounding boxes for true and predicted boxes
    for i, bbox in enumerate([true_box.squeeze(0), pred_box.squeeze(0)]):  # Squeeze to remove extra dimensions
        # Ensure bbox is a tensor and convert to float for safety
        bbox = bbox.cpu().detach().numpy()  # Move to CPU, detach from computation graph, and convert to NumPy
        bbox = [float(coord) for coord in bbox]  # Convert to float

        x_min, y_min, x_max, y_max = bbox

        # Scale to image dimensions
        x_min = int(x_min * image.shape[1])  # Scale to image width
        y_min = int(y_min * image.shape[0])  # Scale to image height
        x_max = int(x_max * image.shape[1])  # Scale to image width
        y_max = int(y_max * image.shape[0])  # Scale to image height

        # Draw the rectangle using Matplotlib
        plt.gca().add_patch(plt.Rectangle(
            (x_min, y_min),
            x_max - x_min,
            y_max - y_min,
            edgecolor=colors[i],
            facecolor='none',
            linewidth=2,
            label='True Box' if i == 0 else 'Predicted Box'
        ))

    # Add IoU loss text at the bottom right corner
    plt.text(
        image.shape[1] - 200, image.shape[0] - 50,
        f"IoU Loss: {iou_loss.item():.4f}",
        color='white',
        fontsize=14,
        bbox=dict(facecolor='black', alpha=0.8, edgecolor='none', pad=6)
    )

    # Add legend and hide axes
    plt.legend()
    plt.axis('off')
    plt.show()

In [None]:
def calculate_metrics(pred_boxes, true_boxes, iou_threshold=0.5):
    ious = torchvision.ops.box_iou(pred_boxes, true_boxes)

    # True Positives, False Positives, False Negatives
    tp = (ious > iou_threshold).sum().item()
    fp = (ious <= iou_threshold).sum().item()
    fn = (true_boxes.shape[0] - tp)

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return precision, recall, f1

def calculate_average_precision(pred_boxes, true_boxes):
    # Flatten for average precision computation
    ious = torchvision.ops.box_iou(pred_boxes, true_boxes).cpu().numpy().flatten()

    # Binary classification for IoU > threshold (positive), otherwise (negative)
    labels = (ious > 0.5).astype(int)

    # Use precision_recall_curve and average_precision_score from sklearn
    precision, recall, _ = precision_recall_curve(labels, ious)
    average_precision = average_precision_score(labels, ious)

    return precision, recall, average_precision

def plot_precision_recall_curve(precision, recall, average_precision):
    plt.figure()
    plt.plot(recall, precision, label=f'Average Precision = {average_precision:.2f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.show()

def compute_accuracy(predictions, targets, threshold=0.5):
    """
    Computes the accuracy of the predictions based on IoU.

    Args:
        predictions: Predicted bounding boxes (tensor of shape [N, 4]).
        targets: Ground truth bounding boxes (tensor of shape [N, 4]).
        threshold: IoU threshold to consider a prediction as correct.

    Returns:
        accuracy: The proportion of correct predictions based on IoU.
    """
    # Ensure predictions and targets are in the correct shape
    if predictions.ndim == 1:
        predictions = predictions.unsqueeze(0)
    if targets.ndim == 1:
        targets = targets.unsqueeze(0)

    # Calculate IoU for each predicted box with the corresponding target box
    ious = (predictions, targets)

    # Count the number of correct predictions based on the threshold
    correct_predictions = (ious > threshold).sum().item()
    total_predictions = predictions.size(0)

    # Calculate accuracy
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
    return accuracy

def plot_metrics_vs_iou_thresholds(pred_boxes, true_boxes, thresholds=np.linspace(0.5, 0.95, 10)):
    precisions, recalls, f1_scores, accuracies, average_precisions = [], [], [], [], []

    for iou_thresh in thresholds:
        precision, recall, f1 = calculate_metrics(pred_boxes, true_boxes, iou_thresh)
        avg_precision = calculate_average_precision(pred_boxes, true_boxes)[2]

        accuracies.append((precision + recall) / 2)  # Simple accuracy
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        average_precisions.append(avg_precision)

    # Plot the results
    plt.figure(figsize=(12, 8))
    plt.plot(thresholds, precisions, label='Precision')
    plt.plot(thresholds, recalls, label='Recall')
    plt.plot(thresholds, f1_scores, label='F1-Score')
    plt.plot(thresholds, accuracies, label='Accuracy')
    plt.plot(thresholds, average_precisions, label='Average Precision')

    plt.xlabel('IoU Threshold')
    plt.ylabel('Score')
    plt.title('Metrics vs IoU Threshold')
    plt.legend()
    plt.show()

In [None]:
def train_model(model, train_loader, val_loader, optimizer, device, epochs, checkpoint_path):
    """
    Trains a localization model, computes IoU-based accuracy, and saves the best model.
    Prints training loss, validation loss, and accuracy on one line per epoch.

    Args:
        model: The model to train.
        train_loader: DataLoader for training data.
        val_loader: DataLoader for validation data.
        optimizer: Optimizer for training.
        device: Device to train on (e.g., 'cuda' or 'cpu').
        epochs: Number of epochs to train.
        checkpoint_path: Path to save the best model checkpoint.
    """
    start_epoch = 0
    best_val_loss = float('inf')
    best_model_state = None
    
    model.train()
    for c_path in [checkpoint_path, "/kaggle/input/ccln_0.75.pth/pytorch/default/1/clnn.pth", "/kaggle/input/ccln_0.75.pth/pytorch/default/1/clnn_0.75.pth"]:
        if False and os.path.exists(c_path):
            model, best_val_loss = load_checkpoint(model, c_path, device)
            optimizer.load_state_dict(torch.load(c_path, weights_only=True)['optimizer_state_dict'])
            start_epoch = torch.load(c_path, weights_only=True)['epoch'] + 1
            break

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(start_epoch, epochs):
        running_loss = 0.0

        # Training loop
        for batch in train_loader:
            images = batch['image'].to(device)
            bboxes = batch['bbox'].to(device)

            optimizer.zero_grad()
            
            outputs = model(images)

            # Calculate Complete Box IoU loss
            #loss = torchvision.ops.complete_box_iou_loss(outputs, bboxes, reduction='mean') + torch.nn.functional.mse_loss(outputs, bboxes).mean()
            loss = torchvision.ops.complete_box_iou_loss(outputs, bboxes, reduction='mean')
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Average training loss for the epoch
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validation loss and accuracy
        val_loss, val_accuracy = evaluate_model(model, val_loader, device)
        val_losses.append(val_loss)

        # Store IoU accuracy for training
        train_accuracies.append(torchvision.ops.box_iou(outputs, bboxes).mean().item())
        val_accuracies.append(val_accuracy)

        #current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch [{epoch + 1}/{epochs}] | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Accuracy (IoU): {val_accuracy:.4f}')

        # Save the model if the validation loss is the best so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            save_checkpoint(model, optimizer, checkpoint_path, best_val_loss, epoch)
            if val_accuracy > 0.75:
                save_checkpoint(model, optimizer, f"/kaggle/working/clnn_{val_accuracy:.2f}.pth", best_val_loss, epoch)


        # Update the learning rate
        scheduler.step(val_loss)

    print(f'Training completed. Best model saved with validation loss: {best_val_loss:.4f}')
    return best_model_state, epoch, train_losses, val_losses, train_accuracies, val_accuracies

def evaluate_model(model, data_loader, device):
    """
    Evaluates the model on a validation dataset and computes IoU-based accuracy and loss.

    Args:
        model: The model to evaluate.
        data_loader: DataLoader for validation data.
        device: Device to evaluate on.

    Returns:
        val_loss (float): The mean loss over the validation dataset.
        accuracy (float): The mean IoU accuracy over the validation dataset.
    """
    model.eval()
    running_loss = 0.0
    iou_scores = []

    with torch.no_grad():
        for batch in data_loader:
            images = batch['image'].to(device)
            bboxes = batch['bbox'].to(device)

            outputs = model(images)

            # Calculate Complete Box IoU loss
            loss = torchvision.ops.complete_box_iou_loss(outputs, bboxes, reduction='mean')
            
            running_loss += loss.item()

            # Compute IoU for each pair of predicted and ground truth bounding boxes
            iou_scores.append(torchvision.ops.box_iou(outputs, bboxes).mean().item())

    # Compute mean IoU accuracy and loss over all batches
    val_loss = running_loss / len(data_loader)
    val_accuracy = sum(iou_scores) / len(iou_scores)
    return val_loss, val_accuracy

def save_checkpoint(model, optimizer, checkpoint_path, best_loss, epoch):
    """
    Saves the model checkpoint.

    Args:
        model: The model to save.
        optimizer: The optimizer state to save.
        checkpoint_path: Path to save the model checkpoint.
        best_loss: Best validation loss to include in the saved checkpoint.
        epoch: Current epoch number.
    """
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
        'epoch': epoch
    }, checkpoint_path)
    print(f"Model saved with Best Validation Loss: {best_loss:.4f}")

def load_checkpoint(model, checkpoint_path, device):
    """
    Loads the model checkpoint.

    Args:
        model: The model to load the checkpoint into.
        checkpoint_path: Path to the saved model checkpoint.
        device: Device to load the model on.

    Returns:
        model: The loaded model.
        best_loss: The best validation loss from the checkpoint.
    """
    print(f"Loading model from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    best_loss = checkpoint['best_loss']
    return model, best_loss
    
def plot_metrics(train_accuracies, val_accuracies, train_losses, val_losses, checkpoint_path, epoch):
    """
    Plots the training and validation accuracy and loss curves.

    Args:
        train_accuracies: List of training accuracies.
        val_accuracies: List of validation accuracies.
        train_losses: List of training losses.
        val_losses: List of validation losses.
        checkpoint_path: Path to save the plot.
        epoch: Current epoch.
    """
    # Plotting training and validation accuracies
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epoch + 1), train_accuracies, label='Train Accuracy', color='blue')
    plt.plot(range(1, epoch + 1), val_accuracies, label='Validation Accuracy', color='orange')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(checkpoint_path, f'accuracy_plot_epoch_{epoch}.png'))
    plt.close()

    # Plotting training and validation losses
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epoch + 1), train_losses, label='Train Loss', color='blue')
    plt.plot(range(1, epoch + 1), val_losses, label='Validation Loss', color='orange')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(checkpoint_path, f'loss_plot_epoch_{epoch}.png'))
    plt.close()

def plot_top_bottom_images(model, test_loader, device, checkpoint_path):
    """
    Plots the top 5 and bottom 5 images based on IoU loss.

    Args:
        model: The trained model.
        test_loader: DataLoader for the test data.
        device: Device to use for inference.
        checkpoint_path: Path to save the plots.
    """
    model.eval()
    iou_losses = []
    images = []
    bboxes = []

    with torch.no_grad():
        for batch in test_loader:
            batch_images = batch['image'].to(device)
            batch_bboxes = batch['bbox'].to(device)

            outputs = model(batch_images)
            loss = 1 - torchvision.ops.box_iou(outputs, batch_bboxes).mean()
            iou_losses.extend(loss.cpu().numpy())
            images.extend(batch_images.cpu())
            bboxes.extend(batch_bboxes.cpu())

    # Sort the images based on IoU loss
    sorted_indices = np.argsort(iou_losses)

    # Plot the top 5 and bottom 5 images
    fig, axs = plt.subplots(2, 5, figsize=(15, 6))

    # Top 5 images
    for i in range(5):
        index = sorted_indices[i]
        ax = axs[0, i]
        display_image_with_boxes(images[index], bboxes[index], outputs[index], iou_losses[index])
        ax.set_title(f"Top {i+1}, IoU Loss: {iou_losses[index]:.4f}")

    # Bottom 5 images
    for i in range(5):
        index = sorted_indices[-i-1]
        ax = axs[1, i]
        display_image_with_boxes(images[index], bboxes[index], outputs[index], iou_losses[index])
        ax.set_title(f"Bottom {i+1}, IoU Loss: {iou_losses[index]:.4f}")

    plt.savefig(os.path.join(checkpoint_path, 'top_bottom_images.png'))
    plt.close()

In [None]:
# Hyperparameters and settings
img_channel = 3  # RGB
img_height = 256
img_width = 256
learning_rate = 0.001
batch_size = 16
num_epochs = 512
model_save_path = "/kaggle/working/clnn.pth"
data_path = '/kaggle/input/finale-dataset/final_dataset/localization'

In [None]:

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the dataset
dataset = LocalizationDataset(data_path, target_size=(img_width, img_height), augment=False)

# Define dataset splitting
train_size = int(0.7 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size


# Create train, validation, and test datasets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders for batching
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Initialize the simple model
model = CCLN().to(device)

# Create a random input tensor
random_input = torch.randn(3, 3, img_height, img_width).to(device)

# Pass the random input through the model
output = model(random_input)

# Print the input and output shapes
print("Input shape:", random_input.shape)
print("Output shape:", output.shape)

# Define the optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)#, weight_decay=1e-4) 
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9)

In [None]:
def denormalize_boxes(boxes, image_width, image_height):
    """De-normalize bounding boxes from [0, 1] to pixel values."""
    boxes[:, 0] *= image_width  # x_min
    boxes[:, 1] *= image_height  # y_min
    boxes[:, 2] *= image_width  # x_max
    boxes[:, 3] *= image_height  # y_max
    return boxes

def normalize_boxes(boxes, image_width, image_height):
    """Normalize bounding boxes from [0, 1] to pixel values."""
    boxes[:, 0] /= image_width  # x_min
    boxes[:, 1] /= image_height  # y_min
    boxes[:, 2] /= image_width  # x_max
    boxes[:, 3] /= image_height  # y_max
    return boxes

# List of loaders and their names for display
loaders = [train_loader, val_loader, test_loader]
loader_names = ['Train', 'Validation', 'Test']

# Example usage in your validation loop
for loader, name in zip(loaders, loader_names):
    idx = random.randint(0, len(loader.dataset) - 1)
    sample = loader.dataset[idx]
    image = sample['image'].unsqueeze(0).to(device)

    with torch.no_grad():
        pred_boxes = model(image)

    # Get image dimensions
    image_width, image_height = image.shape[3], image.shape[2]

    # De-normalize predicted boxes if they are normalized
    pred_boxes = denormalize_boxes(pred_boxes.cpu(), image_width, image_height)
    true_boxes = denormalize_boxes(sample['bbox'].unsqueeze(0).cpu(), image_width, image_height)

    # Calculate IoU loss
    loss = 1 - torchvision.ops.box_iou(pred_boxes, true_boxes)

    # Normalize predicted boxes if they are normalized
    pred_boxes = normalize_boxes(pred_boxes.cpu(), image_width, image_height)
    true_boxes = normalize_boxes(sample['bbox'].unsqueeze(0).cpu(), image_width, image_height)

    # Display the results
    display_image_with_boxes(sample['image'], true_boxes, pred_boxes, loss)


In [None]:

import time
import os
import torch
snapshot_path = "/kaggle/input/ccln/pytorch/default/1/ccln.pth"
if os.path.exists(snapshot_path):
    model, _ = load_checkpoint(model, snapshot_path, device)
    model.eval()

def calculate_model_complexity(model):
    """Calculates the number of parameters and the disk size of the model."""
    
    # 1. Number of parameters (trainable)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # 2. Model size (MB) - saves model state to disk temporarily
    temp_path = 'temp_model_size.pth'
    torch.save(model.state_dict(), temp_path)
    model_size_bytes = os.path.getsize(temp_path)
    model_size_mb = model_size_bytes / (1024 * 1024)
    os.remove(temp_path)
    
    print("\n--- Model Complexity Metrics ---")
    print(f"Number of parameters: {total_params:,}")
    print(f"Model size (MB): {model_size_mb:.2f} MB")
    
    return total_params, model_size_mb


def calculate_inference_speed(model, data_loader, device, num_warmup_batches=5):
    """Calculates FPS and Latency, ensuring accurate timing for CPU/GPU inference."""
    model.eval()
    
    # 1. Warm-up Runs (Crucial for accurate GPU timing)
    print("Starting warm-up...")
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            if i >= num_warmup_batches:
                break
            images = batch['image'].to(device)
            labels = batch['bbox'].to(device)
            outputs = model(images)
            for i in range(images.size(0)):
                true_box = labels[i].unsqueeze(0)  # Shape [1, 4]
                pred_box = outputs[i].unsqueeze(0)  # Shape [1, 4]

                # Optionally display predictions and ground truth
                loss = 1 - torchvision.ops.box_iou(pred_box, true_box)
                display_image_with_boxes(images[i].cpu(), true_box.cpu(), pred_box.cpu(), loss)
            if device.type == 'cuda':
                torch.cuda.synchronize()
    
    # 2. Main Timing Loop
    total_time = 0.0
    total_images = 0
    
    if device.type == 'cuda':
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    
    print("Starting main timing...")
    with torch.no_grad():
        for batch in data_loader:
            images = batch['image'].to(device)
            batch_size = images.size(0)
            
            # --- Start Timing (only inference time) ---
            if device.type == 'cuda':
                starter.record()
                _ = model(images)
                ender.record()
                torch.cuda.synchronize()
                curr_time = starter.elapsed_time(ender) / 1000.0 # Convert ms to seconds
            else:
                start_time = time.time()
                _ = model(images)
                curr_time = time.time() - start_time
            # --- End Timing ---
            
            total_time += curr_time
            total_images += batch_size

    # 3. Calculate metrics
    fps = total_images / total_time
    latency_ms = (total_time / total_images) * 1000
    
    print("\n--- Inference Speed Metrics ---")
    print(f"Total images processed: {total_images}")
    print(f"Total inference time: {total_time:.4f} seconds")
    print(f"FPS (Frames Per Second): {fps:.2f}")
    print(f"Latency (ms per image): {latency_ms:.2f} ms")
    
    return fps, latency_ms

# --- Usage Example ---

# # 1. Calculate and display complexity metrics
num_params, model_size = calculate_model_complexity(model)

# # 2. Calculate and display inference speed metrics
fps, latency = calculate_inference_speed(model, test_loader, device)


# # 3. Then run your evaluation logic (like your original evaluate_and_plot function)


In [None]:

# Assuming the function returns train_losses, val_losses, train_accuracies, val_accuracies
#best_model_state, best_epoch, train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_loader, val_loader, optimizer, device, num_epochs, model_save_path)


In [None]:
train_losses = []
validation_losses = []
train_accuracies = []
validation_accuracies = []

model_path = model_save_path
best_loss = float('inf')


def train(model, train_loader, validation_loader, optimizer, scheduler,
          num_epochs=16, display_every=10):

    best_model_state = None
    best_val_loss = float('inf')
    best_val_iou = 0.0

    for epoch in range(num_epochs):

        # -----------------------
        # TRAIN
        # -----------------------
        model.train()
        total_loss = 0
        total_iou = 0

        lr = optimizer.param_groups[0]['lr']

        for batch in train_loader:
            images = batch['image'].to(device)
            bboxes = batch['bbox'].to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = torchvision.ops.complete_box_iou_loss(outputs, bboxes, reduction='mean')

            loss.backward()
            optimizer.step()

            total_loss += loss.detach()
            iou = torchvision.ops.box_iou(outputs, bboxes).diag().mean()
            total_iou += iou.detach()

        avg_train_loss = total_loss / len(train_loader)
        avg_train_iou = total_iou / len(train_loader)

        train_losses.append(avg_train_loss.item())
        train_accuracies.append(avg_train_iou.item())

        # -----------------------
        # VALIDATION
        # -----------------------
        model.eval()
        val_loss = 0
        val_iou = 0

        with torch.no_grad():
            for batch in validation_loader:
                images = batch['image'].to(device)
                bboxes = batch['bbox'].to(device)

                outputs = model(images)
                loss = torchvision.ops.complete_box_iou_loss(outputs, bboxes, reduction='mean')

                val_loss += loss.detach()
                val_iou += torchvision.ops.box_iou(outputs, bboxes).diag().mean().detach()

        avg_val_loss = val_loss / len(validation_loader)
        avg_val_iou = val_iou / len(validation_loader)

        validation_losses.append(avg_val_loss.item())
        validation_accuracies.append(avg_val_iou.item())

        scheduler.step(avg_val_loss)

        # PRINT METRICS
        print(f"Epoch [{epoch+1}/{num_epochs}] | LR {lr:.2e} "
              f"| Train Loss {avg_train_loss:.4f} | Val Loss {avg_val_loss:.4f} "
              f"| Train IoU {avg_train_iou:.4f} | Val IoU {avg_val_iou:.4f}")

        # -----------------------
        # SAVE BEST MODEL
        # -----------------------
        is_best = (avg_val_loss < best_val_loss) or (avg_val_iou > best_val_iou)

        if is_best:
            print("  âœ” New best model! Saving...")
            best_model_state = model.state_dict()
            best_val_loss = avg_val_loss
            best_val_iou = avg_val_iou

            torch.save(
                {
                    "epoch": epoch + 1,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "loss": avg_val_loss,
                    "iou": avg_val_iou,
                },
                model_path  # global save path
            )

        # -----------------------
        # SAVE SNAPSHOT EVERY 10 EPOCHS
        # -----------------------
        #if (epoch + 1) % 10 == 0:
        #    snapshot_path = "/kaggle/input/ccln/pytorch/default/1/ccln.pth"
        #    torch.save(model.state_dict(), snapshot_path)
        #    print(f"  ðŸ“Œ Snapshot saved: {snapshot_path}")

        # -----------------------
        # DISPLAY 5 RANDOM VAL IMAGES
        # -----------------------
        if (epoch + 1) % display_every == 0:
            print("  ðŸ–¼ Displaying 5 random validation predictions...")
            self_display_count = 0

            for batch in validation_loader:
                images = batch['image']
                bboxes = batch['bbox']
                outputs = model(images.to(device)).cpu()

                for i in range(images.size(0)):
                    if self_display_count >= 5:
                        break
                    loss = 1 - torchvision.ops.box_iou(
                        outputs[i].unsqueeze(0), 
                        bboxes[i].unsqueeze(0)
                    )

                    display_image_with_boxes(
                        images[i], 
                        bboxes[i].unsqueeze(0), 
                        outputs[i].unsqueeze(0), 
                        loss
                    )
                    self_display_count += 1

                if self_display_count >= 5:
                    break

    return best_model_state


# --- Execution ---
# Note: You must ensure 'model_save_path' and 'num_epochs' are defined before this call.
# The scheduler is now a required argument for the train function.

# Initialize the scheduler before calling train
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.9) # assumed to be initialized outside

# Call train function with the new argument
best_model_state = train(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs)

epochs = range(1, len(train_losses) + 1)

# Plot Loss vs Epoch
plt.figure(figsize=(12, 6))
plt.plot(epochs, train_losses, label='Training Loss', color='blue')
plt.plot(epochs, validation_losses, label='Validation Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid()
plt.show()

# Plot Accuracy vs Epoch
plt.figure(figsize=(12, 6))
plt.plot(epochs, [t_a  for t_a in train_accuracies], label='Training Accuracy', color='green')
plt.plot(epochs, [v_a for v_a in validation_accuracies], label='Validation Accuracy', color='red')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid()
plt.show()


In [None]:
# Load the model
snapshot_path = "/kaggle/input/ccln/pytorch/default/1/ccln.pth"
loaded_model, _ = load_checkpoint(model, snapshot_path, device)
#loaded_model, _ = load_checkpoint(model, "/kaggle/input/ccln_0.75.pth/pytorch/default/1/clnn_0.75.pth", device)
loaded_model.eval()
model = loaded_model

# Evaluate the model on the test set
test_loss, test_iou = evaluate_model(loaded_model, test_loader, device)
print(f"Test Loss: {test_loss:.4f}, Test IoU: {test_iou:.4f}")



In [None]:
# Set the model to evaluation mode
model.eval()

# Randomly select 5 unique indices from the test dataset
num_samples = 5
idxs = random.sample(range(len(test_loader.dataset)), num_samples)

for idx in idxs:
    # Get the sample from the test dataset
    sample = test_loader.dataset[idx]  # Assuming test_loader.dataset returns the sample directly
    image = sample['image'].unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Make predictions
    with torch.no_grad():  # Disable gradient calculation for inference
        pred_boxes = model(image)  # Get predicted bounding boxes

    # Convert predictions to CPU and numpy for visualization
    pred_boxes = pred_boxes.cpu()

    # Calculate IoU loss
    loss = 1 - torchvision.ops.box_iou(pred_boxes, sample['bbox'].unsqueeze(0))
    
    # Display the image with the ground truth and predicted bounding boxes
    display_image_with_boxes(sample['image'], sample['bbox'].unsqueeze(0), pred_boxes, loss)


In [None]:
import torch
from torchvision.ops import box_iou
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import numpy as np
import random
import torchvision.transforms.functional as TF

# Metric calculation using torch's box_iou
def calculate_metrics(pred_boxes, true_boxes, iou_threshold=0.75):
    # Ensure bounding boxes are valid
    pred_boxes = validate_boxes(pred_boxes)
    true_boxes = validate_boxes(true_boxes)

    # Compute IoU using torchvision's box_iou
    ious = torchvision.ops.box_iou(pred_boxes, true_boxes)

    # True Positives (TP): IoU > threshold
    tp = (ious > iou_threshold).sum().item()

    # False Positives (FP): Predicted boxes not matching any true box
    fp = pred_boxes.size(0) - tp

    # False Negatives (FN): True boxes not matched by any predicted box
    fn = true_boxes.size(0) - tp

    # Total predictions for accuracy
    total = pred_boxes.size(0) + true_boxes.size(0)

    # Calculate precision, recall, F1 score, and accuracy
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    accuracy = tp / total if total > 0 else 0

    return precision, recall, f1, accuracy

# Average Precision calculation with dynamic IoU
def calculate_average_precision(pred_boxes, true_boxes, iou_thresholds=np.arange(0.0, 1.1, 0.1)):
    # Ensure bounding boxes are valid
    pred_boxes = validate_boxes(pred_boxes)
    true_boxes = validate_boxes(true_boxes)

    # Compute IoU
    ious = torchvision.ops.box_iou(pred_boxes, true_boxes).flatten().cpu().numpy()

    # Prepare for precision-recall calculation
    avg_precisions = []
    for iou_thresh in iou_thresholds:
        labels = (ious > iou_thresh).astype(int)
        precision, recall, _ = precision_recall_curve(labels, ious)
        avg_precision = average_precision_score(labels, ious)
        avg_precisions.append(avg_precision)

    return avg_precisions

# Plot Precision-Recall curve
def plot_precision_recall_curve(precision, recall, average_precision):
    plt.figure()
    plt.plot(recall, precision, label=f'AP = {average_precision:.2f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.show()

# Plot metrics (Precision, Recall, F1-score, Accuracy, etc.) vs IoU thresholds
def plot_metrics_vs_iou_thresholds(pred_boxes, true_boxes, thresholds=np.linspace(0.5, 0.95, 10)):
    precisions, recalls, f1_scores, accuracies, avg_precisions = [], [], [], [], []

    for iou_thresh in thresholds:
        precision, recall, f1, accuracy = calculate_metrics(pred_boxes, true_boxes, iou_thresh)
        avg_precision = calculate_average_precision(pred_boxes, true_boxes)[-1]  # Last value for max IoU

        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        accuracies.append(accuracy)
        avg_precisions.append(avg_precision)

    # Plot results
    plt.figure(figsize=(12, 8))
    plt.plot(thresholds, precisions, label='Precision')
    plt.plot(thresholds, recalls, label='Recall')
    plt.plot(thresholds, f1_scores, label='F1-Score')
    plt.plot(thresholds, accuracies, label='Accuracy')
    plt.plot(thresholds, avg_precisions, label='Average Precision')
    plt.xlabel('IoU Threshold')
    plt.ylabel('Score')
    plt.title('Metrics vs IoU Threshold')
    plt.legend()
    plt.show()

# Validate bounding boxes to ensure coordinates are within [0, 1]
def validate_boxes(boxes):
    boxes[:, :2] = torch.clamp(boxes[:, :2], 0, 1)  # Clamp x_min, y_min
    boxes[:, 2:] = torch.clamp(boxes[:, 2:], 0, 1)  # Clamp x_max, y_max
    return boxes

# Evaluate and plot metrics
def evaluate_and_plot(test_loader, model, device):
    all_precisions, all_recalls, all_f1s, all_accuracies = [], [], [], []
    all_true_boxes, all_pred_boxes = [], []

    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            true_boxes = batch['bbox'].to(device)
            predicted_boxes = model(images)

            for i in range(images.size(0)):
                true_box = true_boxes[i].unsqueeze(0)  # Shape [1, 4]
                pred_box = predicted_boxes[i].unsqueeze(0)  # Shape [1, 4]

                precision, recall, f1, accuracy = calculate_metrics(pred_box, true_box, iou_threshold=0.75)
                all_precisions.append(precision)
                all_recalls.append(recall)
                all_f1s.append(f1)
                all_accuracies.append(accuracy)

                all_true_boxes.append(true_box)
                all_pred_boxes.append(pred_box)

                # Optionally display predictions and ground truth
                display_image_with_boxes(images[i].cpu(), true_box.cpu(), pred_box.cpu())

    # Calculate metrics vs IoU thresholds
    plot_metrics_vs_iou_thresholds(torch.cat(all_pred_boxes), torch.cat(all_true_boxes))

    # Precision-Recall curve with dynamic IoU
    avg_precisions = calculate_average_precision(torch.cat(all_pred_boxes), torch.cat(all_true_boxes))
    for iou in np.arange(0.0, 1.1, 0.1):
        precision, recall, _ = precision_recall_curve((box_iou(torch.cat(all_pred_boxes), torch.cat(all_true_boxes)).flatten() > iou).int().cpu().numpy(), 
                                                       box_iou(torch.cat(all_pred_boxes), torch.cat(all_true_boxes)).flatten().cpu().numpy())
        plot_precision_recall_curve(precision, recall, avg_precisions[int(iou * 10)])

    # Summary of average metrics
    print("Average Metrics at IoU 0.75:")
    print(f"Precision: {np.mean(all_precisions):.2f}")
    print(f"Recall: {np.mean(all_recalls):.2f}")
    print(f"F1-Score: {np.mean(all_f1s):.2f}")
    print(f"Accuracy: {np.mean(all_accuracies):.2f}")

# Undo normalization for image display
def undo_normalization(image, mean, std):
    for t, m, s in zip(image, mean, std):
        t.mul_(s).add_(m)
    return image

# Visualization helper
def display_image_with_boxes(image, true_box, pred_box):
    print("Prediction:", pred_box)
    print("Label:", true_box)

    # Undo normalization
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    #image = undo_normalization(image, mean, std)

    # Convert to numpy for plotting
    image = image.permute(1, 2, 0).numpy()  # Convert to HWC

    # Create a figure and axis
    plt.figure(figsize=(12, 12))
    plt.imshow(image)

    # Define colors for bounding boxes
    colors = ['red', 'green']  # Red for true box, Green for predicted box

    # Draw bounding boxes for true and predicted boxes
    for i, bbox in enumerate([true_box.squeeze(0), pred_box.squeeze(0)]):
        x_min, y_min, x_max, y_max = bbox
        x_min = int(x_min * image.shape[1])  # Scale to image width
        y_min = int(y_min * image.shape[0])  # Scale to image height
        x_max = int(x_max * image.shape[1])  # Scale to image width
        y_max = int(y_max * image.shape[0])  # Scale to image height

        # Draw rectangle
        plt.gca().add_patch(plt.Rectangle(
            (x_min, y_min),
            x_max - x_min,
            y_max - y_min,
            edgecolor=colors[i],
            facecolor='none',
            linewidth=2,
            label='True Box' if i == 0 else 'Predicted Box'
        ))

    # Add legend and hide axes
    plt.legend()
    plt.axis('off')
    plt.show()

# Example usage
evaluate_and_plot(test_loader, model, device)
