In [None]:
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import ijson
from torchvision.ops import nms
import random


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set random seed for reproducibility
random.seed(42)
torch.manual_seed(42)

# Input image size (height, width)
ISIZE = (720, 1280)

# ImageNet statistics (for VGG16)
# imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
# imagenet_std  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

def normalize_tensor(img):
    """Normalize a tensor image (C, H, W) with values in [0,255]."""
    img = img / 255.0
    return img

def unnormalize_tensor(img):
    """Convert a normalized tensor back to a displayable numpy image."""
    img = img * 255.0
    return img.clamp(0, 255).byte().cpu().numpy()

# Global anchor parameters
ratios = [0.5, 1, 2]
anchor_scales = [4, 8, 16, 32, 64] # generate more anchors at each location


#################################################################
def extract_first_n_labels(json_file_path, n):
    labels = []
    with open(json_file_path, 'rb') as f:
        parser = ijson.items(f, 'item')
        for i, item in enumerate(parser):
            if i >= n:
                break
            filtered_labels = [
                {"category": li.get("category"), "box2d": li.get("box2d")}
                for li in item.get("labels", []) if "box2d" in li
            ]
            labels.append({
                "name": item.get("name"),
                "timestamp": item.get("timestamp"),
                "labels": filtered_labels
            })
    return labels

def standardize_filename(path_or_name):
    base = os.path.basename(path_or_name)
    base, _ = os.path.splitext(base)
    return base


def contrast_stretch(image, low_percentile=10, high_percentile=90):
    """
    Perform contrast stretching while printing debug info to avoid full black images.

    :param image: PyTorch tensor of shape (C, H, W)
    :param low_percentile: Lower percentile for clipping
    :param high_percentile: Upper percentile for clipping
    :return: Contrast-stretched tensor
    """
    image_np = image.cpu().numpy()

    # Compute percentiles
    min_val = np.percentile(image_np, low_percentile)
    max_val = np.percentile(image_np, high_percentile)

    # print(f"Debug: Min percentile value = {min_val}, Max percentile value = {max_val}")

    if max_val - min_val < 1e-6:
        print("Warning: Min and max values are too close! Returning original image.")
        return image  # Return original image to avoid black output

    # Apply contrast stretching
    stretched = (image_np - min_val) / (max_val - min_val + 1e-8)

    # Clip values to avoid over-brightening
    stretched = np.clip(stretched, 0, 1)

    return torch.tensor(stretched, dtype=torch.float32)


######################################################################

# 1) "CustomDataset" with [x1, y1, x2, y2]
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, labels, pt_dir='pt_files'):
        self.image_dir = image_dir
        self.pt_dir = pt_dir
        os.makedirs(self.pt_dir, exist_ok=True)
        self.image_files = sorted([
            os.path.join(image_dir, f)
            for f in os.listdir(image_dir)
            if f.lower().endswith(('.jpg', '.png', '.jpeg'))
        ])
        self.label_dict = {}
        for item in labels:
            key = os.path.splitext(os.path.basename(item["name"]))[0]
            self.label_dict[key] = item

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        pt_path = os.path.join(
            self.pt_dir,
            os.path.basename(image_path)
                .replace('.jpg', '.pt')
                .replace('.png', '.pt')
                .replace('.jpeg', '.pt')
        )
        if os.path.exists(pt_path):
            image_tensor = torch.load(pt_path)
        else:
            image = Image.open(image_path).convert('RGB')
            if image.size != (ISIZE[1], ISIZE[0]):  # (width, height)
                image = image.resize((ISIZE[1], ISIZE[0]))
            image_tensor = transforms.PILToTensor()(image).float()
            torch.save(image_tensor, pt_path)
        
        # Contrast Stretching 
        # image_tensor = contrast_stretch(image_tensor)

        # Normalize to [0,1]


        # Parse ground truth
        key = os.path.splitext(os.path.basename(image_path))[0]
        matched = self.label_dict.get(key, None)
        if matched is None or "labels" not in matched:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            names = []
        else:
            box_list = []
            cat_list = []
            for obj in matched["labels"]:
                if "box2d" in obj:
                    b2d = obj["box2d"]
                    # *** IMPORTANT *** interpret these as [x1, y1, x2, y2]
                    # If your JSON is truly y1,x1,y2,x2, then flip them here.
                    # But let's assume the user wants x1= b2d["x1"], y1= b2d["y1"], etc.
                    x1 = float(b2d["x1"])
                    y1 = float(b2d["y1"])
                    x2 = float(b2d["x2"])
                    y2 = float(b2d["y2"])
                    box_list.append([x1, y1, x2, y2])
                    cat_list.append(obj["category"])
            if len(box_list) == 0:
                boxes = torch.zeros((0,4), dtype=torch.float32)
                labels = torch.zeros((0,), dtype=torch.int64)
                names = []
            else:
                boxes = torch.tensor(box_list, dtype=torch.float32)
                labels = torch.ones((len(box_list),), dtype=torch.int64)  # e.g. '1' for object
                names = cat_list

        return {
            "image": image_tensor,  # (C,H,W) in [0,1]
            "boxes": boxes,         # shape (N,4) in [x1,y1,x2,y2]
            "labels": labels,       # shape (N,)
            "names": names,
            "index": idx,
            "img_name": os.path.basename(image_path) 
        }

# 2) Collate
def custom_collate_fn(batch):
    # Move to device later in the training loop
    images = [item["image"] for item in batch]
    boxes  = [item["boxes"] for item in batch]
    labels = [item["labels"] for item in batch]
    names  = [item["names"] for item in batch]
    idxs   = [item["index"] for item in batch]
    img_ids = [item["img_name"] for item in batch] 
    # Stack images => shape (B, C, H, W)
    return {
        "images": torch.stack(images, dim=0),
        "boxes": boxes,
        "labels": labels,
        "names": names,
        "indices": idxs,
        "img_ids": img_ids
    }

# 3) Utility for decoding predicted offsets -> [x1,y1,x2,y2]
def pred_bbox_to_xywh(bbox_offsets, anchors):
    """
    bbox_offsets: (N,4) predicted offsets [dy, dx, dh, dw]
    anchors: (N,4) in [x1,y1,x2,y2]
    return (N,4) boxes in [x1,y1,x2,y2]
    """
    # Convert to numpy
    anchors_np = anchors.detach().cpu().numpy()
    bbox_np    = bbox_offsets.detach().cpu().numpy()

    anc_w = anchors_np[:,2] - anchors_np[:,0]  # x2 - x1
    anc_h = anchors_np[:,3] - anchors_np[:,1]  # y2 - y1
    anc_ctr_x = anchors_np[:,0] + 0.5*anc_w
    anc_ctr_y = anchors_np[:,1] + 0.5*anc_h

    dy = bbox_np[:,0]
    dx = bbox_np[:,1]
    dh = bbox_np[:,2]
    dw = bbox_np[:,3]

    # decode
    ctr_y = dy*anc_h + anc_ctr_y
    ctr_x = dx*anc_w + anc_ctr_x
    h = np.exp(dh)*anc_h
    w = np.exp(dw)*anc_w

    out = np.zeros_like(bbox_np, dtype=np.float32)
    out[:,0] = ctr_x - 0.5*w  # x1
    out[:,1] = ctr_y - 0.5*h  # y1
    out[:,2] = ctr_x + 0.5*w  # x2
    out[:,3] = ctr_y + 0.5*h  # y2
    return out

# 4) IoU with [x1,y1,x2,y2]
def compute_iou_vectorized(boxes1, boxes2):
    """boxes1, boxes2 in [x1,y1,x2,y2]. Return IoU matrix."""
    boxes1 = boxes1.astype(np.float32)
    boxes2 = boxes2.astype(np.float32)

    inter_x1 = np.maximum(boxes1[:, None, 0], boxes2[None, :, 0])
    inter_y1 = np.maximum(boxes1[:, None, 1], boxes2[None, :, 1])
    inter_x2 = np.minimum(boxes1[:, None, 2], boxes2[None, :, 2])
    inter_y2 = np.minimum(boxes1[:, None, 3], boxes2[None, :, 3])

    inter_w = np.maximum(inter_x2 - inter_x1, 0)
    inter_h = np.maximum(inter_y2 - inter_y1, 0)
    inter_area = inter_w*inter_h

    area1 = (boxes1[:,2] - boxes1[:,0])*(boxes1[:,3] - boxes1[:,1])
    area2 = (boxes2[:,2] - boxes2[:,0])*(boxes2[:,3] - boxes2[:,1])
    union = area1[:,None] + area2 - inter_area
    iou = inter_area/union
    return iou

# 5) Show boxes in [x1,y1,x2,y2]
def create_corner_rect(bb, color='red'):
    x1,y1,x2,y2 = bb
    return plt.Rectangle((x1,y1), x2 - x1, y2 - y1, color=color,
                         fill=False, lw=2)

def show_corner_bbs(img, bbs):
    # Expect [x1,y1,x2,y2]
    img_np = (img*255.0).clamp(0,255).byte().cpu().numpy()  # (C,H,W)
    img_np = np.transpose(img_np, (1,2,0))  # (H,W,C)
    plt.imshow(img_np)
    for bb in bbs:
        plt.gca().add_patch(create_corner_rect(bb))
    plt.show()

# 6) Example Anchor Generation in [x1,y1,x2,y2]
def generate_anchor_grid_np(X_FM, Y_FM, ratios, scales):
    """
    Return anchors as shape (N,4), each row [x1,y1,x2,y2].
    """
    H_IMG, W_IMG = ISIZE[0], ISIZE[1]
    sub_sampling_x = W_IMG / float(X_FM)
    sub_sampling_y = H_IMG / float(Y_FM)

    shift_x = np.arange(sub_sampling_x, (X_FM+1)*sub_sampling_x, sub_sampling_x)
    shift_y = np.arange(sub_sampling_y, (Y_FM+1)*sub_sampling_y, sub_sampling_y)
    shift_x, shift_y = np.meshgrid(shift_x, shift_y)

    centers = np.stack([
        shift_x.ravel() - sub_sampling_x/2.0,
        shift_y.ravel() - sub_sampling_y/2.0
    ], axis=1)  # shape (total_positions,2) => (cx,cy)

    anchors = []
    for (cx, cy) in centers:
        for ratio in ratios:
            for scale in scales:
                h = sub_sampling_y*scale*np.sqrt(ratio)
                w = sub_sampling_x*scale*np.sqrt(1.0/ratio)
                x1 = cx - 0.5*w
                y1 = cy - 0.5*h
                x2 = cx + 0.5*w
                y2 = cy + 0.5*h
                anchors.append([x1,y1,x2,y2])
    return np.array(anchors, dtype=np.float32)

# 7) The RPN with CBAM, in-channels=512
class CBAM(nn.Module):
    def __init__(self, channels, reduction=4, kernel_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels//reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channels//reduction, channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(2,1,kernel_size,padding=kernel_size//2,bias=False)

    def forward(self, x):
        b,c,h,w = x.shape
        # Channel
        y_avg = self.avg_pool(x).view(b,c)
        y_max = self.max_pool(x).view(b,c)
        y = self.fc(y_avg) + self.fc(y_max)
        scale = self.sigmoid(y).view(b,c,1,1)
        x = x*scale

        # Spatial
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out,_ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out,max_out], dim=1)
        scale = self.sigmoid(self.conv(y))
        return x*scale

import torch
import torch.nn as nn
import torchvision.ops as ops

class ROIPooling(nn.Module):
    def __init__(self, output_size):
        """
        Args:
            output_size (tuple or int): Size of the output (height, width)
        """
        super(ROIPooling, self).__init__()
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            self.output_size = output_size

    def forward(self, features, rois):
        """
        Args:
            features (Tensor): Input features of shape (N, C, H, W)
            rois (Tensor): Regions of Interest in format
                          [batch_index, x1, y1, x2, y2] with shape (K, 5)

        Returns:
            Pooled features of shape (K, C, output_size[0], output_size[1])
        """
        # Ensure rois are on same device as features
        rois = rois.to(features.device)

        # Calculate spatial scale (feature map size / original image size)
        spatial_scale_h = features.size(2) / 224  # Assuming input image size is 224
        spatial_scale_w = features.size(3) / 224
        spatial_scale = min(spatial_scale_h, spatial_scale_w)

        # Perform ROI pooling
        pooled_features = ops.roi_pool(
            features,
            rois,
            output_size=self.output_size,
            spatial_scale=spatial_scale
        )

        return pooled_features

    def __repr__(self):
        return self.__class__.__name__ + '(output_size={})'.format(self.output_size)

class EnhancedRPNWithROI(nn.Module):
    def __init__(self, 
                 in_channels=512, 
                 mid_channels=256, 
                 n_anchor=15,  # must match len(ratios)*len(scales)
                 pool_size=(7,7), 
                 nms_thresh=0.5,
                 conf_thresh=0.5,
                 top_n=400):
        """
        Largely the same as your EnhancedRPN, but adds ROI pooling and a _process_proposals method.
        """
        super().__init__()
        
        # --- The same RPN body as your EnhancedRPN ---
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1)

        # If you want your CBAM modules:
        self.cbam1 = CBAM(mid_channels, reduction=4, kernel_size=3)
        self.cbam2 = CBAM(mid_channels, reduction=4, kernel_size=3)

        self.reg_layer = nn.Conv2d(mid_channels, n_anchor*4, kernel_size=1)
        self.cls_layer = nn.Conv2d(mid_channels, n_anchor*2, kernel_size=1)

        self.skip_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
        if in_channels == mid_channels:
            nn.init.eye_(self.skip_conv.weight)
            nn.init.zeros_(self.skip_conv.bias)
            self.skip_conv.weight.requires_grad = False
            self.skip_conv.bias.requires_grad = False

        self._init_weights()

        # --- ROI pooling layer ---
        #   You can use RoIPool or RoIAlign. 
        #   Below is ops.RoIPool, but if you want bilinear interpolation,
        #   you might prefer RoIAlign(pooled_height=pool_size[0], pooled_width=pool_size[1], ...)
        self.roi_pool = ops.RoIPool(output_size=pool_size, spatial_scale=1.0)

        # Store thresholds for NMS, etc.
        self.nms_thresh = nms_thresh
        self.conf_thresh = conf_thresh
        self.top_n = top_n

    def _init_weights(self):
        for layer in [self.conv1, self.conv2, self.conv3, self.reg_layer, self.cls_layer]:
            nn.init.normal_(layer.weight, std=0.01)
            if layer.bias is not None:
                nn.init.constant_(layer.bias, 0)

    def forward(self, x, anchors=None, do_roi=False):
        """
        Args:
            x: feature map from the backbone, shape (B, C=512, H, W)
            anchors: (total_anchors, 4) if you want to decode proposals
            do_roi: bool. If True, we also do the ROI pooling step, returning pooled features.

        Returns:
            pred_locs: (B, #anchors, 4)
            pred_scores: (B, #anchors, 2)
            objectness_score: (B, #anchors)
            (optionally) pooled_feats: if do_roi==True and anchors is not None
        """
        residual = self.skip_conv(x)

        # block1 + CBAM
        x = F.relu(self.conv1(x))
        x = self.cbam1(x)

        # block2 + CBAM
        x = F.relu(self.conv2(x))
        x = self.cbam2(x)

        # block3 + residual
        x = F.relu(self.conv3(x) + residual)

        # RPN heads
        B = x.size(0)
        pred_anchor_locs = self.reg_layer(x)      # shape (B, n_anchor*4, H, W)
        pred_cls_scores  = self.cls_layer(x)      # shape (B, n_anchor*2, H, W)

        pred_anchor_locs = pred_anchor_locs.permute(0,2,3,1).contiguous().view(B, -1, 4)
        pred_cls_scores  = pred_cls_scores.permute(0,2,3,1).contiguous().view(B, -1, 2)

        # optional tanh for loc
        pred_anchor_locs = torch.tanh(pred_anchor_locs) * 2
        objectness_score = F.softmax(pred_cls_scores, dim=-1)[..., 1]  # shape (B, #anchors)

        if anchors is None or (not do_roi):
            # If we don't want ROI pooling, just return the normal RPN outputs
            return pred_anchor_locs, pred_cls_scores, objectness_score

        # else we want proposals + ROI pooling
        proposals = self._generate_proposals(pred_anchor_locs, anchors)  # shape (B, #anchors, 4)

        # _process_proposals will do NMS, thresholding, and ROI pooling
        pooled_feats = self._process_proposals(x, proposals, objectness_score)

        return pred_anchor_locs, pred_cls_scores, objectness_score, pooled_feats

    def _generate_proposals(self, pred_locs, anchors):
        """
        Convert anchor offsets to box coords [x1, y1, x2, y2].
        pred_locs: (B, N, 4) => offsets [dy, dx, dh, dw]
        anchors:   (N, 4) in [x1, y1, x2, y2]
        """
        B, N, _ = pred_locs.size()
        proposals = torch.zeros_like(pred_locs)  # (B, N, 4)

        # anchors => float on same device
        anchors = anchors.to(pred_locs.device)

        # anchor geometry
        anc_w = anchors[:, 2] - anchors[:, 0]  # x2 - x1
        anc_h = anchors[:, 3] - anchors[:, 1]  # y2 - y1
        anc_ctr_x = anchors[:, 0] + 0.5*anc_w
        anc_ctr_y = anchors[:, 1] + 0.5*anc_h

        dy = pred_locs[..., 0]
        dx = pred_locs[..., 1]
        dh = pred_locs[..., 2]
        dw = pred_locs[..., 3]

        # decode
        ctr_y = dy * anc_h[None, :] + anc_ctr_y[None, :]
        ctr_x = dx * anc_w[None, :] + anc_ctr_x[None, :]
        h = torch.exp(dh) * anc_h[None, :]
        w = torch.exp(dw) * anc_w[None, :]

        # final
        proposals[..., 0] = ctr_x - 0.5*w
        proposals[..., 1] = ctr_y - 0.5*h
        proposals[..., 2] = ctr_x + 0.5*w
        proposals[..., 3] = ctr_y + 0.5*h

        return proposals

    def _process_proposals(self, conv_features, proposals, scores):
        """
        For each image in the batch:
          - Filter proposals by self.conf_thresh
          - NMS
          - Keep top_n
          - ROI Pool
        Return pooled features.
        """
        B = conv_features.size(0)
        pooled_list = []

        for b_idx in range(B):
            cur_scores = scores[b_idx]      # shape (#anchors,)
            cur_props = proposals[b_idx]    # shape (#anchors, 4)

            # 1) Confidence threshold
            conf_mask = cur_scores > self.conf_thresh
            filtered_boxes  = cur_props[conf_mask]
            filtered_scores = cur_scores[conf_mask]

            if filtered_boxes.size(0) == 0:
                # no proposals left
                pooled_list.append(torch.empty(0, device=conv_features.device))
                continue

            # 2) NMS
            keep_idx = ops.nms(filtered_boxes, filtered_scores, self.nms_thresh)
            keep_idx = keep_idx[:self.top_n]  # top top_n after NMS
            final_boxes = filtered_boxes[keep_idx]

            # 3) ROI Pool
            # Format => [batch_ind, x1, y1, x2, y2]
            roi_input = torch.cat([
                torch.full((final_boxes.size(0),1), b_idx, device=conv_features.device, dtype=torch.float32),
                final_boxes
            ], dim=1)

            # shape => (N_proposals, C, pool_size[0], pool_size[1])
            pooled = self.roi_pool(conv_features, roi_input)
            pooled_list.append(pooled)

        # Combine into a single tensor if you want
        pooled_feats = torch.cat(pooled_list, dim=0)
        return pooled_feats



##################################

def compute_iou_torch(boxes1, boxes2):
    """
    Compute IoU between two sets of boxes (PyTorch).
    boxes1: (N,4)  [y1, x1, y2, x2]
    boxes2: (M,4)
    Returns an (N, M) tensor of IoU values.
    """
    inter_y1 = torch.max(boxes1[:, None, 0], boxes2[:, 0])  # (N, M)
    inter_x1 = torch.max(boxes1[:, None, 1], boxes2[:, 1])
    inter_y2 = torch.min(boxes1[:, None, 2], boxes2[:, 2])
    inter_x2 = torch.min(boxes1[:, None, 3], boxes2[:, 3])

    inter_h = (inter_y2 - inter_y1).clamp(min=0)
    inter_w = (inter_x2 - inter_x1).clamp(min=0)
    inter_area = inter_h * inter_w

    area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])  # shape [N]
    area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])  # shape [M]
    union_area = area1[:, None] + area2 - inter_area

    iou = inter_area / union_area
    return iou

def hierarchical_sample_anchors(all_anchor_idxs, anchors, num_to_sample=256):
    """
    Given a set of anchor indices (e.g. pos or neg) and the anchor boxes themselves,
    split them by size (small/medium/large) and sample from each group equally.

    Inputs:
      all_anchor_idxs: 1D tensor of indices into 'anchors' that we want to sample from.
      anchors: (N, 4) the full anchor set as a torch.Tensor.
      num_to_sample: how many total anchors we want.

    Returns:
      chosen_idxs: the indices (subset of all_anchor_idxs) that we keep.
    """
    # Subset just the anchors we care about
    subset_anchors = anchors[all_anchor_idxs]  # shape: (K, 4)
    # Compute sizes = (y2-y1)*(x2-x1)
    sizes = (subset_anchors[:, 2] - subset_anchors[:, 0]) * (subset_anchors[:, 3] - subset_anchors[:, 1])

    # Percentile thresholds for small/medium/large
    small_thresh = torch.quantile(sizes, 0.33)
    large_thresh = torch.quantile(sizes, 0.66)

    small_mask = sizes <= small_thresh
    large_mask = sizes > large_thresh
    medium_mask = (~small_mask) & (~large_mask)

    # Split into three groups
    small_idxs  = all_anchor_idxs[ small_mask ]
    medium_idxs = all_anchor_idxs[ medium_mask ]
    large_idxs  = all_anchor_idxs[ large_mask ]
    per_group = num_to_sample // 3

    # If a group doesn’t have enough anchors, sample with replacement or clamp
    chosen_small  = sample_with_clamp(small_idxs,  per_group)
    chosen_medium = sample_with_clamp(medium_idxs, per_group)
    chosen_large  = sample_with_clamp(large_idxs,  per_group)

    chosen_idxs = torch.cat([chosen_small, chosen_medium, chosen_large], dim=0)
    return chosen_idxs


def sample_with_clamp(indices, num_needed):
    """
    Randomly choose 'num_needed' items from 'indices' (1D tensor).
    If we have fewer than num_needed, sample with replacement to keep it simple.
    """
    if len(indices) == 0:
        # If no anchors in that category, return empty
        return indices
    if len(indices) >= num_needed:
        rand = torch.randperm(len(indices))[:num_needed]
        return indices[rand]
    else:
        # sample with replacement
        extra = num_needed - len(indices)
        rand_base = torch.randperm(len(indices))
        # fill up
        chosen = torch.cat([indices, indices[rand_base[:extra]]], dim=0)
        return chosen

####################

def create_ground_truth_rect(bb, color='blue'):
    # Now handles [x1,y1,x2,y2]:
    bb = np.array(bb, dtype=np.float32)
    x1, y1, x2, y2 = bb
    return plt.Rectangle(
        (x1, y1),            # top-left corner
        x2 - x1,             # width
        y2 - y1,             # height
        color=color, fill=False, lw=3
    )

def show_ground_truth_bbs(im, bbs):
    im_np = unnormalize_tensor(im)
    plt.imshow(np.transpose(im_np, (1, 2, 0)))
    for bb in bbs:
        plt.gca().add_patch(create_ground_truth_rect(bb))
    plt.show()

############################


def bbox_generation(images, targets, X_FM, Y_FM):
    B = len(images)
    C, H_IMG, W_IMG = images[0].shape  # (C, H, W)
    
    total_positions = X_FM * Y_FM
    num_anchors_per_pos = len(ratios) * len(anchor_scales)
    total_anchors = total_positions * num_anchors_per_pos

    # Strides
    stride_x = float(W_IMG)/X_FM
    stride_y = float(H_IMG)/Y_FM

    # Make grid centers
    shift_x = np.arange(stride_x, (X_FM+1)*stride_x, stride_x)
    shift_y = np.arange(stride_y, (Y_FM+1)*stride_y, stride_y)
    shift_x, shift_y = np.meshgrid(shift_x, shift_y)
    centers = np.stack([
        shift_x.ravel() - stride_x/2.0,
        shift_y.ravel() - stride_y/2.0
    ], axis=1)  # shape (total_positions,2) => (cx, cy)

    # Build anchors in [x1,y1,x2,y2]
    anchors_list = []
    for (cx, cy) in centers:
        for ratio in ratios:
            for scale in anchor_scales:
                h = stride_y*scale*np.sqrt(ratio)
                w = stride_x*scale*np.sqrt(1.0/ratio)
                x1 = cx - 0.5*w
                y1 = cy - 0.5*h
                x2 = cx + 0.5*w
                y2 = cy + 0.5*h
                anchors_list.append([x1, y1, x2, y2])
    anchors = np.array(anchors_list, dtype=np.float32)

    # Valid anchors fully in the image
    valid_idx = np.where(
        (anchors[:,0] >= 0) & (anchors[:,1] >= 0) &
        (anchors[:,2] <= W_IMG) & (anchors[:,3] <= H_IMG)

    )[0]

    # For each image in the batch, label anchors, etc.
    anchor_locs_all, anchor_labels_all = [], []
    pos_iou_th = 0.7
    neg_iou_th = 0.3
    n_sample = 256
    pos_ratio = 0.5

    for i in range(B):
        labels = -1*np.ones((total_anchors,), dtype=np.int32)
        locs   = np.zeros((total_anchors, 4), dtype=np.float32)

        gt_boxes = targets[i]["boxes"].cpu().numpy()  # shape (M,4) in [x1,y1,x2,y2]
        if gt_boxes.shape[0] > 0:
            # IoU only among valid anchors
            valid_anchors = anchors[valid_idx]
            ious = compute_iou_vectorized(valid_anchors, gt_boxes)  # shape [N_valid, M]
            max_ious  = np.max(ious, axis=1)
            argmax_ious = np.argmax(ious, axis=1)

            # Label: pos => iou>=0.7, neg => iou<0.3
            valid_labels = -1*np.ones_like(max_ious, dtype=np.int32)
            valid_labels[max_ious >= pos_iou_th] = 1
            valid_labels[max_ious <  neg_iou_th] = 0

            # Force each gt box to have at least one positive anchor
            gt_max_ious = np.max(ious, axis=0)
            for j, g_iou in enumerate(gt_max_ious):
                idxs = np.where(ious[:,j] == g_iou)[0]
                valid_labels[idxs] = 1

            # Subsample
            pos_inds = np.where(valid_labels == 1)[0]
            neg_inds = np.where(valid_labels == 0)[0]

            num_pos = int(n_sample*pos_ratio)
            if len(pos_inds) > num_pos:
                disable_pos = np.random.choice(pos_inds, size=len(pos_inds)-num_pos, replace=False)
                valid_labels[disable_pos] = -1

            num_neg = n_sample - np.sum(valid_labels==1)
            if len(neg_inds) > num_neg:
                disable_neg = np.random.choice(neg_inds, size=len(neg_inds)-num_neg, replace=False)
                valid_labels[disable_neg] = -1

            # Compute loc for positives
            valid_locs = np.zeros((valid_anchors.shape[0], 4), dtype=np.float32)
            pos_final = np.where(valid_labels == 1)[0]
            if len(pos_final)>0:
                posA = valid_anchors[pos_final]
                # anchor center
                anc_w = posA[:,2] - posA[:,0]
                anc_h = posA[:,3] - posA[:,1]
                anc_ctr_x = posA[:,0] + 0.5*anc_w
                anc_ctr_y = posA[:,1] + 0.5*anc_h

                # matched GT
                match_gt = gt_boxes[argmax_ious[pos_final]]
                gt_w = match_gt[:,2] - match_gt[:,0]
                gt_h = match_gt[:,3] - match_gt[:,1]
                gt_ctr_x = match_gt[:,0] + 0.5*gt_w
                gt_ctr_y = match_gt[:,1] + 0.5*gt_h

                dx = (gt_ctr_x - anc_ctr_x)/anc_w
                dy = (gt_ctr_y - anc_ctr_y)/anc_h
                dw = np.log(gt_w/anc_w)
                dh = np.log(gt_h/anc_h)
                valid_locs[pos_final] = np.stack([dy, dx, dh, dw], axis=1)

            # Fill in the big array
            labels[valid_idx] = valid_labels
            locs[valid_idx]   = valid_locs

        anchor_labels_all.append(labels)
        anchor_locs_all.append(locs)

    anchor_labels_all = np.stack(anchor_labels_all, axis=0)  # (B, total_anchors)
    anchor_locs_all   = np.stack(anchor_locs_all,   axis=0)  # (B, total_anchors, 4)

    return anchor_locs_all, anchor_labels_all, anchors  # anchors in [x1,y1,x2,y2]

#############################

def visualize_attention(feat, title=""):
    """Visualize the attention maps from CBAM"""
    plt.figure(figsize=(10,5))
    # Channel attention
    avg_pool = torch.mean(feat, dim=1, keepdim=True)
    max_pool, _ = torch.max(feat, dim=1, keepdim=True)
    plt.subplot(1,2,1)
    plt.imshow(avg_pool[0].cpu().detach().numpy(), cmap='jet')
    plt.title(f"{title} - Channel Avg")
    plt.subplot(1,2,2)
    plt.imshow(max_pool[0].cpu().detach().numpy(), cmap='jet')
    plt.title(f"{title} - Channel Max")
    plt.show()

################################
rpn_model = EnhancedRPNWithROI(in_channels=512, mid_channels=256).to(device)

vgg_model = torchvision.models.vgg16(pretrained=True).to(device)
vgg_model.eval()
for param in vgg_model.features.parameters():
    param.requires_grad = True

base_lr = 0.001

backbone_params = list(vgg_model.features.parameters())

new_params = list(rpn_model.parameters())

req_features = [layer for layer in list(vgg_model.features)[:30]]


# optimizer = torch.optim.Adam(rpn_model.parameters(), lr=0.001, weight_decay=1e-4)

optimizer = torch.optim.Adam([
    {'params': backbone_params, 'lr': base_lr * 0.1},
    {'params': new_params, 'lr': base_lr}
])

###################################

def train_epochs(req_features, rpn_model, optimizer, train_dl,
                 epochs=20, rpn_lambda=10, iou_threshold=0.2, top_k=400):
    """
    Integrated hierarchical sampling in place of random anchor subsampling.
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # rpn_model.load_state_dict(torch.load("./cbma_final.pth", map_location=device))
    rpn_model.load_state_dict(torch.load("./1206.pth", map_location=device))
    rpn_model.train()
    rpn_model.to(device)

    epoch_train_recalls = []
    epoch_train_errors = []
    


    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        total_samples = 0
        sum_loss = 0.0
        sum_loss_cls = 0.0
        sum_loss_loc = 0.0
        batch_recalls = []

        for batch in train_dl:
            images = torch.stack([img.to(device) for img in batch["images"]])
            targets = [{"boxes": b.to(device), "labels": l.to(device)}
                       for b, l in zip(batch["boxes"], batch["labels"])]

            B = images.size(0)
            total_samples += B

            # ----- 1) Frozen backbone features -----
            with torch.no_grad():
                feat = images
                for m in req_features:
                    m.to(device)
                    feat = m(feat)
            X_FM, Y_FM = feat.shape[2], feat.shape[3]

            # ----- 2) Generate all anchors (like bbox_generation but no labeling) -----
            anchors = generate_anchor_grid_np(X_FM, Y_FM, ratios, anchor_scales)  
            # shape = (total_anchors, 4) in NumPy
            anchors = torch.from_numpy(anchors).float().to(device)

            # Indices of anchors inside the image
            H_IMG, W_IMG = images.shape[2], images.shape[3]
            valid_idx = torch.where(
                (anchors[:,0] >= 0) &            # x1 >= 0
                (anchors[:,2] <= W_IMG) &        # x2 <= width
                (anchors[:,1] >= 0) &            # y1 >= 0
                (anchors[:,3] <= H_IMG)          # y2 <= height
            )[0]

            # Prepare big tensors for labels/locs across the batch
            # We'll label only the anchors in valid_idx, but keep the shape the same length as 'anchors'
            anchor_labels_all = -1 * torch.ones((B, anchors.size(0)), dtype=torch.int32, device=device)
            anchor_locs_all   = torch.zeros((B, anchors.size(0), 4),  dtype=torch.float32, device=device)

            # ----- 3) Per-image anchor labeling & hierarchical sampling -----
            pos_iou_threshold = 0.7
            neg_iou_threshold = 0.3
            n_sample  = 256    # total anchors to keep
            pos_ratio = 0.5

            for i in range(B):
                gt_boxes_i = targets[i]["boxes"]   # shape (M,4)
                if gt_boxes_i.numel() == 0:
                    # No ground truth => skip
                    continue

                # 3a) IoU labeling for the valid anchors
                valid_anchors = anchors[valid_idx]
                ious = compute_iou_torch(valid_anchors, gt_boxes_i)  # shape (N_valid, M)

                max_ious, argmax_ious = ious.max(dim=1)   # best GT for each anchor
                valid_labels = -1 * torch.ones_like(max_ious, dtype=torch.int32)

                # Positive > pos_iou_threshold
                valid_labels[max_ious >= pos_iou_threshold] = 1
                # Negative < neg_iou_threshold
                valid_labels[max_ious < neg_iou_threshold]  = 0

                # Ensure each GT box has at least one anchor
                gt_max_ious, gt_argmax = ious.max(dim=0)  # best anchor for each GT
                for j in range(gt_boxes_i.size(0)):
                    best_anchor_for_gt = (ious[:, j] == gt_max_ious[j]).nonzero(as_tuple=True)[0]
                    valid_labels[best_anchor_for_gt] = 1

                # pos_count = (valid_labels == 1).sum().item()
                # print(f"Found {pos_count} positives in image {i}")

                # 3b) Separate pos vs neg anchor indices
                pos_inds = torch.where(valid_labels == 1)[0]
                neg_inds = torch.where(valid_labels == 0)[0]

                # 3c) Hierarchical sampling: sample some positives, some negatives
                # We'll try to keep up to pos_ratio*n_sample as positives, remainder as negatives.
                # So let's compute how many positives we want
                max_pos = int(pos_ratio * n_sample)
                n_pos = min(len(pos_inds), max_pos)
                n_neg = n_sample - n_pos

                # (i) sample from pos_inds by size
                chosen_pos = hierarchical_sample_anchors(pos_inds, anchors[valid_idx], num_to_sample=n_pos)
                # (ii) sample from neg_inds by size
                chosen_neg = hierarchical_sample_anchors(neg_inds, anchors[valid_idx], num_to_sample=n_neg)

                # anything not chosen => label = -1
                chosen_mask = torch.zeros_like(valid_labels, dtype=torch.bool)
                chosen_mask[chosen_pos] = True
                chosen_mask[chosen_neg] = True

                valid_labels[~chosen_mask] = -1

                # 3d) Compute regression targets for the chosen positives
                valid_locs = torch.zeros((valid_anchors.size(0), 4), dtype=torch.float32, device=device)
                pos_chosen_mask = (valid_labels == 1)
                if pos_chosen_mask.sum() > 0:
                    pos_anchors = valid_anchors[pos_chosen_mask]
                    assigned_gt = gt_boxes_i[argmax_ious[pos_chosen_mask]]  # matched GT

                    # Convert anchor + GT to center/width/height
                    anc_h = pos_anchors[:,2] - pos_anchors[:,0]
                    anc_w = pos_anchors[:,3] - pos_anchors[:,1]
                    anc_ctr_y = pos_anchors[:,0] + 0.5*anc_h
                    anc_ctr_x = pos_anchors[:,1] + 0.5*anc_w

                    gt_h = assigned_gt[:,2] - assigned_gt[:,0]
                    gt_w = assigned_gt[:,3] - assigned_gt[:,1]
                    gt_ctr_y = assigned_gt[:,0] + 0.5*gt_h
                    gt_ctr_x = assigned_gt[:,1] + 0.5*gt_w

                    dy = (gt_ctr_y - anc_ctr_y) / anc_h
                    dx = (gt_ctr_x - anc_ctr_x) / anc_w
                    dh = torch.log(gt_h / anc_h)
                    dw = torch.log(gt_w / anc_w)

                    valid_locs[pos_chosen_mask] = torch.stack([dy, dx, dh, dw], dim=1)

                # 3e) Put back into the full-size arrays
                anchor_labels_all[i, valid_idx] = valid_labels
                anchor_locs_all[i, valid_idx]   = valid_locs

            # Convert final label/loc arrays to float for loss
            gt_scores = anchor_labels_all.to(torch.float32)  # (B, total_anchors)
            gt_locs   = anchor_locs_all                      # (B, total_anchors, 4)

            # ----- 4) Forward pass through RPN -----
            pred_locs, pred_scores, objectness_score, pooled_feats = rpn_model(
                feat, 
                anchors=anchors, 
                do_roi=True
            )
            # pred_locs: (B, all_anchors_in_featuremap, 4)
            # pred_scores: (B, all_anchors_in_featuremap, 2)

            # print("pred_scores shape:", pred_scores.shape)  # expecting (B, A, 2)
            # print("gt_scores shape:", gt_scores.shape)      # expecting (B, A)

            # print("After flatten:")
            # print("pred_scores.view(-1,2) ->", pred_scores.view(-1,2).shape)
            # print("gt_scores.view(-1)    ->", gt_scores.view(-1).shape)

            # Classification loss (ignore label = -1)
            cls_loss = F.cross_entropy(
                pred_scores.view(-1, 2),
                gt_scores.view(-1).long(),
                ignore_index=-1
            )

            # Smooth L1 for positives
            pos_mask = (gt_scores == 1)
            if pos_mask.sum() > 0:
                pred_pos = pred_locs[pos_mask]
                gt_pos   = gt_locs[pos_mask]
                diff = torch.abs(gt_pos - pred_pos)
                loc_loss = torch.where(diff < 1, 0.5 * diff**2, diff - 0.5).sum()
                loc_loss = loc_loss / pos_mask.sum().float()
            else:
                loc_loss = torch.tensor(0.0, device=device)

            loss = cls_loss + rpn_lambda * loc_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss     += loss.item()
            sum_loss_cls += cls_loss.item()
            sum_loss_loc += (rpn_lambda * loc_loss).item()

            # ----- 5) Measure recall in this batch -----
            with torch.no_grad():
                batch_recall = 0.0
                count = 0
                for i in range(B):
                    # Convert predicted anchor offsets back to box coords
                    rois = pred_bbox_to_xywh(pred_locs[i], anchors)  # anchors shape (total_anchors, 4)

                    # Top-k proposals
                    k = min(top_k, objectness_score[i].shape[0])
                    topk_inds = torch.topk(objectness_score[i], k=k).indices
                    proposals = rois[topk_inds.cpu().numpy()]

                    gt_boxes = batch["boxes"][i].cpu().numpy()  # shape (M,4) in [y1,x1,y2,x2] or [x1,y1,x2,y2]?

                    if len(gt_boxes) > 0:
                        matched = 0
                        for gt in gt_boxes:
                            # If your GT is [y1,x1,y2,x2], do nothing
                            # If your GT is [x1,y1,x2,y2], you need to reorder
                            # Example if your code expects [y1,x1,y2,x2]:
                            # gt_converted = np.array([gt[1], gt[0], gt[3], gt[2]])
                            # but adapt to your dataset as needed
                            gt_converted = gt  
                            ious = compute_iou_vectorized(proposals, np.expand_dims(gt_converted, 0))
                            best_iou = np.max(ious) if ious.size > 0 else 0.0
                            if best_iou >= iou_threshold:
                                matched += 1
                        recall = matched / len(gt_boxes)
                        batch_recall += recall
                        count += 1

                if count > 0:
                    batch_recalls.append(batch_recall / count)

        # End of epoch
        epoch_recall = np.mean(batch_recalls) if batch_recalls else 0.0
        epoch_train_recalls.append(epoch_recall)
        epoch_train_errors.append(1 - epoch_recall)

        print(f"Epoch {epoch+1}: Loss {sum_loss/total_samples:.3f} | "
              f"Recall: {epoch_recall:.3f} | Error: {1-epoch_recall:.3f}")

        if (epoch+1) % 5 == 0:
            torch.save(rpn_model.state_dict(), f"./cecilia_{epoch+1}.pth")

    torch.save(rpn_model.state_dict(), "./1206.pth")
    # ---- (Optional) Plot recall/error over epochs ----
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs+1), epoch_train_recalls, 'b-o', label='Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.title('Training Recall Over Epochs')
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs+1), epoch_train_errors, 'r-o', label='Error')
    plt.xlabel('Epoch')
    plt.ylabel('Error (1 - Recall)')
    plt.title('Training Error Over Epochs')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    # torch.save(rpn_model.state_dict(), "./cbma_final_more_anchor.pth")
    return rpn_model
###################

In [None]:

image_dir = "C:/360/trainA_origin_700/trainA_original_700"
pt_dir = 'C:/360/trainA_testing2'
json_file_path =  "C:/360/bdd100k_labels_images_train.json"
# Extract labels from JSON (adjust number as desired)
all_labels = extract_first_n_labels(json_file_path, 20000)

# Create the custom dataset using your method
dataset = CustomDataset(image_dir, all_labels, pt_dir)

# Split using random_split (70% train, 15% val, 15% test)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")

# Create DataLoaders using your custom collate function
batch_size = 8
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn, num_workers=0)
val_loader   = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn, num_workers=0)

In [None]:
def compute_iou_vectorized_yxyx(boxes1, boxes2):
    """
    Compute IoU between boxes1 (N,4) and boxes2 (M,4) in y1,x1,y2,x2 format.
    Returns IoU matrix of shape (N, M).
    """
    # Convert to numpy if they're tensors
    if isinstance(boxes1, torch.Tensor):
        boxes1 = boxes1.cpu().numpy()
    if isinstance(boxes2, torch.Tensor):
        boxes2 = boxes2.cpu().numpy()

    inter_y1 = np.maximum(boxes1[:, None, 0], boxes2[None, :, 0])
    inter_x1 = np.maximum(boxes1[:, None, 1], boxes2[None, :, 1])
    inter_y2 = np.minimum(boxes1[:, None, 2], boxes2[None, :, 2])
    inter_x2 = np.minimum(boxes1[:, None, 3], boxes2[None, :, 3])

    inter_h = np.maximum(inter_y2 - inter_y1, 0)
    inter_w = np.maximum(inter_x2 - inter_x1, 0)
    inter_area = inter_h * inter_w

    area1 = (boxes1[:,2]-boxes1[:,0])*(boxes1[:,3]-boxes1[:,1])
    area2 = (boxes2[:,2]-boxes2[:,0])*(boxes2[:,3]-boxes2[:,1])

    union = area1[:,None] + area2[None,:] - inter_area
    iou = inter_area / (union + 1e-6)
    return iou


def compute_recall_at_threshold(proposals, gt_boxes, iou_thresh=0.3):
    """
    Compute recall at a fixed IoU threshold.

    For each GT box, if the best IoU among proposals is >= iou_thresh, count it as a match.
    Recall = (number of matched GT boxes) / (total GT boxes)
    """
    if len(gt_boxes) == 0:
        return 0.0

    # Compute IoU between proposals and ground truth boxes
    ious = compute_iou_vectorized_yxyx(proposals, gt_boxes)  # shape (N, M)
    best_ious = ious.max(axis=0)  # best IoU for each GT
    recall = np.sum(best_ious >= iou_thresh) / float(len(gt_boxes))
    return recall



def compute_mean_best_iou(proposals, gt_boxes):
    """
    Compute the mean best IoU across all GT boxes.

    For each GT box, we take the average of the first 5 highest IoU achieved by any proposal.
    """
    if len(gt_boxes) == 0:
        return 0.0

    ious = compute_iou_vectorized_yxyx(proposals, gt_boxes)  # shape (N, M)
    best_ious = ious.max(axis=0)
    return np.mean(best_ious)


def compute_proposal_density(proposals, gt_boxes, iou_thresh=0.5):
    """
    Compute the average number of qualifying proposals PER GT BOX.
    Should return values >= 0 (typically >> 1 for good RPN performance).
    """
    if len(gt_boxes) == 0 or len(proposals) == 0:
        return 0.0

    # Ensure numpy arrays
    proposals = np.asarray(proposals, dtype=np.float32)
    gt_boxes = np.asarray(gt_boxes, dtype=np.float32)

    # Verify box format: (N,4) where columns are [y1,x1,y2,x2]
    assert proposals.shape[1] == 4 and gt_boxes.shape[1] == 4

    # Compute IoU matrix (N proposals x M GT boxes)
    ious = compute_iou_vectorized_yxyx(proposals, gt_boxes)

    # Count qualifying proposals PER GT BOX (sum across axis=0)
    qual_counts = (ious >= iou_thresh).sum(axis=0)

    # Return average count per GT box
    return float(np.mean(qual_counts))

def plot_training_metrics(train_losses, val_metrics, eval_every):
    plt.figure(figsize=(15, 5))

    # Plot training loss
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()

    # Plot validation metrics if available
    if len(val_metrics) > 0:
        # Get x-axis positions for validation points
        val_epochs = [(i+1)*eval_every for i in range(len(val_metrics))]

        # Plot Average Recall
        plt.subplot(1, 3, 2)
        ar_values = [np.mean(m['recall'])  for m in val_metrics]
        plt.plot(val_epochs, ar_values, 'o-', label='Avg Recall')
        plt.xlabel('Epoch')
        plt.ylabel('Average Recall')
        plt.title('Validation Recall')
        plt.ylim(0, 1)
        plt.legend()

        # Plot Mean IoU
        plt.subplot(1, 3, 3)
        iou_values = [np.mean(m['mean_best_iou']) for m in val_metrics]
        plt.plot(val_epochs, iou_values, 'o-', label='Mean IoU')
        plt.xlabel('Epoch')
        plt.ylabel('Mean Best IoU')
        plt.title('Localization Quality')
        plt.ylim(0, 1)
        plt.legend()

    plt.tight_layout()
    plt.show()


def decode_boxes(anchors, pred_locs):
    """
    Decode regression offsets (pred_locs) to actual box coordinates based on anchors.
    This function assumes that the offsets are in the form [dy, dx, dh, dw].

    Args:
        anchors (Tensor): Anchor boxes of shape (N, 4) in [y1, x1, y2, x2] format.
        pred_locs (Tensor): Regression offsets of shape (N, 4).

    Returns:
        Tensor: Decoded boxes of shape (N, 4) in [y1, x1, y2, x2] format.
    """
    anchors = anchors.float()
    # Compute widths, heights, and center coordinates of the anchors.
    heights = anchors[:, 2] - anchors[:, 0]
    widths = anchors[:, 3] - anchors[:, 1]
    ctr_y = anchors[:, 0] + 0.5 * heights
    ctr_x = anchors[:, 1] + 0.5 * widths

    # Get the predicted offsets
    dy = pred_locs[:, 0]
    dx = pred_locs[:, 1]
    dh = pred_locs[:, 2]
    dw = pred_locs[:, 3]

    # Apply the offsets to get the predicted center
    pred_ctr_y = ctr_y + dy * heights
    pred_ctr_x = ctr_x + dx * widths
    # Compute predicted width and height
    pred_h = heights * torch.exp(dh)
    pred_w = widths * torch.exp(dw)

    # Convert center coordinates back to box coordinates
    y1 = pred_ctr_y - 0.5 * pred_h
    x1 = pred_ctr_x - 0.5 * pred_w
    y2 = pred_ctr_y + 0.5 * pred_h
    x2 = pred_ctr_x + 0.5 * pred_w

    pred_boxes = torch.stack([y1, x1, y2, x2], dim=1)
    return pred_boxes


In [None]:
from torchvision.ops import box_iou

def recursive_nms(boxes, scores, iou_threshold=0.5, recursion_limit=10):
    """
    Custom NMS that recursively combines overlapping boxes by comparing all pairs.
    
    Args:
        boxes: Tensor of shape [N, 4] (x1, y1, x2, y2 format)
        scores: Tensor of shape [N] containing confidence scores
        iou_threshold: IoU threshold for combining boxes
        recursion_limit: Maximum number of recursive passes
        
    Returns:
        combined_boxes: Tensor of combined boxes
        keep_indices: Indices of kept boxes from original input
    """
    if len(boxes) == 0:
        return boxes, torch.empty(0, dtype=torch.long, device=boxes.device)
    
    # Convert to float32 if needed
    boxes = boxes.float()
    
    # Initialize list to track which boxes to keep
    keep = torch.ones(len(boxes), dtype=torch.bool, device=boxes.device)
    
    # Recursive combining
    changed = True
    recursion_count = 0
    
    while changed and recursion_count < recursion_limit:
        changed = False
        iou_matrix = box_iou(boxes, boxes)  # [N, N] matrix
        
        # Zero out diagonal (self-comparisons)
        iou_matrix.fill_diagonal_(0)
        
        # Find all pairs that exceed IoU threshold
        overlaps = iou_matrix > iou_threshold
        
        for i in range(len(boxes)):
            if not keep[i]:
                continue
                
            # Find all boxes that overlap with current box
            overlapping_indices = torch.where(overlaps[i])[0]
            
            if len(overlapping_indices) > 0:
                # Get the overlapping boxes and their scores
                overlapping_boxes = boxes[overlapping_indices]
                overlapping_scores = scores[overlapping_indices]
                
                # Combine with current box (weighted average by scores)
                combined_box = combine_box_group(
                    torch.cat([boxes[i].unsqueeze(0), overlapping_boxes]),
                    torch.cat([scores[i].unsqueeze(0), overlapping_scores])
                )
                
                # Replace current box with combined version
                boxes[i] = combined_box
                
                # Mark overlapping boxes for removal
                keep[overlapping_indices] = False
                changed = True
        
        # Filter boxes after each pass
        boxes = boxes[keep]
        scores = scores[keep]
        keep = torch.ones(len(boxes), dtype=torch.bool, device=boxes.device)
        recursion_count += 1
    
    return boxes, torch.where(keep)[0]

def combine_box_group(boxes, scores):
    """
    Combine a group of boxes into one representative box using score-weighted average
    """
    weights = scores / scores.sum()
    combined_box = torch.sum(boxes * weights.view(-1, 1), dim=0)
    return combined_box

In [None]:

#####################

small_train_dataset = torch.utils.data.Subset(train_dataset, list(range(3000)))
small_train_loader = torch.utils.data.DataLoader(small_train_dataset, batch_size=batch_size, shuffle=True,
                                           collate_fn=custom_collate_fn, num_workers=0)

small_val_dataset = torch.utils.data.Subset(val_dataset, list(range(300)))
small_val_loader = torch.utils.data.DataLoader(small_val_dataset, batch_size=batch_size, shuffle=True,
                                           collate_fn=custom_collate_fn, num_workers=0)

###########################

def generate_anchor_grid_np(X_FM, Y_FM, ratios, scales):
    """
    Generate a (N,4) NumPy array of anchor boxes over an X_FM x Y_FM feature map.

    The final shape is (X_FM * Y_FM * len(ratios)*len(scales), 4),
    where each row is [y1, x1, y2, x2].
    """
    import numpy as np

    # X_FM => width of feature map (number of columns)
    # Y_FM => height of feature map (number of rows)

    total_positions = X_FM * Y_FM
    num_anchor_per_pos = len(ratios) * len(scales)
    total_anchors = total_positions * num_anchor_per_pos

    # We assume your input image is (H=ISIZE[0], W=ISIZE[1])
    # You may also do something like:
    # sub_sampling_x = float(W_IMG) / X_FM
    # sub_sampling_y = float(H_IMG) / Y_FM
    # But in your original code, you used the global ISIZE.
    # Adjust as needed if your shape is dynamic.

    # If you’re using a fixed ISIZE = (height=720, width=1280), do:
    H_IMG, W_IMG = ISIZE[0], ISIZE[1]

    sub_sampling_x = W_IMG / float(X_FM)
    sub_sampling_y = H_IMG / float(Y_FM)

    # Create a grid of center positions
    shift_x = np.arange(sub_sampling_x, (X_FM + 1) * sub_sampling_x, sub_sampling_x)
    shift_y = np.arange(sub_sampling_y, (Y_FM + 1) * sub_sampling_y, sub_sampling_y)

    shift_x, shift_y = np.meshgrid(shift_x, shift_y)  # shape (Y_FM, X_FM)
    # Now each cell center is (cy, cx) = (shift_y[r,c] - sub_sampling_y/2, shift_x[r,c] - sub_sampling_x/2)
    centers = np.stack([
        shift_y.ravel() - sub_sampling_y / 2.0,
        shift_x.ravel() - sub_sampling_x / 2.0
    ], axis=1)  # shape (total_positions, 2)

    anchors = []
    for cy, cx in centers:
        for ratio in ratios:
            for scale in scales:
                h = sub_sampling_y * scale * np.sqrt(ratio)
                w = sub_sampling_x * scale * np.sqrt(1. / ratio)

                y1 = cy - h * 0.5
                x1 = cx - w * 0.5
                y2 = cy + h * 0.5
                x2 = cx + w * 0.5

                anchors.append([x1, y1, x2, y2])

    anchors = np.array(anchors, dtype=np.float32)  # shape (total_anchors, 4)


    return anchors

#########################

import matplotlib.pyplot as plt
import numpy as np

def debug_check_image_and_boxes(dataloader, num_images=1):
    """
    Pull 'num_images' samples from the dataloader, plot the image after resizing,
    and overlay the GT boxes. Also print the box sizes relative to the image dimension.
    """

    # 1) Grab a single batch
    batch = next(iter(dataloader))
    
    images = batch["images"][:num_images]  # shape (num_images, C, H, W)
    boxes_list = batch["boxes"][:num_images]
    indices = batch["indices"][:num_images]

    for i in range(len(images)):
        img_tensor = images[i]
        gt_boxes = boxes_list[i].cpu().numpy()  # shape (N, 4) in [x1,y1,x2,y2]

        # 2) Print debug info
        C, H, W = img_tensor.shape
        print(f"\n--- Debug: Image index {indices[i]} ---")
        print(f"  Tensor shape: (C={C}, H={H}, W={W}) => (should be 3,720,1280 if that is your resize?)")
        print(f"  Num GT boxes: {len(gt_boxes)}")

        # 3) Plot the image
        img_np = (img_tensor*255.0).clamp(0,255).byte().cpu().numpy()  # shape (C,H,W)
        img_np = np.transpose(img_np, (1,2,0))  # (H,W,C)
        plt.figure(figsize=(10,6))
        plt.imshow(img_np)

        # 4) Plot the GT boxes
        for b_idx, box in enumerate(gt_boxes):
            x1,y1,x2,y2 = box
            # Print size info
            box_w = x2 - x1
            box_h = y2 - y1
            print(f"  Box {b_idx}: [x1={x1:.1f}, y1={y1:.1f}, x2={x2:.1f}, y2={y2:.1f}]"
                  f" => w={box_w:.1f}, h={box_h:.1f}")

            # Add rectangle
            rect = plt.Rectangle((x1, y1), box_w, box_h,
                                 edgecolor='green', fill=False, lw=2)
            plt.gca().add_patch(rect)

        plt.title(f"Image idx {indices[i]} - Resized: W={W},H={H}")
        plt.show()


debug_check_image_and_boxes(train_loader, num_images=2)

######################
import json
def validate(rpn_model, data_loader, req_features,
             n_images=7, top_k=40, iou_threshold=0.5):
    """
    Validate the RPN on up to `n_images` from each batch.
    1) Generates anchors with bbox_generation(...) 
    2) Runs the RPN to get predicted offsets & objectness
    3) Decodes bounding boxes with pred_bbox_to_xywh(...)
    4) Visualizes bounding boxes
    5) Measures recall, error (1 - recall), and average IoU.

    Args:
        rpn_model: your trained RPN (EnhancedRPN) in eval mode
        data_loader: DataLoader for your validation (or test) dataset
        req_features: the list of frozen VGG layers (or backbone layers)
        n_images: how many images to process from each batch
        top_k: how many proposals to visualize (by objectness)
        iou_threshold: IoU threshold used to decide if a GT is "matched"

    Returns:
        errors, recalls, avg_ious (lists)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    rpn_model.eval().to(device)
    save_json_path="C:/360/exported_anchors.json"
    anchor_results = {}
    errors = []     # 1 - recall, per image
    recalls = []    # recall, per image
    avg_ious = []   # average best IoU, per image

    with torch.no_grad():
        # Loop over batches
        count = 0
        for batch_idx, batch in enumerate(data_loader):
            count += 7
            print(count)
            # If you only want to visualize one batch, uncomment:
            # if batch_idx >= 1:
            #     break

            # Collect up to n_images from this batch
            images = batch["images"][:n_images].to(device)
            print("Batch keys:", batch.keys()["img_ids"])

            targets = [
                {"boxes": b.to(device), "labels": l.to(device)}
                for (b, l) in zip(batch["boxes"], batch["labels"])
            ][:n_images]
            image_ids = batch["img_ids"][:n_images]
           

            B = images.shape[0]
            if B == 0:
                continue

            # Forward pass through the *frozen* backbone
            feats = images.clone()
            for m in req_features:
                m.to(device)
                feats = m(feats)

            # The shape of the feature map
            X_FM, Y_FM = feats.shape[2], feats.shape[3]

            # Generate anchors using your original bbox_generation
            # This returns: (anchor_locs, anchor_labels, anchors)
            # but we only need "anchors" here for decoding
            _, _, anchors = bbox_generation(
                [img for img in images],  # pass as list
                targets,
                X_FM, Y_FM
            )
            # anchors is an np.array of shape (total_anchors, 4) in [y1, x1, y2, x2].
            anchors_torch = torch.from_numpy(anchors).float().to(device)

            # Run the RPN to get predicted offsets & classification
            pred_locs, pred_scores, objectness_score, pooled_feats = rpn_model(
                feats, 
                anchors_torch, 
                do_roi=True
            )
            # pred_locs.shape: (B, total_anchors, 4)
            # pred_scores.shape: (B, total_anchors, 2)
            # objectness_score.shape: (B, total_anchors)

            # Process each image in this batch
            for i in range(B):
                # Decode from anchors + predicted offsets => actual boxes
                rois = pred_bbox_to_xywh(pred_locs[i], anchors_torch)
                # rois now shape (total_anchors, 4). Format is [x1, y1, x2, y2],
                # if that's how your pred_bbox_to_xywh is implemented.

                # Choose top-k by objectness
                # if top_k is not None:
                #     k = min(top_k, objectness_score[i].shape[0])
                #     topk_inds = torch.topk(objectness_score[i], k=k).indices
                #     proposals = rois[topk_inds.cpu().numpy()]
                # else:
                #     proposals = rois
                # Choose top-k by objectness
                if top_k is not None:
                    k = min(top_k, objectness_score[i].shape[0])
                    topk_inds = torch.topk(objectness_score[i], k=k).indices
                    proposals = rois[topk_inds.cpu().numpy()]
                    scores = objectness_score[i][topk_inds].cpu().numpy()
                else:
                    proposals = rois
                    scores = objectness_score[i].cpu().numpy()

                # Apply recursive NMS to reduce overlapping proposals
                proposals_tensor = torch.from_numpy(proposals).float().to(device)
                scores_tensor = torch.from_numpy(scores).float().to(device)
                proposals, _ = recursive_nms(proposals_tensor, scores_tensor, iou_threshold=0.5, recursion_limit=top_k)
                proposals = proposals.cpu().numpy()  # convert back to numpy for visualization


                boxes_list = proposals.tolist()
                anchor_results[str(image_ids[i])] = boxes_list

                save_path = "C:/360/output_roi_file"
                # Visualize the proposals
                # images[i] => shape (3, H, W)
                print(count+i)
                show_corner_bbs(images[i], proposals, save_path, (count+i))

                # Grab ground-truth boxes & visualize
                gt_boxes = targets[i]["boxes"].cpu().numpy()  # shape (M,4)
                if len(gt_boxes) > 0:
                    
                   
                    show_ground_truth_bbs(images[i], gt_boxes, save_path, (count+i))

                    # Compute recall
                    matched = 0
                    image_ious = []

                    for gt in gt_boxes:
                        # If your GT is [x1, y1, x2, y2] but your proposals are [x1, y1, x2, y2],
                        # then you can pass them directly to compute_iou_vectorized.
                        # If your GT is [y1, x1, y2, x2], reorder accordingly!
                        gt_box = gt

                        # Compute IoU with all proposals => shape (k, 1)
                        ious = compute_iou_vectorized(proposals, np.expand_dims(gt_box, axis=0))
                        best_iou = np.max(ious) if ious.size > 0 else 0.0
                        image_ious.append(best_iou)
                        if best_iou >= iou_threshold:
                            matched += 1

                    # Summarize for this image
                    recall = matched / len(gt_boxes)
                    error = 1 - recall
                    avg_iou = np.mean(image_ious)

                    recalls.append(recall)
                    errors.append(error)
                    avg_ious.append(avg_iou)

                    print(f"[Val] Image {i} metrics:")
                    print(f"  - Recall: {recall:.3f}")
                    print(f"  - Error : {error:.3f}")
                    print(f"  - AvgIoU: {avg_iou:.3f}")
                else:
                    print(f"[Val] Image {i}: no ground-truth boxes available")

    # After we finish, plot the results
    if errors:
        plt.figure(figsize=(15, 5))

        # Plot 1: Recall
        plt.subplot(1, 3, 1)
        plt.plot(range(len(recalls)), recalls, 'b-o')
        plt.ylim(0, 1.05)
        plt.xlabel('Image Index')
        plt.ylabel('Recall')
        plt.title(f'Recall (IoU ≥ {iou_threshold})')
        plt.grid(True)

        # Plot 2: Error
        plt.subplot(1, 3, 2)
        plt.plot(range(len(errors)), errors, 'r-o')
        plt.ylim(0, 1.05)
        plt.xlabel('Image Index')
        plt.ylabel('Error (1 - Recall)')
        plt.title('Error per Image')
        plt.grid(True)

        # Plot 3: Average IoU
        plt.subplot(1, 3, 3)
        plt.plot(range(len(avg_ious)), avg_ious, 'g-o')
        plt.ylim(0, 1.05)
        plt.xlabel('Image Index')
        plt.ylabel('Avg Best IoU')
        plt.title('Proposal Quality (Best IoU)')
        plt.grid(True)

        plt.tight_layout()
        plt.show()

        print(f"\nValidation Summary:")
        print(f"  Mean Recall:  {np.mean(recalls):.3f} ± {np.std(recalls):.3f}")
        print(f"  Mean Error:   {np.mean(errors):.3f} ± {np.std(errors):.3f}")
        print(f"  Mean AvgIoU:  {np.mean(avg_ious):.3f} ± {np.std(avg_ious):.3f}")
    else:
        print("[Val] No ground-truth boxes found in the processed images.")
    
    os.makedirs(os.path.dirname(save_json_path), exist_ok=True)
    with open(save_json_path, "w") as f:
        json.dump(anchor_results, f, indent=2)
    
    # Switch model back to train mode if you like
    rpn_model.train().to(device)
    return errors, recalls, avg_ious


######################


In [None]:
trained_rpn = train_epochs(req_features, rpn_model, optimizer, small_train_loader, epochs=5, rpn_lambda=1)

# Validate (visualize predictions) on both training and validation sets
# print("Validation on training data:")
# validate(trained_rpn, small_train_loader, req_features)
# print("Validation on validation data:")
# validate(trained_rpn, small_val_loader, req_features)

In [None]:

# Define a directory to store results
save_path = "C:/360/output_roi"

import os
import matplotlib.pyplot as plt
from PIL import Image

def show_corner_bbs(img, bbs, save_path=None, image_id=None, title="Predicted Boxes"):
    """
    Show and optionally save predicted bounding boxes on an image.

    Args:
        img: Tensor image (C, H, W) normalized to [0, 1]
        bbs: List or array of bounding boxes in [x1, y1, x2, y2]
        save_path: Directory path to save the image (optional)
        image_id: Unique identifier for the saved image name (optional)
        title: Title of the plot (optional)
    """
    try:
        img_np = (img * 255.0).clamp(0, 255).byte().cpu().numpy()
        img_np = np.transpose(img_np, (1, 2, 0))  # (H, W, C)

        plt.figure(figsize=(10, 6))
        plt.imshow(img_np.astype(np.uint8))
        for bb in bbs:
            plt.gca().add_patch(create_corner_rect(bb))
        plt.title(title)
        plt.axis('off')

        if save_path and image_id is not None:
            os.makedirs(save_path, exist_ok=True)
            filename = os.path.join(save_path, f"pred_{image_id}.png")
            plt.savefig(filename, bbox_inches='tight', pad_inches=0)
        plt.close()
    except Exception as e:
        print(f"[Error saving/showing predicted boxes]: {e}")


def show_ground_truth_bbs(img, bbs, save_path=None, image_id=None, title="Ground Truth Boxes"):
    """
    Show and optionally save ground-truth bounding boxes on an image.

    Args:
        img: Tensor image (C, H, W) normalized to [0, 1]
        bbs: List or array of ground truth boxes in [x1, y1, x2, y2]
        save_path: Directory path to save the image (optional)
        image_id: Unique identifier for the saved image name (optional)
        title: Title of the plot (optional)
    """
    try:
        img_np = unnormalize_tensor(img)  # expected to be in [0, 255]
        img_np = np.transpose(img_np, (1, 2, 0))  # (H, W, C)

        plt.figure(figsize=(10, 6))
        plt.imshow(img_np.astype(np.uint8))
        for bb in bbs:
            plt.gca().add_patch(create_ground_truth_rect(bb))
        plt.title(title)
        plt.axis('off')

        if save_path and image_id is not None:
            os.makedirs(save_path, exist_ok=True)
            filename = os.path.join(save_path, f"gt_{image_id}.png")
            plt.savefig(filename, bbox_inches='tight', pad_inches=0)
        plt.close()
    except Exception as e:
        print(f"[Error saving/showing ground truth boxes]: {e}")


In [None]:
print("Validation on training data:")
validate(trained_rpn, small_train_loader, req_features)