In [None]:
!pip install albumentations
!pip install tensorflow

In [None]:
!pip install ultralytics

In [2]:
print("PyTorch version:", torch.__version__)
print("torchvision version:", torchvision.__version__)

PyTorch version: 2.6.0+cu124
torchvision version: 0.21.0+cu124


In [1]:
import torch
import torchvision

In [6]:
import os
import xml.etree.ElementTree as ET
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cv2
import matplotlib.pyplot as plt
# Import necessary libraries
from PIL import Image as PILImage
from IPython.display import Image, display
from torchvision.models.detection import fasterrcnn_resnet50_fpn, fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torchvision
import random
from sklearn.metrics import precision_recall_fscore_support



class HomelessDataset(Dataset):
    def __init__(self, data_dir, xml_files, transform=None, augmentation_type="basic"):
        self.data_dir = data_dir
        self.xml_files = xml_files
        self.transform = transform
        self.augmentation_type = augmentation_type  # "basic" or "mosaic_mixup"
        self.class_dict = {
            'Homeless_People': 1,
            'Homeless_Encampments': 2,
            'Homeless_Homeless Cart': 3,
            'Homeless_Homeless Bike': 4
        }
        # Initialize class counts per image for sampling
        self.class_counts_per_image = self._count_classes_per_image()

    def _count_classes_per_image(self):
        """Count the number of each class in each image for sampling weights"""
        class_counts_per_image = []

        for xml_file in self.xml_files:
            xml_path = os.path.join(self.data_dir, xml_file)
            try:
                tree = ET.parse(xml_path)
                root = tree.getroot()

                # Initialize count dict for this image
                counts = {1: 0, 2: 0, 3: 0, 4: 0}

                # Count classes
                for obj in root.findall('object'):
                    class_name = obj.find('name').text
                    if class_name in self.class_dict:
                        class_id = self.class_dict[class_name]
                        counts[class_id] += 1

                class_counts_per_image.append(counts)
            except Exception as e:
                print(f"Error processing {xml_file}: {e}")
                # Add empty counts if there's an error
                class_counts_per_image.append({1: 0, 2: 0, 3: 0, 4: 0})

        return class_counts_per_image

    def get_sample_weights(self, target_distribution=None):
        """
        Calculate sample weights for each image based on classes present
        If target_distribution is None, use uniform distribution
        """
        # Default to uniform distribution
        if target_distribution is None:
            target_distribution = {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}

        # Count total instances by class
        total_counts = {1: 0, 2: 0, 3: 0, 4: 0}
        for counts in self.class_counts_per_image:
            for cls, count in counts.items():
                total_counts[cls] += count

        # Calculate weights
        weights = []
        for counts in self.class_counts_per_image:
            # Weight is sum of (target_pct / actual_pct) for each class in the image
            weight = 0
            for cls, count in counts.items():
                if count > 0 and total_counts[cls] > 0:
                    cls_weight = target_distribution[cls] / (total_counts[cls] / sum(total_counts.values()))
                    weight += cls_weight * count

            # If no objects, use average weight
            if weight == 0:
                weight = 1.0

            weights.append(weight)

        return weights

    def __len__(self):
        return len(self.xml_files)

    def is_valid_box(self, box):
        """Check if a bounding box is valid (has positive height and width)."""
        x1, y1, x2, y2 = box
        return x2 > x1 and y2 > y1 and x2 - x1 >= 5 and y2 - y1 >= 5  # Minimum size check

    def _get_image_and_targets(self, idx):
        """Helper function to get image and targets from XML"""
        # Parse XML
        xml_path = os.path.join(self.data_dir, self.xml_files[idx])
        tree = ET.parse(xml_path)
        root = tree.getroot()

        # Load image
        img_name = root.find('filename').text
        img_path = os.path.join(self.data_dir, img_name)
        image = cv2.imread(img_path)
        if image is None:
            # Fallback if image not found
            print(f"Warning: Image {img_name} not found, using placeholder")
            image = np.zeros((512, 512, 3), dtype=np.uint8)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Extract annotations
        boxes = []
        class_ids = []

        h, w = image.shape[:2]
        for obj in root.findall('object'):
            class_name = obj.find('name').text
            if class_name not in self.class_dict:
                continue

            bbox = obj.find('bndbox')
            xmin = max(0, int(float(bbox.find('xmin').text)))
            ymin = max(0, int(float(bbox.find('ymin').text)))
            xmax = min(w, int(float(bbox.find('xmax').text)))
            ymax = min(h, int(float(bbox.find('ymax').text)))

            box = [xmin, ymin, xmax, ymax]

            # Filter out invalid boxes
            if self.is_valid_box(box):
                boxes.append(box)
                class_ids.append(self.class_dict[class_name])


        return image, boxes, class_ids,(h, w)

    def _apply_basic_transform(self, image, boxes, class_ids, img_shape):
        h, w = img_shape
        orig_h, orig_w = h, w  # Store original dimensions

        # Apply image transforms
        if self.transform:
            image = PILImage.fromarray(image)  # Convert NumPy array (from cv2) to PIL
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)

        # Calculate scaling factors
        new_h, new_w = 512, 512  # Your standard target size
        h_scale = new_h / orig_h
        w_scale = new_w / orig_w

        # Handle case with no valid boxes
        if len(boxes) == 0:
            # Return dummy target
            dummy_boxes = torch.FloatTensor([[0, 0, 10, 10]])
            dummy_labels = torch.LongTensor([0])  # Background class

            target = {
                "boxes": dummy_boxes,
                "labels": dummy_labels,
                "image_id": torch.tensor([0]),
                "area": torch.tensor([100.0]),
                "iscrowd": torch.zeros((1,), dtype=torch.int64)
            }

            return image, target, False

        # Scale boxes according to image resize
        scaled_boxes = []
        for box in boxes:
            x1, y1, x2, y2 = box
            # Scale coordinates
            x1 = int(x1 * w_scale)
            y1 = int(y1 * h_scale)
            x2 = int(x2 * w_scale)
            y2 = int(y2 * h_scale)
            # Ensure valid box
            x1, y1 = max(0, x1), max(0, y1)
            x2, y2 = min(new_w, x2), min(new_h, y2)
            scaled_boxes.append([x1, y1, x2, y2])

        # Convert to PyTorch format
        boxes = torch.FloatTensor(scaled_boxes)
        labels = torch.LongTensor(class_ids)

        # Calculate areas
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([0]),
            "area": area,
            "iscrowd": torch.zeros((len(boxes),), dtype=torch.int64)
        }

        return image, target, True  # True indicates a valid target

    def _apply_mosaic_augmentation(self, idx):
        """Apply mosaic augmentation for detection only (no masks)"""
        # Get current image and 3 random images without applying transformations yet
        indices = [idx] + [random.randint(0, len(self.xml_files) - 1) for _ in range(3)]
        raw_images = []
        all_boxes = []
        all_class_ids = []
        all_shapes = []

        # Load all raw images first
        for mosaic_idx in indices:
            image, boxes, class_ids, img_shape = self._get_image_and_targets(mosaic_idx)
            raw_images.append(image)
            all_boxes.append(boxes)
            all_class_ids.append(class_ids)
            all_shapes.append(img_shape)

        # Create the mosaic canvas
        mosaic_size = 512
        mosaic_img = np.zeros((mosaic_size, mosaic_size, 3), dtype=np.uint8)

        # Split points
        cx, cy = mosaic_size // 2, mosaic_size // 2

        # Combined boxes and labels for the mosaic
        combined_boxes = []
        combined_labels = []

        # Fill in the mosaic with the 4 images and adjust their bounding boxes
        placements = [(0, 0), (cx, 0), (0, cy), (cx, cy)]  # Top-left, top-right, bottom-left, bottom-right

        for i in range(min(4, len(raw_images))):
            image = raw_images[i]
            boxes = all_boxes[i]
            class_ids = all_class_ids[i]
            h, w = all_shapes[i]

            # Resize the image to fit in its quadrant
            quadrant_w, quadrant_h = cx, cy
            part_img = cv2.resize(image, (quadrant_w, quadrant_h))

            # Place in the mosaic
            x_offset, y_offset = placements[i]
            mosaic_img[y_offset:y_offset+quadrant_h, x_offset:x_offset+quadrant_w] = part_img

            # Scale and offset boxes
            scale_x = quadrant_w / w
            scale_y = quadrant_h / h

            for box_idx, box in enumerate(boxes):
                if len(box) == 4:  # Ensure the box has the expected format
                    x1, y1, x2, y2 = box

                    # Scale coordinates
                    x1_new = int(x1 * scale_x) + x_offset
                    y1_new = int(y1 * scale_y) + y_offset
                    x2_new = int(x2 * scale_x) + x_offset
                    y2_new = int(y2 * scale_y) + y_offset

                    # Clip to mosaic boundaries
                    x1_new = max(0, min(mosaic_size-1, x1_new))
                    y1_new = max(0, min(mosaic_size-1, y1_new))
                    x2_new = max(0, min(mosaic_size-1, x2_new))
                    y2_new = max(0, min(mosaic_size-1, y2_new))

                    # Check if the box is still valid
                    if x2_new > x1_new and y2_new > y1_new and (x2_new - x1_new) >= 5 and (y2_new - y1_new) >= 5:
                        combined_boxes.append([x1_new, y1_new, x2_new, y2_new])
                        combined_labels.append(class_ids[box_idx])

        # Apply transformations to the whole mosaic image at once
        if self.transform:
            # Convert to PIL Image first for the transform
            mosaic_pil = PILImage.fromarray(mosaic_img)
            transformed_img = self.transform(mosaic_pil)
        else:
            transformed_img = transforms.ToTensor()(mosaic_img)

        # Handle case with no valid boxes
        if len(combined_boxes) == 0:
            dummy_boxes = torch.FloatTensor([[0, 0, 10, 10]])
            dummy_labels = torch.LongTensor([0])  # Background class

            target = {
                "boxes": dummy_boxes,
                "labels": dummy_labels,
                "image_id": torch.tensor([idx]),
                "area": torch.tensor([100.0]),
                "iscrowd": torch.zeros((1,), dtype=torch.int64)
            }
            return transformed_img, target, False

        # Convert to PyTorch format
        boxes = torch.FloatTensor(combined_boxes)
        labels = torch.LongTensor(combined_labels)

        # Calculate areas
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # Create final target
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": area,
            "iscrowd": torch.zeros((len(boxes),), dtype=torch.int64)
        }

        return transformed_img, target, True

    def _apply_mixup_augmentation(self, idx, alpha=0.5):
        """Apply mixup augmentation by blending two images and combining boxes intelligently"""
        # Get two images without transforms first
        idx2 = random.randint(0, len(self.xml_files) - 1)

        # Get raw data
        image1, boxes1, class_ids1, img_shape1 = self._get_image_and_targets(idx)
        image2, boxes2, class_ids2, img_shape2 = self._get_image_and_targets(idx2)

        # Resize both images to the same size before mixing
        h, w = 512, 512
        image1_resized = cv2.resize(image1, (w, h))
        image2_resized = cv2.resize(image2, (w, h))

        # Calculate scale factors to adjust bounding boxes
        h1, w1 = img_shape1
        h2, w2 = img_shape2

        scale_x1, scale_y1 = w/w1, h/h1
        scale_x2, scale_y2 = w/w2, h/h2

        # Scale boxes for both images
        scaled_boxes1 = []
        for box in boxes1:
            x1, y1, x2, y2 = box
            x1s, y1s = int(x1 * scale_x1), int(y1 * scale_y1)
            x2s, y2s = int(x2 * scale_x1), int(y2 * scale_y1)
            # Clip to boundaries
            x1s, y1s = max(0, x1s), max(0, y1s)
            x2s, y2s = min(w, x2s), min(h, y2s)
            # Check if valid
            if x2s > x1s and y2s > y1s and (x2s - x1s) >= 5 and (y2s - y1s) >= 5:
                scaled_boxes1.append([x1s, y1s, x2s, y2s])

        scaled_boxes2 = []
        for box in boxes2:
            x1, y1, x2, y2 = box
            x1s, y1s = int(x1 * scale_x2), int(y1 * scale_y2)
            x2s, y2s = int(x2 * scale_x2), int(y2 * scale_y2)
            # Clip to boundaries
            x1s, y1s = max(0, x1s), max(0, y1s)
            x2s, y2s = min(w, x2s), min(h, y2s)
            # Check if valid
            if x2s > x1s and y2s > y1s and (x2s - x1s) >= 5 and (y2s - y1s) >= 5:
                scaled_boxes2.append([x1s, y1s, x2s, y2s])

        # If either image has no valid boxes after scaling, fall back to basic transform
        if not scaled_boxes1 and not scaled_boxes2:
            return self._apply_basic_transform(image1, boxes1, class_ids1, img_shape1)

        # Random mixup ratio from beta distribution
        lam = np.random.beta(alpha, alpha)

        # Blend the images
        mixed_img_np = cv2.addWeighted(image1_resized, lam, image2_resized, 1 - lam, 0)

        # Handle boxes based on mixup ratio
        # This approach is smarter: if one image is dominant, we prioritize its boxes
        combined_boxes = []
        combined_labels = []

        # If lambda > 0.7, prioritize boxes from image1
        if lam > 0.7:
            for i, box in enumerate(scaled_boxes1):
                if i < len(class_ids1):
                    combined_boxes.append(box)
                    combined_labels.append(class_ids1[i])
        # If lambda < 0.3, prioritize boxes from image2
        elif lam < 0.3:
            for i, box in enumerate(scaled_boxes2):
                if i < len(class_ids2):
                    combined_boxes.append(box)
                    combined_labels.append(class_ids2[i])
        # For balanced mix, use boxes from both images
        else:
            for i, box in enumerate(scaled_boxes1):
                if i < len(class_ids1):
                    combined_boxes.append(box)
                    combined_labels.append(class_ids1[i])
            for i, box in enumerate(scaled_boxes2):
                if i < len(class_ids2):
                    combined_boxes.append(box)
                    combined_labels.append(class_ids2[i])

        # Apply transform to the mixed image
        if self.transform:
            # Convert to PIL Image first for the transform
            mixed_pil = PILImage.fromarray(mixed_img_np)
            transformed_img = self.transform(mixed_pil)
        else:
            transformed_img = transforms.ToTensor()(mixed_img_np)

        # Handle case with no valid boxes
        if len(combined_boxes) == 0:
            dummy_boxes = torch.FloatTensor([[0, 0, 10, 10]])
            dummy_labels = torch.LongTensor([0])  # Background class

            target = {
                "boxes": dummy_boxes,
                "labels": dummy_labels,
                "image_id": torch.tensor([idx]),
                "area": torch.tensor([100.0]),
                "iscrowd": torch.zeros((1,), dtype=torch.int64)
            }
            return transformed_img, target, False

        # Convert to PyTorch format
        boxes = torch.FloatTensor(combined_boxes)
        labels = torch.LongTensor(combined_labels)

        # Calculate areas
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # Create final target
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": area,
            "iscrowd": torch.zeros((len(boxes),), dtype=torch.int64)
        }

        return transformed_img, target, True


    def __getitem__(self, idx):
        """Get item with appropriate augmentation strategy"""
        if self.augmentation_type == "basic":
            # Basic augmentation
            image, boxes, class_ids,img_shape = self._get_image_and_targets(idx)
            return self._apply_basic_transform(image, boxes, class_ids,img_shape)
        elif self.augmentation_type == "mosaic_mixup":
            # Original implementation
            p = random.random()
            if p < 0.4:  # 40% just mosaic
                return self._apply_mosaic_augmentation(idx)
            elif p < 0.8:  # 40% mosaic+mixup
                return self._apply_mixup_augmentation(idx)
            else:  # 20% just basic
                image, boxes, class_ids, img_shape = self._get_image_and_targets(idx)
                return self._apply_basic_transform(image, boxes, class_ids, img_shape)
        elif self.augmentation_type == "mosaic_only":
            # 50% mosaic, 50% basic
            p = random.random()
            if p < 0.5:
                return self._apply_mosaic_augmentation(idx)
            else:
                image, boxes, class_ids,img_shape = self._get_image_and_targets(idx)
                return self._apply_basic_transform(image, boxes, class_ids, img_shape)

def get_basic_transform(train):
    if train:
        return transforms.Compose([
            # REMOVE transforms.ToPILImage(),
            transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            # REMOVE transforms.ToPILImage(),
            transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

def get_model(num_classes):

    model = fasterrcnn_resnet50_fpn_v2(pretrained=True)

    # Replace the classifier head
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

def collate_fn(batch):
    """Collate function for detection-only tasks (no masks)"""
    images = []
    targets = []

    for img, target, is_valid in batch:
        if is_valid:
            images.append(img)
            targets.append(target)

    if len(images) == 0:
        dummy_img = torch.zeros((3, 512, 512))
        dummy_target = {
            "boxes": torch.FloatTensor([[0, 0, 10, 10]]),
            "labels": torch.LongTensor([0]),
            "image_id": torch.tensor([0]),
            "area": torch.tensor([100.0]),
            "iscrowd": torch.zeros((1,), dtype=torch.int64)
        }
        images = [dummy_img]
        targets = [dummy_target]

    return images, targets


def calculate_iou(box1, box2):
    """Calculate IoU between two boxes"""
    # Convert to coordinates format if needed
    if isinstance(box1, torch.Tensor):
        box1 = box1.cpu().numpy()
    if isinstance(box2, torch.Tensor):
        box2 = box2.cpu().numpy()

    # Intersection coordinates
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    # Intersection area
    intersection = max(0, x2 - x1) * max(0, y2 - y1)

    # Union area
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = box1_area + box2_area - intersection

    # Calculate IoU
    iou = intersection / max(union, 1e-6)

    return iou

def calculate_metrics(pred_boxes, pred_labels, pred_scores, gt_boxes, gt_labels, iou_threshold=0.6):
    """
    Calculate mAP, IoU, and F1 scores for each class
    """
    if len(gt_boxes) == 0:
        if len(pred_boxes) == 0:
            return {
                "mAP": 1.0,
                "IoU": 1.0,
                "class_metrics": {
                    cls: {"precision": 1.0, "recall": 1.0, "f1": 1.0, "AP": 1.0, "IoU": 1.0}
                    for cls in range(1, 5)
                }
            }
        else:
            return {
                "mAP": 0.0,
                "IoU": 0.0,
                "class_metrics": {
                    cls: {"precision": 0.0, "recall": 0.0, "f1": 0.0, "AP": 0.0, "IoU": 0.0}
                    for cls in range(1, 5)
                }
            }

    # Track metrics for each class
    class_metrics = {}
    overall_iou = 0.0
    num_valid_ious = 0

    # Process each class separately
    for class_id in range(1, 5):  # 1-4 for the classes
        # Get predictions and ground truth for this class
        class_pred_indices = [i for i, lbl in enumerate(pred_labels) if lbl == class_id]
        class_gt_indices = [i for i, lbl in enumerate(gt_labels) if lbl == class_id]

        # Initialize class metrics
        class_metrics[class_id] = {
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0,
            "AP": 0.0,
            "IoU": 0.0
        }

        if not class_gt_indices:  # No ground truth for this class
            if not class_pred_indices:  # No predictions for this class either
                class_metrics[class_id] = {
                    "precision": 1.0,
                    "recall": 1.0,
                    "f1": 1.0,
                    "AP": 1.0,
                    "IoU": 1.0
                }
            # Otherwise all metrics stay at 0
            continue

        if not class_pred_indices:  # No predictions for this class
            # All metrics stay at 0
            continue

        # Get boxes, scores for this class
        c_pred_boxes = [pred_boxes[i] for i in class_pred_indices]
        c_pred_scores = [pred_scores[i] for i in class_pred_indices]
        c_gt_boxes = [gt_boxes[i] for i in class_gt_indices]

        # Sort predictions by confidence
        c_pred_boxes, c_pred_scores = zip(*sorted(zip(c_pred_boxes, c_pred_scores),
                                              key=lambda x: x[1], reverse=True))

        # For metrics calculation
        true_positives = np.zeros(len(c_pred_boxes))
        false_positives = np.zeros(len(c_pred_boxes))
        gt_matched = [False] * len(c_gt_boxes)

        # Track IoUs for this class
        class_ious = []

        # Check each prediction
        for pred_idx, pred_box in enumerate(c_pred_boxes):
            # Find best matching ground truth
            best_iou = 0
            best_gt_idx = -1

            for gt_idx, gt_box in enumerate(c_gt_boxes):
                if gt_matched[gt_idx]:
                    continue  # This gt already matched

                iou = calculate_iou(pred_box, gt_box)
                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = gt_idx

            # Store IoU for metrics
            if best_iou > 0:
                class_ious.append(best_iou)

            # Check if we have a match
            if best_iou >= iou_threshold:
                if not gt_matched[best_gt_idx]:
                    true_positives[pred_idx] = 1
                    gt_matched[best_gt_idx] = True
                else:
                    false_positives[pred_idx] = 1
            else:
                false_positives[pred_idx] = 1

        # Compute cumulative values
        cumsum_tp = np.cumsum(true_positives)
        cumsum_fp = np.cumsum(false_positives)

        # Calculate precision and recall
        precision = cumsum_tp / (cumsum_tp + cumsum_fp + 1e-10)
        recall = cumsum_tp / len(c_gt_boxes)

        # Compute average precision using the 11-point interpolation
        ap = 0
        for r in np.arange(0, 1.1, 0.1):
            if np.sum(recall >= r) == 0:
                p = 0
            else:
                p = np.max(precision[recall >= r])
            ap += p / 11

        # Calculate F1 score
        if len(precision) > 0 and len(recall) > 0:
            f1 = 2 * (precision[-1] * recall[-1]) / (precision[-1] + recall[-1] + 1e-10)
        else:
            f1 = 0.0

        # Calculate average IoU for this class
        avg_iou = np.mean(class_ious) if class_ious else 0.0

        # Store metrics for this class
        class_metrics[class_id] = {
            "precision": float(precision[-1]) if len(precision) > 0 else 0.0,
            "recall": float(recall[-1]) if len(recall) > 0 else 0.0,
            "f1": float(f1),
            "AP": float(ap),
            "IoU": float(avg_iou)
        }

        # Update overall IoU
        if avg_iou > 0:
            overall_iou += avg_iou
            num_valid_ious += 1

    # Calculate mAP across all classes
    mAP = sum(class_metrics[cls]["AP"] for cls in class_metrics) / len(class_metrics)

    # Calculate overall IoU
    overall_iou = overall_iou / max(1, num_valid_ious)

    # Return all metrics
    return {
        "mAP": float(mAP),
        "IoU": float(overall_iou),
        "class_metrics": class_metrics
    }

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
    """Basic training for one epoch without loss reweighting"""
    model.train()
    running_loss = 0.0

    # Warmup scheduler for first epoch
    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)
        lr_scheduler = optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    for i, (images, targets) in enumerate(data_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass and get loss
        loss_dict = model(images, targets)

        # Sum losses
        losses = sum(loss for loss in loss_dict.values())

        # Handle NaN loss
        if not torch.isfinite(losses):
            print(f"Loss is {losses}, skipping batch")
            optimizer.zero_grad()
            continue

        # Backprop
        optimizer.zero_grad()
        losses.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

        optimizer.step()

        # Update LR if needed
        if lr_scheduler is not None:
            lr_scheduler.step()

        # Track loss
        running_loss += losses.item()

        # Print progress
        if i % print_freq == 0:
            print(f"Epoch {epoch}, Batch {i}/{len(data_loader)}, Loss: {losses.item():.4f}")

    return running_loss / len(data_loader)

def train_one_epoch_reweighed(model, optimizer, data_loader, device, epoch, class_weights, print_freq=10):
    """Custom version of train_one_epoch that applies class weights to the classification loss"""
    model.train()
    running_loss = 0.0

    # Warmup scheduler
    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)
        lr_scheduler = optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    for i, (images, targets) in enumerate(data_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass and get loss
        loss_dict = model(images, targets)

        # Apply class weights to classification loss
        if "loss_classifier" in loss_dict:
            # Get all target labels from this batch
            batch_labels = []
            for t in targets:
                batch_labels.extend(t["labels"].cpu().numpy())

            # Calculate a weighted factor based on labels in this batch
            if batch_labels:
                weight_factor = 0
                for label in batch_labels:
                    weight_factor += class_weights.get(label, 1.0)
                weight_factor /= len(batch_labels)

                # Apply weight to classification loss
                loss_dict["loss_classifier"] *= weight_factor

        # Sum losses
        losses = sum(loss for loss in loss_dict.values())

        # Handle NaN loss
        if not torch.isfinite(losses):
            print(f"Loss is {losses}, skipping batch")
            optimizer.zero_grad()
            continue

        # Backprop
        optimizer.zero_grad()
        losses.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

        optimizer.step()

        # Update LR if needed
        if lr_scheduler is not None:
            lr_scheduler.step()

        # Track loss
        running_loss += losses.item()

        # Print progress
        if i % print_freq == 0:
            print(f"Epoch {epoch}, Batch {i}/{len(data_loader)}, Loss: {losses.item():.4f}")

    return running_loss / len(data_loader)

def evaluate_model(model, data_loader, device):
    """Evaluate model with comprehensive metrics"""
    model.eval()
    metrics_accumulator = {
        "mAP": [],
        "IoU": [],
        "class_metrics": {
            1: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []},
            2: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []},
            3: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []},
            4: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []}
        }
    }

    with torch.no_grad():
        for images, targets in data_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Standard inference
            results = model(images)

            # Calculate metrics for each image
            for i, (result, target) in enumerate(zip(results, targets)):
                pred_boxes = result['boxes'].cpu()
                pred_labels = result['labels'].cpu()
                pred_scores = result['scores'].cpu()

                gt_boxes = target['boxes'].cpu()
                gt_labels = target['labels'].cpu()

                # Calculate all metrics
                metrics = calculate_metrics(
                    pred_boxes, pred_labels, pred_scores,
                    gt_boxes, gt_labels, iou_threshold=0.5
                )

                # Accumulate metrics
                metrics_accumulator["mAP"].append(metrics["mAP"])
                metrics_accumulator["IoU"].append(metrics["IoU"])

                # Accumulate class-specific metrics
                for cls in range(1, 5):
                    if cls in metrics["class_metrics"]:
                        for metric_name, value in metrics["class_metrics"][cls].items():
                            metrics_accumulator["class_metrics"][cls][metric_name].append(value)

    # Calculate average metrics
    avg_metrics = {
        "mAP": np.mean(metrics_accumulator["mAP"]) if metrics_accumulator["mAP"] else 0.0,
        "IoU": np.mean(metrics_accumulator["IoU"]) if metrics_accumulator["IoU"] else 0.0,
        "class_metrics": {}
    }

    # Calculate class averages
    for cls in range(1, 5):
        avg_metrics["class_metrics"][cls] = {}
        for metric_name in ["precision", "recall", "f1", "AP", "IoU"]:
            values = metrics_accumulator["class_metrics"][cls][metric_name]
            avg_metrics["class_metrics"][cls][metric_name] = np.mean(values) if values else 0.0

    return avg_metrics

def visualize_predictions(model, dataset, device, num_images=5, confidence_thresholds=None, save_path=None):
    """
    Visualize model predictions with class-specific confidence thresholds

    Parameters:
    - model: Trained model
    - dataset: Dataset containing validation images
    - device: Device to run inference on
    - num_images: Number of images to visualize
    - confidence_thresholds: Dictionary mapping class IDs to confidence thresholds
    - save_path: Path to save visualization (if None, uses default name)

    Returns:
    - Path to saved visualization
    """
    model.eval()

    # Default confidence thresholds if none provided
    if confidence_thresholds is None:
        confidence_thresholds = {1: 0.55, 2: 0.8, 3: 0.45, 4: 0.65}

    # Class names and colors for display
    class_names = {
        1: 'People',
        2: 'Encampments',
        3: 'Cart',
        4: 'Bike'
    }

    print(f"Using confidence thresholds: {confidence_thresholds}")

    class_colors = {
        1: (255, 0, 0),    # Red for people
        2: (0, 255, 0),    # Green for encampments
        3: (0, 0, 255),    # Blue for carts
        4: (255, 255, 0)   # Yellow for bikes
    }

    # Get random samples
    indices = np.random.choice(len(dataset), min(num_images, len(dataset)), replace=False)

    plt.figure(figsize=(20, 20))

    for i, idx in enumerate(indices):
        img, target, valid = dataset[idx]
        if not valid:
            continue

        # Simple inference
        with torch.no_grad():
            prediction = model([img.to(device)])[0]

        print(f"\nImage {idx} detection scores:")
        for class_id in range(1, 5):
            scores = [prediction['scores'][j].item() for j in range(len(prediction['scores']))
                      if prediction['labels'][j].item() == class_id]
            print(f"Class {class_id} ({class_names[class_id]}): {scores}")

        # Apply class-specific confidence thresholds
        keep_indices = []
        for j, label in enumerate(prediction['labels']):
            label_id = label.item()
            if prediction['scores'][j] >= confidence_thresholds.get(label_id, 0.5):
                keep_indices.append(j)


        # Replace with this more careful conversion
        boxes = prediction['boxes'][keep_indices].cpu().numpy()
        # Ensure boxes are within image boundaries before converting to int
        boxes[:, 0] = np.clip(boxes[:, 0], 0, img.shape[2] - 1)
        boxes[:, 1] = np.clip(boxes[:, 1], 0, img.shape[1] - 1)
        boxes[:, 2] = np.clip(boxes[:, 2], 0, img.shape[2] - 1)
        boxes[:, 3] = np.clip(boxes[:, 3], 0, img.shape[1] - 1)
        boxes = boxes.astype(np.int32)
        labels = prediction['labels'][keep_indices].cpu().numpy()
        scores = prediction['scores'][keep_indices].cpu().numpy()

        # Convert image back to numpy for display
        image_np = img.permute(1, 2, 0).cpu().numpy()

        # Denormalize
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image_np = std * image_np + mean
        image_np = np.clip(image_np, 0, 1)
        image_np = (image_np * 255).astype(np.uint8)

        # Create a copy for drawing
        image_with_boxes = image_np.copy()

        # Draw ground truth boxes
        gt_boxes = target['boxes'].cpu().numpy().astype(np.int32)
        gt_labels = target['labels'].cpu().numpy()

        # Draw ground truth first
        for box, label in zip(gt_boxes, gt_labels):
            color = class_colors.get(label.item(), (255, 255, 255))
            cv2.rectangle(image_with_boxes, (box[0], box[1]), (box[2], box[3]),
                         color, 2, cv2.LINE_AA)
            cv2.putText(image_with_boxes, f"GT: {class_names.get(label.item(), 'Unknown')}",
                       (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        # Draw predictions
        for box, label, score in zip(boxes, labels, scores):
            color = class_colors.get(label.item(), (255, 255, 255))
            cv2.rectangle(image_with_boxes, (box[0], box[1]), (box[2], box[3]),
                         color, 2, cv2.LINE_AA)
            cv2.putText(image_with_boxes,
                       f"{class_names.get(label.item(), 'Unknown')}: {score:.2f}",
                       (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        # Display
        plt.subplot(num_images, 1, i + 1)
        plt.imshow(image_with_boxes)
        plt.title(f"Sample {idx} (Inference with Class-Specific Thresholds)")
        plt.axis('off')

    plt.tight_layout()

    if save_path is None:
        save_path = "mask_rcnn_predictions.png"

    plt.savefig(save_path)
    plt.close()

    return save_path


def plot_training_history(experiment_results, save_path="training_history.png"):
    """
    Plot training history comparing multiple experiments

    Parameters:
    - experiment_results: List of dictionaries containing experiment results
    - save_path: Path to save the plot

    Returns:
    - Path to the saved plot
    """
    fig, axs = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Training Results Comparison', fontsize=16)

    # Colors for different experiments
    colors = ['b', 'r', 'g', 'c']

    # Plot training loss
    ax = axs[0, 0]
    for i, result in enumerate(experiment_results):
        ax.plot(result["history"]["train_loss"],
                label=f"{result['experiment_name']}",
                color=colors[i])
    ax.set_title('Training Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()

    # Plot validation mAP
    ax = axs[0, 1]
    for i, result in enumerate(experiment_results):
        ax.plot(result["history"]["val_mAP"],
                label=f"{result['experiment_name']}",
                color=colors[i])
    ax.set_title('Validation mAP')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('mAP')
    ax.legend()

    # Plot validation IoU
    ax = axs[0, 2]
    for i, result in enumerate(experiment_results):
        ax.plot(result["history"]["val_IoU"],
                label=f"{result['experiment_name']}",
                color=colors[i])
    ax.set_title('Validation IoU')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('IoU')
    ax.legend()

    # Plot class F1 scores
    class_names = {1: 'People', 2: 'Encampments', 3: 'Cart', 4: 'Bike'}
    class_metrics = ['f1', 'AP', 'IoU']
    metric_idx = 0  # F1 score

    ax = axs[1, 0]
    for cls in range(1, 5):
        ax.plot([], [], label=class_names[cls], color=f'C{cls}')
    for i, result in enumerate(experiment_results):
        linestyle = '-' if i % 2 == 0 else '--'
        for cls in range(1, 5):
            ax.plot(result["history"]["class_metrics"][cls]["f1"],
                    color=f'C{cls}', linestyle=linestyle, alpha=0.7)
    ax.set_title(f'Class F1 Scores')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('F1 Score')
    ax.legend()

    # Plot class AP scores
    ax = axs[1, 1]
    for i, result in enumerate(experiment_results):
        for cls in range(1, 5):
            ax.plot(result["history"]["class_metrics"][cls]["AP"],
                    label=f"{class_names[cls]} - {result['experiment_name']}" if i == 0 else "",
                    color=f'C{cls}', linestyle='-' if i % 2 == 0 else '--', alpha=0.7)
    ax.set_title('Class AP Scores')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('AP')
    ax.legend()

    # Plot class IoU scores
    ax = axs[1, 2]
    for i, result in enumerate(experiment_results):
        for cls in range(1, 5):
            ax.plot(result["history"]["class_metrics"][cls]["IoU"],
                    label=f"{class_names[cls]} - {result['experiment_name']}" if i == 0 else "",
                    color=f'C{cls}', linestyle='-' if i % 2 == 0 else '--', alpha=0.7)
    ax.set_title('Class IoU Scores')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('IoU')
    ax.legend()

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

    return save_path

def visualize_augmentations_with_boxes(data_dir, xml_files, transform=None):
    """
    Visualize data augmentation techniques with bounding boxes for all types

    Parameters:
    - data_dir: Directory containing images and XML files
    - xml_files: List of XML files
    - transform: Optional transform to apply
    """
    plt.figure(figsize=(15, 15))

    # Create a dataset just for visualization
    viz_dataset = HomelessDataset(
        data_dir,
        xml_files[:50],  # Use a subset
        transform=None,  # Important: don't normalize during visualization
        augmentation_type="basic"
    )

    # Pick a random index with objects
    for _ in range(10):  # Try up to 10 random indices
        idx = np.random.randint(0, len(viz_dataset))
        raw_image, boxes, class_ids, img_shape = viz_dataset._get_image_and_targets(idx)
        if len(boxes) > 0:
            break

    # Define class colors
    class_colors = {
        1: (0, 255, 0),     # Green for People
        2: (0, 0, 255),     # Blue for Encampments
        3: (255, 0, 0),     # Red for Cart
        4: (255, 255, 0)    # Yellow for Bike
    }

    class_names = {
        1: 'People',
        2: 'Encampments',
        3: 'Cart',
        4: 'Bike'
    }

    # 1. Show basic image with boxes
    plt.subplot(3, 1, 1)
    image_with_boxes = raw_image.copy()
    for box, class_id in zip(boxes, class_ids):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))
        cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
        cv2.putText(image_with_boxes, class_names.get(class_id, "Unknown"),
                   (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    plt.imshow(image_with_boxes)
    plt.title("Basic (No Augmentation)")
    plt.axis('off')

    # 2. Show mosaic augmentation with boxes
    plt.subplot(3, 1, 2)

    # Create a direct mosaic
    indices = [idx] + [random.randint(0, len(viz_dataset.xml_files) - 1) for _ in range(3)]
    mosaic_img = np.zeros((512, 512, 3), dtype=np.uint8)

    # For tracking transformed boxes
    mosaic_boxes = []
    mosaic_class_ids = []

    # Place images in mosaic grid (2x2)
    cx, cy = 256, 256  # Center point
    placements = [(0, 0), (cx, 0), (0, cy), (cx, cy)]  # Top-left, top-right, bottom-left, bottom-right

    for i, mosaic_idx in enumerate(indices):
        img, img_boxes, img_class_ids, img_shape = viz_dataset._get_image_and_targets(mosaic_idx)
        h, w = img_shape

        # Place in the mosaic
        x_offset, y_offset = placements[i]

        # Resize image to fit quadrant
        quadrant_w, quadrant_h = cx, cy
        part_img = cv2.resize(img, (quadrant_w, quadrant_h))
        mosaic_img[y_offset:y_offset+quadrant_h, x_offset:x_offset+quadrant_w] = part_img

        # Scale and offset boxes
        scale_x = quadrant_w / w
        scale_y = quadrant_h / h

        for box_idx, box in enumerate(img_boxes):
            if len(box) == 4:  # Ensure the box has the expected format
                x1, y1, x2, y2 = box

                # Scale coordinates
                x1_new = int(x1 * scale_x) + x_offset
                y1_new = int(y1 * scale_y) + y_offset
                x2_new = int(x2 * scale_x) + x_offset
                y2_new = int(y2 * scale_y) + y_offset

                # Clip to mosaic boundaries
                x1_new = max(0, min(512-1, x1_new))
                y1_new = max(0, min(512-1, y1_new))
                x2_new = max(0, min(512-1, x2_new))
                y2_new = max(0, min(512-1, y2_new))

                # Check if the box is still valid
                if x2_new > x1_new and y2_new > y1_new and (x2_new - x1_new) >= 5 and (y2_new - y1_new) >= 5:
                    mosaic_boxes.append([x1_new, y1_new, x2_new, y2_new])
                    if box_idx < len(img_class_ids):
                        mosaic_class_ids.append(img_class_ids[box_idx])
                    else:
                        mosaic_class_ids.append(1)  # Default to class 1 if out of bounds

    # Draw boxes on mosaic
    mosaic_with_boxes = mosaic_img.copy()
    for box, class_id in zip(mosaic_boxes, mosaic_class_ids):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))
        cv2.rectangle(mosaic_with_boxes, (x1, y1), (x2, y2), color, 2)
        cv2.putText(mosaic_with_boxes, class_names.get(class_id, "Unknown"),
                   (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    plt.imshow(mosaic_with_boxes)
    plt.title("Mosaic Augmentation")
    plt.axis('off')

    # 3. Show mixup augmentation with boxes
    plt.subplot(3, 1, 3)

    # Create a direct mixup without normalization for visualization
    idx2 = random.randint(0, len(viz_dataset.xml_files) - 1)
    img1, boxes1, class_ids1, img_shape1 = viz_dataset._get_image_and_targets(idx)
    img2, boxes2, class_ids2, img_shape2 = viz_dataset._get_image_and_targets(idx2)

    # Resize both images to the same size
    h, w = 512, 512
    img1_resized = cv2.resize(img1, (w, h))
    img2_resized = cv2.resize(img2, (w, h))

    # Calculate scale factors to adjust bounding boxes
    h1, w1 = img_shape1
    h2, w2 = img_shape2

    scale_x1, scale_y1 = w/w1, h/h1
    scale_x2, scale_y2 = w/w2, h/h2

    # Scale boxes for both images
    scaled_boxes1 = []
    scaled_class_ids1 = []
    for box_idx, box in enumerate(boxes1):
        x1, y1, x2, y2 = box
        x1s, y1s = int(x1 * scale_x1), int(y1 * scale_y1)
        x2s, y2s = int(x2 * scale_x1), int(y2 * scale_y1)
        # Clip to boundaries
        x1s, y1s = max(0, x1s), max(0, y1s)
        x2s, y2s = min(w, x2s), min(h, y2s)
        # Check if valid
        if x2s > x1s and y2s > y1s and (x2s - x1s) >= 5 and (y2s - y1s) >= 5:
            scaled_boxes1.append([x1s, y1s, x2s, y2s])
            if box_idx < len(class_ids1):
                scaled_class_ids1.append(class_ids1[box_idx])

    scaled_boxes2 = []
    scaled_class_ids2 = []
    for box_idx, box in enumerate(boxes2):
        x1, y1, x2, y2 = box
        x1s, y1s = int(x1 * scale_x2), int(y1 * scale_y2)
        x2s, y2s = int(x2 * scale_x2), int(y2 * scale_y2)
        # Clip to boundaries
        x1s, y1s = max(0, x1s), max(0, y1s)
        x2s, y2s = min(w, x2s), min(h, y2s)
        # Check if valid
        if x2s > x1s and y2s > y1s and (x2s - x1s) >= 5 and (y2s - y1s) >= 5:
            scaled_boxes2.append([x1s, y1s, x2s, y2s])
            if box_idx < len(class_ids2):
                scaled_class_ids2.append(class_ids2[box_idx])

    # Mix the images with a fixed alpha for visualization clarity
    alpha = 0.5
    mixed_img = cv2.addWeighted(img1_resized, alpha, img2_resized, 1.0 - alpha, 0)

    # Draw boxes from both images
    mixed_with_boxes = mixed_img.copy()

    # Draw boxes from first image with solid lines
    for box, class_id in zip(scaled_boxes1, scaled_class_ids1):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))
        cv2.rectangle(mixed_with_boxes, (x1, y1), (x2, y2), color, 2)
        cv2.putText(mixed_with_boxes, f"{class_names.get(class_id, 'Unknown')} (img1)",
                   (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    # Draw boxes from second image with dashed lines (to distinguish them)
    for box, class_id in zip(scaled_boxes2, scaled_class_ids2):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))

        # Create dashed line effect (crude but effective for visualization)
        for i in range(x1, x2, 5):
            cv2.line(mixed_with_boxes, (i, y1), (min(i+3, x2), y1), color, 2)
            cv2.line(mixed_with_boxes, (i, y2), (min(i+3, x2), y2), color, 2)
        for i in range(y1, y2, 5):
            cv2.line(mixed_with_boxes, (x1, i), (x1, min(i+3, y2)), color, 2)
            cv2.line(mixed_with_boxes, (x2, i), (x2, min(i+3, y2)), color, 2)

        cv2.putText(mixed_with_boxes, f"{class_names.get(class_id, 'Unknown')} (img2)",
                   (x1, y2+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    plt.imshow(mixed_with_boxes)
    plt.title("Mixup Augmentation")
    plt.axis('off')

    plt.tight_layout()
    plt.savefig("augmentation_examples_with_boxes.png")
    plt.close()

    return "augmentation_examples_with_boxes.png"


def grid_search_thresholds(data_dir, train_xml_files, val_xml_files, num_epochs=5):
    """
    Perform grid search to find optimal confidence thresholds for each class.

    Parameters:
    - data_dir: Directory containing images and XML files
    - train_xml_files: List of XML files for training
    - val_xml_files: List of XML files for validation
    - num_epochs: Number of training epochs for the grid search model

    Returns:
    - Dictionary with optimal confidence thresholds for each class
    """
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(f"Using device: {device}")

    print(f"\n{'='*80}")
    print(f"GRID SEARCH FOR OPTIMAL CONFIDENCE THRESHOLDS")
    print(f"{'='*80}\n")

    # Create datasets (basic augmentation, no special sampling)
    train_dataset = HomelessDataset(data_dir, train_xml_files,
                                  transform=get_basic_transform(train=True),
                                  augmentation_type="basic")
    val_dataset = HomelessDataset(data_dir, val_xml_files,
                                transform=get_basic_transform(train=False),
                                augmentation_type="basic")

    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")

    # Create data loaders (random sampling)
    train_loader = DataLoader(
        train_dataset,
        batch_size=4,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=4,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2
    )

    # Initialize model
    num_classes = 5  # Background + 4 classes
    model = get_model(num_classes)
    model.to(device)

    # Use SGD optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params,
        lr=0.005,
        momentum=0.9,
        weight_decay=0.0005
    )

    # Learning rate scheduler - cosine annealing
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs,
        eta_min=1e-6
    )

    # Train the model for a few epochs
    print("\nTraining model for grid search...")
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(
            model, optimizer, train_loader, device, epoch, print_freq=10
        )

        # Step LR scheduler
        lr_scheduler.step()

        print(f"Epoch {epoch}, Loss: {train_loss:.4f}")

    # Save the model
    grid_search_model_path = "grid_search_model.pth"
    torch.save(model.state_dict(), grid_search_model_path)
    print(f"Model saved to {grid_search_model_path}")

    # Set up grid search ranges for thresholds
    threshold_ranges = {
        1: np.arange(0.3, 0.8, 0.05),  # People
        2: np.arange(0.3, 0.9, 0.05),  # Encampments
        3: np.arange(0.3, 0.8, 0.05),  # Cart
        4: np.arange(0.3, 0.8, 0.05)   # Bike
    }

    # Class names for display
    class_names = {
        1: 'People',
        2: 'Encampments',
        3: 'Cart',
        4: 'Bike'
    }

    # Initialize best thresholds and metrics
    best_thresholds = {cls: 0.5 for cls in range(1, 5)}
    best_metrics = {cls: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'AP': 0.0} for cls in range(1, 5)}

    # Grid search for each class separately
    model.eval()

    for class_id in range(1, 5):
        print(f"\n{'='*80}")
        print(f"GRID SEARCH FOR CLASS {class_id} ({class_names[class_id]})")
        print(f"{'='*80}")

        print(f"{'Threshold':<10} {'Precision':<10} {'Recall':<10} {'F1':<10} {'AP':<10}")
        print("-" * 50)

        # Keep other thresholds fixed at 0.5 (will optimize later)
        current_thresholds = {cls: 0.5 for cls in range(1, 5)}

        # Try different thresholds for the current class
        for threshold in threshold_ranges[class_id]:
            current_thresholds[class_id] = threshold

            # Evaluate with current thresholds
            class_metrics = evaluate_class_threshold(
                model, val_loader, device, class_id, current_thresholds
            )

            # Print results - ensure all keys exist
            precision = class_metrics.get('precision', 0.0)
            recall = class_metrics.get('recall', 0.0)
            f1 = class_metrics.get('f1', 0.0)
            ap = class_metrics.get('AP', 0.0)

            print(f"{threshold:<10.2f} {precision:<10.4f} "
                  f"{recall:<10.4f} {f1:<10.4f} "
                  f"{ap:<10.4f}")

            # Update best threshold if F1 score is higher
            if f1 > best_metrics[class_id]['f1']:
                best_metrics[class_id] = class_metrics.copy()
                best_thresholds[class_id] = threshold

        print(f"\nBest threshold for {class_names[class_id]}: {best_thresholds[class_id]:.2f}")
        best_f1 = best_metrics[class_id].get('f1', 0.0)
        best_precision = best_metrics[class_id].get('precision', 0.0)
        best_recall = best_metrics[class_id].get('recall', 0.0)
        best_ap = best_metrics[class_id].get('AP', 0.0)

        print(f"Best F1: {best_f1:.4f}, "
              f"Precision: {best_precision:.4f}, "
              f"Recall: {best_recall:.4f}, "
              f"AP: {best_ap:.4f}")

    print(f"\n{'='*80}")
    print(f"OPTIMAL CONFIDENCE THRESHOLDS")
    print(f"{'='*80}")

    print(f"{'Class':<15} {'Threshold':<10}")
    print("-" * 25)

    for class_id in range(1, 5):
        print(f"{class_names[class_id]:<15} {best_thresholds[class_id]:<10.2f}")

    return best_thresholds

def evaluate_class_threshold(model, data_loader, device, target_class, thresholds):
    """
    Evaluate threshold for a specific class

    Parameters:
    - model: Trained model
    - data_loader: Validation data loader
    - device: Device to run on
    - target_class: Class ID to evaluate
    - thresholds: Dictionary of thresholds for each class

    Returns:
    - Dictionary with precision, recall, F1, and AP for the target class
    """
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for images, targets in data_loader:
            images = [img.to(device) for img in images]

            # Forward pass
            results = model(images)

            # Process each image result
            for i, result in enumerate(results):
                pred_boxes = result['boxes'].cpu()
                pred_labels = result['labels'].cpu()
                pred_scores = result['scores'].cpu()

                # Get predictions for target class only
                target_mask = pred_labels == target_class
                target_boxes = pred_boxes[target_mask]
                target_scores = pred_scores[target_mask]

                # Apply threshold
                threshold_mask = target_scores >= thresholds[target_class]
                target_boxes = target_boxes[threshold_mask]

                # Get ground truth for target class
                gt_boxes = targets[i]['boxes'].cpu()
                gt_labels = targets[i]['labels'].cpu()
                gt_mask = gt_labels == target_class
                gt_boxes = gt_boxes[gt_mask]

                # Match predictions to ground truth
                matches = match_boxes(target_boxes, gt_boxes)

                # Create binary arrays for this image
                if len(gt_boxes) == 0:
                    # No ground truth objects - all predictions are false positives
                    all_predictions.extend([0] * len(matches))
                else:
                    # Add true/false positives
                    all_predictions.extend([1 if m else 0 for m in matches])

                    # Add false negatives for unmatched ground truths
                    missed_gt = len(gt_boxes) - sum(matches)
                    if missed_gt > 0:
                        all_targets.extend([1] * len(gt_boxes))
                        all_predictions.extend([0] * missed_gt)
                    else:
                        all_targets.extend([1] * len(gt_boxes))

    # Calculate metrics
    if not all_predictions or not all_targets:
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'AP': 0.0}

    # Make sure arrays have the same length
    if len(all_predictions) != len(all_targets):
        # Truncate the longer array to match the shorter one
        min_len = min(len(all_predictions), len(all_targets))
        all_predictions = all_predictions[:min_len]
        all_targets = all_targets[:min_len]

    # Convert to numpy arrays for calculation
    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)

    try:
        # Calculate precision, recall, F1
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_targets, all_predictions, average='binary', zero_division=0
        )

        # Approximate AP (simplified)
        AP = precision * recall

        return {
            'precision': float(precision),
            'recall': float(recall),
            'f1': float(f1),
            'AP': float(AP)
        }
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        print(f"Predictions array length: {len(all_predictions)}")
        print(f"Targets array length: {len(all_targets)}")
        # Return default metrics on error
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'AP': 0.0}



# First define the helper functions
def tensor_to_image(tensor):
    # Convert tensor to numpy and transpose dimensions
    img = tensor.permute(1, 2, 0).cpu().numpy()

    # Denormalize
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    img = (img * 255).astype(np.uint8)
    return img

def draw_boxes(image, target):
    import cv2
    image = image.copy()

    # Define colors for each class
    class_colors = {
        1: (0, 255, 0),    # Green for encampments
        2: (0, 0, 255)     # Blue for carts
    }

    # Define class names
    class_names = {
        1: 'Encampments',
        2: 'Cart'
    }

    # Get boxes and labels
    boxes = target['boxes'].cpu().numpy().astype(np.int32)
    labels = target['labels'].cpu().numpy()

    # Draw boxes
    for box, label in zip(boxes, labels):
        color = class_colors.get(label.item(), (255, 255, 255))
        cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]),
                     color, 2, cv2.LINE_AA)
        cv2.putText(image, f"{class_names.get(label.item(), 'Unknown')}",
                   (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    return image

def match_boxes(pred_boxes, gt_boxes, iou_threshold=0.5):
    """
    Match predicted boxes to ground truth boxes

    Parameters:
    - pred_boxes: Predicted bounding boxes
    - gt_boxes: Ground truth bounding boxes
    - iou_threshold: IoU threshold for considering a match

    Returns:
    - List of boolean values indicating whether each prediction matched a ground truth
    """
    if len(pred_boxes) == 0:
        return []

    if len(gt_boxes) == 0:
        return [False] * len(pred_boxes)

    matches = []
    matched_gt = set()

    # For each prediction, find the best matching ground truth
    for pred_box in pred_boxes:
        best_iou = 0
        best_gt_idx = -1

        for gt_idx, gt_box in enumerate(gt_boxes):
            if gt_idx in matched_gt:
                continue  # This ground truth is already matched

            iou = calculate_iou(pred_box, gt_box)
            if iou > best_iou:
                best_iou = iou
                best_gt_idx = gt_idx

        # Check if we found a match
        if best_iou >= iou_threshold and best_gt_idx not in matched_gt:
            matches.append(True)
            matched_gt.add(best_gt_idx)
        else:
            matches.append(False)

    return matches

def run_experiment(data_dir, train_xml_files, val_xml_files,
                  augmentation_type="basic", sampling_strategy="class_aware",
                  loss_reweighting=False, num_epochs=15, experiment_name="experiment",
                  confidence_thresholds=None):
    """
    Run training experiment with specified configuration

    Parameters:
    - data_dir: Directory containing images and XML files
    - train_xml_files: List of XML files for training
    - val_xml_files: List of XML files for validation
    - augmentation_type: "basic" or "mosaic_mixup"
    - sampling_strategy: "class_aware" or "random"
    - loss_reweighting: Whether to apply loss reweighting
    - num_epochs: Number of training epochs
    - experiment_name: Name for saving models and visualizations
    - confidence_thresholds: Optional dictionary of class-specific confidence thresholds

    Returns:
    - Dictionary with training results
    """
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(f"Using device: {device}")

    print(f"\n{'='*80}")
    print(f"EXPERIMENT: {experiment_name}")
    print(f"Augmentation: {augmentation_type}")
    print(f"Sampling Strategy: {sampling_strategy}")
    print(f"Loss Reweighting: {loss_reweighting}")

    # Default confidence thresholds if none provided
    if confidence_thresholds is None:
        confidence_thresholds = {1: 0.55, 2: 0.8, 3: 0.45, 4: 0.65}

    print(f"Confidence Thresholds: {confidence_thresholds}")
    print(f"{'='*80}\n")

    # Create datasets with specified augmentation
    train_dataset = HomelessDataset(data_dir, train_xml_files,
                                   transform=get_basic_transform(train=True),
                                   augmentation_type=augmentation_type)
    val_dataset = HomelessDataset(data_dir, val_xml_files,
                                 transform=get_basic_transform(train=False),
                                 augmentation_type="basic")  # Always use basic for validation

    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")

    # Calculate class weights (for either sampling or loss reweighting)
    class_counts = {1: 0, 2: 0, 3: 0, 4: 0}

    print("Calculating class distribution...")
    for idx in tqdm(range(len(train_dataset))):
        _, target, is_valid = train_dataset[idx]
        if not is_valid:
            continue

        if 'labels' in target:
            for label in target['labels']:
                class_id = label.item()
                if class_id in class_counts:
                    class_counts[class_id] += 1

    # Calculate weights from class distribution
    total_instances = sum(class_counts.values())
    class_weights = {}

    print("\nClass distribution and weights:")
    print(f"{'Class':<10} {'Count':<10} {'Weight':<10}")
    print("-" * 30)

    for class_id, count in class_counts.items():
        if count > 0:
            # Calculate inverse frequency weight
            weight = total_instances / (count * len(class_counts))
            class_weights[class_id] = weight
            print(f"{class_id:<10} {count:<10} {weight:<10.4f}")
        else:
            class_weights[class_id] = 1.0

    # Create data loader based on sampling strategy
    if sampling_strategy == "class_aware":
        # Get weights for sampler
        sample_weights = train_dataset.get_sample_weights()

        # Create weighted sampler
        weighted_sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )

        # Create data loader with weighted sampler
        train_loader = DataLoader(
            train_dataset,
            batch_size=4,
            sampler=weighted_sampler,
            collate_fn=collate_fn,
            num_workers=2
        )
        print("Using class-aware sampling")
    else:  # random sampling
        # Create data loader with random sampling
        train_loader = DataLoader(
            train_dataset,
            batch_size=4,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=2
        )
        print("Using random sampling")

    # Validation loader is always the same
    val_loader = DataLoader(
        val_dataset,
        batch_size=4,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2
    )

    # Initialize model
    num_classes = 5  # Background + 4 classes
    model = get_model(num_classes)
    model.to(device)

    # Use SGD optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params,
        lr=0.005,
        momentum=0.9,
        weight_decay=0.0005
    )

    # Learning rate scheduler - cosine annealing
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs,
        eta_min=1e-6
    )

    # Training parameters
    best_metrics = None
    history = {
        "train_loss": [],
        "val_mAP": [],
        "val_IoU": [],
        "class_metrics": {
            1: {"f1": [], "AP": [], "IoU": []},
            2: {"f1": [], "AP": [], "IoU": []},
            3: {"f1": [], "AP": [], "IoU": []},
            4: {"f1": [], "AP": [], "IoU": []}
        }
    }

    # Training loop
    print("\nStarting training...")
    for epoch in range(num_epochs):
        # Train with appropriate loss function
        if loss_reweighting:
            train_loss = train_one_epoch_reweighed(
                model, optimizer, train_loader, device, epoch, class_weights, print_freq=10
            )
            print(f"Using loss reweighting with weights: {class_weights}")
        else:
            train_loss = train_one_epoch(
                model, optimizer, train_loader, device, epoch, print_freq=10
            )
            print("Not using loss reweighting")

        # Step LR scheduler
        lr_scheduler.step()

        # Evaluate with comprehensive metrics
        val_metrics = evaluate_model(model, val_loader, device)

        # Log history
        history["train_loss"].append(train_loss)
        history["val_mAP"].append(val_metrics["mAP"])
        history["val_IoU"].append(val_metrics["IoU"])

        for cls in range(1, 5):
            history["class_metrics"][cls]["f1"].append(val_metrics["class_metrics"][cls]["f1"])
            history["class_metrics"][cls]["AP"].append(val_metrics["class_metrics"][cls]["AP"])
            history["class_metrics"][cls]["IoU"].append(val_metrics["class_metrics"][cls]["IoU"])

        # Print epoch summary
        print(f"\nEpoch {epoch} completed:")
        print(f"  Training Loss: {train_loss:.4f}")
        print(f"  Validation mAP@0.5: {val_metrics['mAP']:.4f}, IoU: {val_metrics['IoU']:.4f}")
        print(f"  Class AP: People: {val_metrics['class_metrics'][1]['AP']:.4f}, "
              f"Encampments: {val_metrics['class_metrics'][2]['AP']:.4f}, "
              f"Cart: {val_metrics['class_metrics'][3]['AP']:.4f}, "
              f"Bike: {val_metrics['class_metrics'][4]['AP']:.4f}")
        print(f"  Class F1: People: {val_metrics['class_metrics'][1]['f1']:.4f}, "
              f"Encampments: {val_metrics['class_metrics'][2]['f1']:.4f}, "
              f"Cart: {val_metrics['class_metrics'][3]['f1']:.4f}, "
              f"Bike: {val_metrics['class_metrics'][4]['f1']:.4f}")
        print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

        # Save best model based on mAP
        if best_metrics is None or val_metrics["mAP"] > best_metrics["mAP"]:
            best_metrics = val_metrics
            model_path = f"maskrcnn_{experiment_name}_weights.pth"
            torch.save(model.state_dict(), model_path)
            print(f"  New best mAP: {val_metrics['mAP']:.4f}, model saved to {model_path}")

    print(f"\nTraining complete! Best mAP: {best_metrics['mAP']:.4f}")

    # Load best model for visualization
    model.load_state_dict(torch.load(f"maskrcnn_{experiment_name}_weights.pth"))

    # Visualize predictions with provided thresholds
    print("Generating prediction visualizations...")
    viz_path = visualize_predictions(
        model, val_dataset, device, num_images=5,
        confidence_thresholds=confidence_thresholds,
        save_path=f"predictions_{experiment_name}.png"
    )
    print(f"Visualization saved to {viz_path}")

    # Return results
    results = {
        "experiment_name": experiment_name,
        "augmentation_type": augmentation_type,
        "sampling_strategy": sampling_strategy,
        "loss_reweighting": loss_reweighting,
        "confidence_thresholds": confidence_thresholds,
        "best_metrics": best_metrics,
        "history": history,
        "model_path": f"maskrcnn_{experiment_name}_weights.pth",
        "viz_path": viz_path
    }

    return results

def create_train_val_test_split(xml_files, val_size=0.15, test_size=0.15, random_state=42):
    """
    Create a 3-way split of the data into training, validation, and test sets.

    Parameters:
    - xml_files: List of XML filenames
    - val_size: Proportion of data to use for validation
    - test_size: Proportion of data to use for testing
    - random_state: Random seed for reproducibility

    Returns:
    - train_xml_files, val_xml_files, test_xml_files: Lists of XML filenames
    """
    # First split off the test set
    train_val_files, test_xml_files = train_test_split(
        xml_files,
        test_size=test_size,
        random_state=random_state
    )

    # Then split the remaining data into train and validation
    # Adjust validation size to be relative to the train_val_files size
    effective_val_size = val_size / (1 - test_size)

    train_xml_files, val_xml_files = train_test_split(
        train_val_files,
        test_size=effective_val_size,
        random_state=random_state
    )

    return train_xml_files, val_xml_files, test_xml_files

# Function to evaluate on test set
def evaluate_on_test_set(model, test_loader, device, confidence_thresholds=None):
    """
    Evaluate model performance on the test set

    Parameters:
    - model: Trained model
    - test_loader: DataLoader for test set
    - device: Device to run evaluation on
    - confidence_thresholds: Dictionary mapping class IDs to confidence thresholds

    Returns:
    - test_metrics: Dictionary containing evaluation metrics
    """
    model.eval()

    # Initialize metrics accumulators for all classes
    metrics_accumulator = {
        "mAP": [],
        "IoU": [],
        "class_metrics": {
            1: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []},
            2: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []},
            3: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []},
            4: {"precision": [], "recall": [], "f1": [], "AP": [], "IoU": []}
        }
    }

    with torch.no_grad():
        for images, targets in test_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Get predictions
            predictions = model(images)

            # Apply confidence thresholds if provided
            if confidence_thresholds:
                filtered_predictions = []
                for prediction in predictions:
                    keep_indices = []
                    for j, label in enumerate(prediction['labels']):
                        label_id = label.item()
                        if label_id in confidence_thresholds and prediction['scores'][j] >= confidence_thresholds[label_id]:
                            keep_indices.append(j)

                    if keep_indices:
                        # Create new filtered prediction with all required keys
                        filtered_pred = {
                            'boxes': prediction['boxes'][keep_indices],
                            'labels': prediction['labels'][keep_indices],
                            'scores': prediction['scores'][keep_indices]
                        }
                    else:
                        # Create empty prediction
                        filtered_pred = {
                            'boxes': torch.zeros((0, 4), device=device),
                            'labels': torch.zeros(0, dtype=torch.int64, device=device),
                            'scores': torch.zeros(0, device=device)
                        }

                    filtered_predictions.append(filtered_pred)

                # Replace original predictions with filtered ones
                predictions = filtered_predictions

            # Calculate metrics for each image
            for i, (prediction, target) in enumerate(zip(predictions, targets)):
                pred_boxes = prediction['boxes'].cpu()
                pred_labels = prediction['labels'].cpu()
                pred_scores = prediction['scores'].cpu()

                gt_boxes = target['boxes'].cpu()
                gt_labels = target['labels'].cpu()

                # Calculate metrics
                metrics = calculate_metrics(
                    pred_boxes, pred_labels, pred_scores,
                    gt_boxes, gt_labels, iou_threshold=0.5
                )

                # Accumulate metrics
                metrics_accumulator["mAP"].append(metrics["mAP"])
                metrics_accumulator["IoU"].append(metrics["IoU"])

                # Accumulate for all classes
                for cls in range(1, 5):
                    if cls in metrics["class_metrics"]:
                        for metric_name, value in metrics["class_metrics"][cls].items():
                            metrics_accumulator["class_metrics"][cls][metric_name].append(value)

    # Calculate average metrics
    avg_metrics = {
        "mAP": np.mean(metrics_accumulator["mAP"]) if metrics_accumulator["mAP"] else 0.0,
        "IoU": np.mean(metrics_accumulator["IoU"]) if metrics_accumulator["IoU"] else 0.0,
        "class_metrics": {}
    }

    # Calculate class averages
    for cls in range(1, 5):
        avg_metrics["class_metrics"][cls] = {}
        for metric_name in ["precision", "recall", "f1", "AP", "IoU"]:
            values = metrics_accumulator["class_metrics"][cls][metric_name]
            avg_metrics["class_metrics"][cls][metric_name] = np.mean(values) if values else 0.0

    return avg_metrics

# Function to log test results
def log_test_results(experiment_name, test_metrics, confidence_thresholds=None):
    """
    Log the test results for an experiment to a file

    Parameters:
    - experiment_name: Name of the experiment
    - test_metrics: Dictionary containing evaluation metrics on test set
    - confidence_thresholds: Dictionary mapping class IDs to confidence thresholds
    """
    log_file = f"test_results_{experiment_name}.txt"

    with open(log_file, 'w') as f:
        f.write(f"TEST RESULTS FOR EXPERIMENT: {experiment_name}\n")
        f.write("="*80 + "\n\n")

        if confidence_thresholds:
            f.write(f"Confidence Thresholds: {confidence_thresholds}\n\n")

        f.write(f"Overall Metrics:\n")
        f.write(f"  mAP@0.5: {test_metrics['mAP']:.4f}\n")
        f.write(f"  IoU: {test_metrics['IoU']:.4f}\n\n")

        f.write(f"Class-Specific Metrics:\n")

        class_names = {1: 'People', 2: 'Encampments', 3: 'Cart', 4: 'Bike'}

        for cls in range(1, 5):
            f.write(f"  Class {cls} ({class_names[cls]}):\n")
            f.write(f"    Precision: {test_metrics['class_metrics'][cls]['precision']:.4f}\n")
            f.write(f"    Recall: {test_metrics['class_metrics'][cls]['recall']:.4f}\n")
            f.write(f"    F1: {test_metrics['class_metrics'][cls]['f1']:.4f}\n")
            f.write(f"    AP: {test_metrics['class_metrics'][cls]['AP']:.4f}\n")
            f.write(f"    IoU: {test_metrics['class_metrics'][cls]['IoU']:.4f}\n\n")

    print(f"Test results saved to {log_file}")
    return log_file
def visualize_augmentations_with_boxes(data_dir, xml_files, transform=None):
    """
    Visualize data augmentation techniques with bounding boxes for all types

    Parameters:
    - data_dir: Directory containing images and XML files
    - xml_files: List of XML files
    - transform: Optional transform to apply
    """
    plt.figure(figsize=(15, 15))

    # Create a dataset just for visualization
    viz_dataset = HomelessDataset(
        data_dir,
        xml_files[:50],  # Use a subset
        transform=None,  # Important: don't normalize during visualization
        augmentation_type="basic"
    )

    # Pick a random index with objects
    idx = None
    for _ in range(10):  # Try up to 10 random indices
        try_idx = np.random.randint(0, len(viz_dataset))
        # Check the return values from _get_image_and_targets
        try:
            # Instead of unpacking, just get the return value
            result = viz_dataset._get_image_and_targets(try_idx)

            # Check how many values are returned and adjust accordingly
            if len(result) == 5:  # For Mask R-CNN version
                raw_image, boxes, class_ids, masks, img_shape = result
            elif len(result) == 4:  # For Faster R-CNN version
                raw_image, boxes, class_ids, img_shape = result
            else:
                print(f"Unexpected number of return values: {len(result)}")
                continue

            if len(boxes) > 0:
                idx = try_idx
                break
        except Exception as e:
            print(f"Error with index {try_idx}: {e}")
            continue

    if idx is None:
        print("Could not find any images with valid boxes. Using first image.")
        idx = 0
        # Try again with first image
        result = viz_dataset._get_image_and_targets(idx)

        # Handle return values based on number of returned items
        if len(result) == 5:
            raw_image, boxes, class_ids, masks, img_shape = result
        elif len(result) == 4:
            raw_image, boxes, class_ids, img_shape = result
        else:
            print(f"Still got unexpected number of values: {len(result)}")
            # Create dummy data
            raw_image = np.zeros((512, 512, 3), dtype=np.uint8)
            boxes = []
            class_ids = []
            img_shape = (512, 512)

    # Define class colors
    class_colors = {
        1: (0, 255, 0),     # Green for People
        2: (0, 0, 255),     # Blue for Encampments
        3: (255, 0, 0),     # Red for Cart
        4: (255, 255, 0)    # Yellow for Bike
    }

    class_names = {
        1: 'People',
        2: 'Encampments',
        3: 'Cart',
        4: 'Bike'
    }

    # 1. Show basic image with boxes
    plt.subplot(3, 1, 1)
    image_with_boxes = raw_image.copy()
    for box, class_id in zip(boxes, class_ids):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))
        cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 2)
        cv2.putText(image_with_boxes, class_names.get(class_id, "Unknown"),
                   (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    plt.imshow(image_with_boxes)
    plt.title("Basic (No Augmentation)")
    plt.axis('off')

    # 2. Show mosaic augmentation with boxes
    plt.subplot(3, 1, 2)

    # Create a direct mosaic
    indices = [idx] + [random.randint(0, len(viz_dataset.xml_files) - 1) for _ in range(3)]
    mosaic_img = np.zeros((512, 512, 3), dtype=np.uint8)

    # For tracking transformed boxes
    mosaic_boxes = []
    mosaic_class_ids = []

    # Place images in mosaic grid (2x2)
    cx, cy = 256, 256  # Center point
    placements = [(0, 0), (cx, 0), (0, cy), (cx, cy)]  # Top-left, top-right, bottom-left, bottom-right

    for i, mosaic_idx in enumerate(indices):
        # Get image and boxes with handling for different return value counts
        result = viz_dataset._get_image_and_targets(mosaic_idx)
        if len(result) == 5:
            img, img_boxes, img_class_ids, _, img_shape = result
        elif len(result) == 4:
            img, img_boxes, img_class_ids, img_shape = result
        else:
            continue  # Skip this image if unexpected return values

        h, w = img_shape

        # Place in the mosaic
        x_offset, y_offset = placements[i]

        # Resize image to fit quadrant
        quadrant_w, quadrant_h = cx, cy
        part_img = cv2.resize(img, (quadrant_w, quadrant_h))
        mosaic_img[y_offset:y_offset+quadrant_h, x_offset:x_offset+quadrant_w] = part_img

        # Scale and offset boxes
        scale_x = quadrant_w / w
        scale_y = quadrant_h / h

        for box_idx, box in enumerate(img_boxes):
            if len(box) == 4:  # Ensure the box has the expected format
                x1, y1, x2, y2 = box

                # Scale coordinates
                x1_new = int(x1 * scale_x) + x_offset
                y1_new = int(y1 * scale_y) + y_offset
                x2_new = int(x2 * scale_x) + x_offset
                y2_new = int(y2 * scale_y) + y_offset

                # Clip to mosaic boundaries
                x1_new = max(0, min(512-1, x1_new))
                y1_new = max(0, min(512-1, y1_new))
                x2_new = max(0, min(512-1, x2_new))
                y2_new = max(0, min(512-1, y2_new))

                # Check if the box is still valid
                if x2_new > x1_new and y2_new > y1_new and (x2_new - x1_new) >= 5 and (y2_new - y1_new) >= 5:
                    mosaic_boxes.append([x1_new, y1_new, x2_new, y2_new])
                    if box_idx < len(img_class_ids):
                        mosaic_class_ids.append(img_class_ids[box_idx])
                    else:
                        mosaic_class_ids.append(1)  # Default to class 1 if out of bounds

    # Draw boxes on mosaic
    mosaic_with_boxes = mosaic_img.copy()
    for box, class_id in zip(mosaic_boxes, mosaic_class_ids):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))
        cv2.rectangle(mosaic_with_boxes, (x1, y1), (x2, y2), color, 2)
        cv2.putText(mosaic_with_boxes, class_names.get(class_id, "Unknown"),
                   (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    plt.imshow(mosaic_with_boxes)
    plt.title("Mosaic Augmentation")
    plt.axis('off')

    # 3. Show mixup augmentation with boxes
    plt.subplot(3, 1, 3)

    # Create a direct mixup without normalization for visualization
    idx2 = random.randint(0, len(viz_dataset.xml_files) - 1)

    # Get first image with handling for different return value counts
    result1 = viz_dataset._get_image_and_targets(idx)
    if len(result1) == 5:
        img1, boxes1, class_ids1, _, img_shape1 = result1
    elif len(result1) == 4:
        img1, boxes1, class_ids1, img_shape1 = result1
    else:
        img1 = np.zeros((512, 512, 3), dtype=np.uint8)
        boxes1 = []
        class_ids1 = []
        img_shape1 = (512, 512)

    # Get second image with handling for different return value counts
    result2 = viz_dataset._get_image_and_targets(idx2)
    if len(result2) == 5:
        img2, boxes2, class_ids2, _, img_shape2 = result2
    elif len(result2) == 4:
        img2, boxes2, class_ids2, img_shape2 = result2
    else:
        img2 = np.zeros((512, 512, 3), dtype=np.uint8)
        boxes2 = []
        class_ids2 = []
        img_shape2 = (512, 512)

    # Resize both images to the same size
    h, w = 512, 512
    img1_resized = cv2.resize(img1, (w, h))
    img2_resized = cv2.resize(img2, (w, h))

    # Calculate scale factors to adjust bounding boxes
    h1, w1 = img_shape1
    h2, w2 = img_shape2

    scale_x1, scale_y1 = w/w1, h/h1
    scale_x2, scale_y2 = w/w2, h/h2

    # Scale boxes for both images
    scaled_boxes1 = []
    scaled_class_ids1 = []
    for box_idx, box in enumerate(boxes1):
        x1, y1, x2, y2 = box
        x1s, y1s = int(x1 * scale_x1), int(y1 * scale_y1)
        x2s, y2s = int(x2 * scale_x1), int(y2 * scale_y1)
        # Clip to boundaries
        x1s, y1s = max(0, x1s), max(0, y1s)
        x2s, y2s = min(w, x2s), min(h, y2s)
        # Check if valid
        if x2s > x1s and y2s > y1s and (x2s - x1s) >= 5 and (y2s - y1s) >= 5:
            scaled_boxes1.append([x1s, y1s, x2s, y2s])
            if box_idx < len(class_ids1):
                scaled_class_ids1.append(class_ids1[box_idx])

    scaled_boxes2 = []
    scaled_class_ids2 = []
    for box_idx, box in enumerate(boxes2):
        x1, y1, x2, y2 = box
        x1s, y1s = int(x1 * scale_x2), int(y1 * scale_y2)
        x2s, y2s = int(x2 * scale_x2), int(y2 * scale_y2)
        # Clip to boundaries
        x1s, y1s = max(0, x1s), max(0, y1s)
        x2s, y2s = min(w, x2s), min(h, y2s)
        # Check if valid
        if x2s > x1s and y2s > y1s and (x2s - x1s) >= 5 and (y2s - y1s) >= 5:
            scaled_boxes2.append([x1s, y1s, x2s, y2s])
            if box_idx < len(class_ids2):
                scaled_class_ids2.append(class_ids2[box_idx])

    # Mix the images with a fixed alpha for visualization clarity
    alpha = 0.5
    mixed_img = cv2.addWeighted(img1_resized, alpha, img2_resized, 1.0 - alpha, 0)

    # Draw boxes from both images
    mixed_with_boxes = mixed_img.copy()

    # Draw boxes from first image with solid lines
    for box, class_id in zip(scaled_boxes1, scaled_class_ids1):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))
        cv2.rectangle(mixed_with_boxes, (x1, y1), (x2, y2), color, 2)
        cv2.putText(mixed_with_boxes, f"{class_names.get(class_id, 'Unknown')} (img1)",
                   (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    # Draw boxes from second image with dashed lines (to distinguish them)
    for box, class_id in zip(scaled_boxes2, scaled_class_ids2):
        x1, y1, x2, y2 = box
        color = class_colors.get(class_id, (255, 255, 255))

        # Create dashed line effect (crude but effective for visualization)
        for i in range(x1, x2, 5):
            cv2.line(mixed_with_boxes, (i, y1), (min(i+3, x2), y1), color, 2)
            cv2.line(mixed_with_boxes, (i, y2), (min(i+3, x2), y2), color, 2)
        for i in range(y1, y2, 5):
            cv2.line(mixed_with_boxes, (x1, i), (x1, min(i+3, y2)), color, 2)
            cv2.line(mixed_with_boxes, (x2, i), (x2, min(i+3, y2)), color, 2)

        cv2.putText(mixed_with_boxes, f"{class_names.get(class_id, 'Unknown')} (img2)",
                   (x1, y2+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    plt.imshow(mixed_with_boxes)
    plt.title("Mixup Augmentation")
    plt.axis('off')

    plt.tight_layout()
    plt.savefig("augmentation_examples_with_boxes.png")
    plt.close()

    return "augmentation_examples_with_boxes.png"

# Function to visualize predictions on test set
def visualize_test_predictions(model, test_dataset, device, num_images=5,
                             confidence_thresholds=None, save_path=None):
    """
    Visualize model predictions on the test set

    Parameters:
    - model: Trained model
    - test_dataset: Test dataset
    - device: Device to run inference on
    - num_images: Number of images to visualize
    - confidence_thresholds: Dictionary mapping class IDs to confidence thresholds
    - save_path: Path to save visualization image

    Returns:
    - save_path: Path where visualization was saved
    """
    model.eval()

    # Default confidence thresholds if none provided
    if confidence_thresholds is None:
        confidence_thresholds = {1: 0.5, 2: 0.5, 3: 0.5, 4: 0.5}

    # Class names and colors for display
    class_names = {
        1: 'People',
        2: 'Encampments',
        3: 'Cart',
        4: 'Bike'
    }

    class_colors = {
        1: (255, 0, 0),    # Red for people
        2: (0, 255, 0),    # Green for encampments
        3: (0, 0, 255),    # Blue for carts
        4: (255, 255, 0)   # Yellow for bikes
    }

    # Get random samples
    indices = np.random.choice(len(test_dataset), min(num_images, len(test_dataset)), replace=False)

    plt.figure(figsize=(20, 20))

    for i, idx in enumerate(indices):
        img, target, valid = test_dataset[idx]
        if not valid:
            continue

        # Simple inference
        with torch.no_grad():
            prediction = model([img.to(device)])[0]

        # Apply class-specific confidence thresholds
        keep_indices = []
        for j, label in enumerate(prediction['labels']):
            label_id = label.item()
            if label_id in confidence_thresholds and prediction['scores'][j] >= confidence_thresholds[label_id]:
                keep_indices.append(j)

        # Process boxes
        if len(keep_indices) > 0:
            boxes = prediction['boxes'][keep_indices].cpu().numpy()
            # Ensure boxes are within image boundaries before converting to int
            boxes[:, 0] = np.clip(boxes[:, 0], 0, img.shape[2] - 1)
            boxes[:, 1] = np.clip(boxes[:, 1], 0, img.shape[1] - 1)
            boxes[:, 2] = np.clip(boxes[:, 2], 0, img.shape[2] - 1)
            boxes[:, 3] = np.clip(boxes[:, 3], 0, img.shape[1] - 1)
            boxes = boxes.astype(np.int32)
            labels = prediction['labels'][keep_indices].cpu().numpy()
            scores = prediction['scores'][keep_indices].cpu().numpy()
        else:
            boxes = np.array([])
            labels = np.array([])
            scores = np.array([])

        # Convert image back to numpy for display
        image_np = img.permute(1, 2, 0).cpu().numpy()

        # Denormalize
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image_np = std * image_np + mean
        image_np = np.clip(image_np, 0, 1)
        image_np = (image_np * 255).astype(np.uint8)

        # Create a copy for drawing
        image_with_boxes = image_np.copy()

        # Draw ground truth boxes
        gt_boxes = target['boxes'].cpu().numpy().astype(np.int32)
        gt_labels = target['labels'].cpu().numpy()

        # Draw ground truth first
        for box, label in zip(gt_boxes, gt_labels):
            color = class_colors.get(label.item(), (255, 255, 255))
            cv2.rectangle(image_with_boxes, (box[0], box[1]), (box[2], box[3]),
                         color, 2, cv2.LINE_AA)
            cv2.putText(image_with_boxes, f"GT: {class_names.get(label.item(), 'Unknown')}",
                       (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        # Draw predictions
        for box, label, score in zip(boxes, labels, scores):
            color = class_colors.get(label.item(), (255, 255, 255))
            cv2.rectangle(image_with_boxes, (box[0], box[1]), (box[2], box[3]),
                         color, 2, cv2.LINE_AA)
            cv2.putText(image_with_boxes,
                       f"{class_names.get(label.item(), 'Unknown')}: {score:.2f}",
                       (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        # Display
        plt.subplot(num_images, 1, i + 1)
        plt.imshow(image_with_boxes)
        plt.title(f"Test Sample {idx} (Inference with Class-Specific Thresholds)")
        plt.axis('off')

    plt.tight_layout()

    if save_path is None:
        save_path = "test_predictions.png"

    plt.savefig(save_path)
    plt.close()

    return save_path

# Function to plot test results
def plot_test_comparison(test_results, save_path="test_comparison.png"):
    """
    Plot a comparison of model performance on the test set

    Parameters:
    - test_results: List of dictionaries containing test metrics
    - save_path: Path to save the visualization

    Returns:
    - save_path: Path where the visualization was saved
    """
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Test Set Performance Comparison', fontsize=16)

    # Get experiment names for labels
    experiment_names = [result["experiment_name"] for result in test_results]

    # Get mAP and IoU values
    mAP_values = [result["metrics"]["mAP"] for result in test_results]
    IoU_values = [result["metrics"]["IoU"] for result in test_results]

    # Plot mAP and IoU
    axs[0, 0].bar(experiment_names, mAP_values, color='skyblue')
    axs[0, 0].set_title('Test mAP@0.5')
    axs[0, 0].set_ylabel('mAP')
    axs[0, 0].set_ylim(0, 1.0)
    plt.setp(axs[0, 0].get_xticklabels(), rotation=45, ha='right')

    axs[0, 1].bar(experiment_names, IoU_values, color='lightgreen')
    axs[0, 1].set_title('Test IoU')
    axs[0, 1].set_ylabel('IoU')
    axs[0, 1].set_ylim(0, 1.0)
    plt.setp(axs[0, 1].get_xticklabels(), rotation=45, ha='right')

    # Class-specific metrics
    class_names = {1: 'People', 2: 'Encampments', 3: 'Cart', 4: 'Bike'}

    # Get F1 and AP values by class
    f1_by_class = {
        cls: [result["metrics"]["class_metrics"][cls]["f1"] for result in test_results]
        for cls in range(1, 5)
    }

    ap_by_class = {
        cls: [result["metrics"]["class_metrics"][cls]["AP"] for result in test_results]
        for cls in range(1, 5)
    }

    # Set up bar positions
    x = np.arange(len(experiment_names))
    width = 0.2

    # Plot F1 scores by class
    for i, cls in enumerate(range(1, 5)):
        axs[1, 0].bar(x + (i-1.5)*width, f1_by_class[cls], width,
                    label=class_names[cls],
                    color=f'C{i}')

    axs[1, 0].set_title('Test F1 Scores by Class')
    axs[1, 0].set_ylabel('F1 Score')
    axs[1, 0].set_xticks(x)
    axs[1, 0].set_xticklabels(experiment_names)
    axs[1, 0].set_ylim(0, 1.0)
    axs[1, 0].legend()
    plt.setp(axs[1, 0].get_xticklabels(), rotation=45, ha='right')

    # Plot AP by class
    for i, cls in enumerate(range(1, 5)):
        axs[1, 1].bar(x + (i-1.5)*width, ap_by_class[cls], width,
                    label=class_names[cls],
                    color=f'C{i}')

    axs[1, 1].set_title('Test AP by Class')
    axs[1, 1].set_ylabel('AP')
    axs[1, 1].set_xticks(x)
    axs[1, 1].set_xticklabels(experiment_names)
    axs[1, 1].set_ylim(0, 1.0)
    axs[1, 1].legend()
    plt.setp(axs[1, 1].get_xticklabels(), rotation=45, ha='right')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

    return save_path

# Modify main() to run grid search before scenarios
def main():
    """
    Main function with 3-way split for training, validation, and testing
    """
    # Data directory
    data_dir = "/content/Annotated-Training-Images-For-Fast-RCNN_1"

    # Get XML files
    xml_files = [f for f in os.listdir(data_dir) if f.endswith('.xml')]
    print(f"Found {len(xml_files)} XML files")

    # Create a 3-way split: train, validation, and test
    train_xml_files, val_xml_files, test_xml_files = create_train_val_test_split(
        xml_files, val_size=0.15, test_size=0.15, random_state=42
    )
    print(f"Split: {len(train_xml_files)} training, {len(val_xml_files)} validation, {len(test_xml_files)} test")

    # Create test dataset and loader (same for all experiments)
    test_dataset = HomelessDataset(
        data_dir,
        test_xml_files,
        transform=get_basic_transform(train=False),
        augmentation_type="basic"
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=4,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2
    )

    # Visualize augmentations to understand the data transformations
    visualization_path = visualize_augmentations_with_boxes(
        data_dir, train_xml_files, transform=get_basic_transform(train=True)
    )
    print(f"Augmentation visualization saved to: {visualization_path}")

    # Run grid search for optimal confidence thresholds using validation set
    print("\n" + "="*80)
    print("RUNNING GRID SEARCH FOR OPTIMAL CONFIDENCE THRESHOLDS")
    print("="*80)

    # Only train the grid search model for a few epochs to save time
    optimal_thresholds = grid_search_thresholds(
        data_dir=data_dir,
        train_xml_files=train_xml_files,
        val_xml_files=val_xml_files,
        num_epochs=5
    )

    # Store results for all scenarios
    all_results = []
    all_test_results = []
    num_classes = 5
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Number of epochs for each experiment
    num_epochs = 30



    # Scenario 4: Mosaic+Mixup augmentation + Random sampling (with loss reweighting)
    print("\n" + "="*80)
    print("SCENARIO 4: MOSAIC AUGMENTATION + RANDOM SAMPLING (WITH LOSS REWEIGHTING)")
    print("="*80)

    result4 = run_experiment(
        data_dir=data_dir,
        train_xml_files=train_xml_files,
        val_xml_files=val_xml_files,
        augmentation_type="mosaic_mixup",
        sampling_strategy="random",
        loss_reweighting=True,
        num_epochs=num_epochs,
        experiment_name="mosaic_random_reweighted",
        confidence_thresholds=optimal_thresholds
    )
    all_results.append(result4)

    # Load best model for test evaluation
    model4 = get_model(num_classes)
    model4.load_state_dict(torch.load(result4['model_path']))
    model4.to(device)

    # Evaluate on test set
    print("\nEvaluating Scenario 4 on test set...")
    test_metrics4 = evaluate_on_test_set(
        model4, test_loader, device, confidence_thresholds=optimal_thresholds
    )

    # Log test results
    log_file4 = log_test_results(
        result4['experiment_name'],
        test_metrics4,
        confidence_thresholds=optimal_thresholds
    )

    # Visualize predictions on test set
    test_viz_path4 = visualize_test_predictions(
        model4,
        test_dataset,
        device,
        num_images=5,
        confidence_thresholds=optimal_thresholds,
        save_path=f"test_predictions_{result4['experiment_name']}.png"
    )

    # Add test results to tracking
    test_result4 = {
        "experiment_name": result4["experiment_name"],
        "metrics": test_metrics4,
        "log_file": log_file4,
        "viz_path": test_viz_path4
    }
    all_test_results.append(test_result4)

    # Print comparison of all scenarios on validation
    print("\n" + "="*80)
    print("COMPREHENSIVE COMPARISON OF ALL SCENARIOS (VALIDATION)")
    print("="*80)

    # Table header
    print(f"{'Scenario':<10} {'Augmentation':<15} {'Sampling':<15} {'Loss Reweight':<15} {'Val mAP':<10}")
    print("-" * 70)

    # Print results for each scenario
    for i, result in enumerate(all_results):
        print(f"{i+1:<10} "
              f"{result['augmentation_type']:<15} "
              f"{result['sampling_strategy']:<15} "
              f"{'Yes' if result['loss_reweighting'] else 'No':<15} "
              f"{result['best_metrics']['mAP']:<10.4f}")

    # Print comparison of all scenarios on test set
    print("\n" + "="*80)
    print("COMPREHENSIVE COMPARISON OF ALL SCENARIOS (TEST SET)")
    print("="*80)

    # Table header
    print(f"{'Scenario':<10} {'Augmentation':<15} {'Sampling':<15} {'Loss Reweight':<15} {'Test mAP':<10} {'Test IoU':<10}")
    print("-" * 80)

    # Print results for each scenario
    for i, (result, test_result) in enumerate(zip(all_results, all_test_results)):
        print(f"{i+1:<10} "
              f"{result['augmentation_type']:<15} "
              f"{result['sampling_strategy']:<15} "
              f"{'Yes' if result['loss_reweighting'] else 'No':<15} "
              f"{test_result['metrics']['mAP']:<10.4f} "
              f"{test_result['metrics']['IoU']:<10.4f}")

    # Print class-specific AP comparison for test set
    print("\n" + "="*80)
    print("TEST SET CLASS-SPECIFIC AVERAGE PRECISION (AP)")
    print("="*80)

    # Table header for class AP
    print(f"{'Scenario':<10} {'People':<10} {'Encampments':<15} {'Cart':<10} {'Bike':<10}")
    print("-" * 60)

    # Print class AP for each scenario
    for i, test_result in enumerate(all_test_results):
        print(f"{i+1:<10} "
              f"{test_result['metrics']['class_metrics'][1]['AP']:<10.4f} "
              f"{test_result['metrics']['class_metrics'][2]['AP']:<15.4f} "
              f"{test_result['metrics']['class_metrics'][3]['AP']:<10.4f} "
              f"{test_result['metrics']['class_metrics'][4]['AP']:<10.4f}")

    # Print class-specific F1 comparison for test set
    print("\n" + "="*80)
    print("TEST SET CLASS-SPECIFIC F1 SCORES")
    print("="*80)

    # Table header for class F1
    print(f"{'Scenario':<10} {'People':<10} {'Encampments':<15} {'Cart':<10} {'Bike':<10}")
    print("-" * 60)

    # Print class F1 for each scenario
    for i, test_result in enumerate(all_test_results):
        print(f"{i+1:<10} "
              f"{test_result['metrics']['class_metrics'][1]['f1']:<10.4f} "
              f"{test_result['metrics']['class_metrics'][2]['f1']:<15.4f} "
              f"{test_result['metrics']['class_metrics'][3]['f1']:<10.4f} "
              f"{test_result['metrics']['class_metrics'][4]['f1']:<10.4f}")

    # Create visualization comparing all scenarios (validation performance)
    val_plot_path = plot_training_history(all_results, "validation_comparison.png")
    print(f"\nValidation training history comparison plot for all scenarios saved to {val_plot_path}")

    # Create visualization for test performance
    test_plot_path = plot_test_comparison(all_test_results, "test_comparison.png")
    print(f"Test performance comparison plot saved to {test_plot_path}")

In [None]:
    # Scenario 1: Basic augmentation + Class-aware sampling (no loss reweighting)
    print("\n" + "="*80)
    print("SCENARIO 1: BASIC AUGMENTATION + CLASS-AWARE SAMPLING (NO LOSS REWEIGHTING)")
    print("="*80)

    result1 = run_experiment(
        data_dir=data_dir,
        train_xml_files=train_xml_files,
        val_xml_files=val_xml_files,
        augmentation_type="basic",
        sampling_strategy="class_aware",
        loss_reweighting=False,
        num_epochs=num_epochs,
        experiment_name="basic_classaware",
        confidence_thresholds=optimal_thresholds
    )
    all_results.append(result1)

    # Load best model for test evaluation
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    num_classes = 5  # Background + 4 classes

    model1 = get_model(num_classes)
    model1.load_state_dict(torch.load(result1['model_path']))
    model1.to(device)

    # Evaluate on test set
    print("\nEvaluating Scenario 1 on test set...")
    test_metrics1 = evaluate_on_test_set(
        model1, test_loader, device, confidence_thresholds=optimal_thresholds
    )

    # Log test results
    log_file1 = log_test_results(
        result1['experiment_name'],
        test_metrics1,
        confidence_thresholds=optimal_thresholds
    )

    # Visualize predictions on test set
    test_viz_path1 = visualize_test_predictions(
        model1,
        test_dataset,
        device,
        num_images=5,
        confidence_thresholds=optimal_thresholds,
        save_path=f"test_predictions_{result1['experiment_name']}.png"
    )

    # Add test results to tracking
    test_result1 = {
        "experiment_name": result1["experiment_name"],
        "metrics": test_metrics1,
        "log_file": log_file1,
        "viz_path": test_viz_path1
    }
    all_test_results.append(test_result1)

    # Scenario 2: Basic augmentation + Random sampling (with loss reweighting)
    print("\n" + "="*80)
    print("SCENARIO 2: BASIC AUGMENTATION + RANDOM SAMPLING (WITH LOSS REWEIGHTING)")
    print("="*80)

    result2 = run_experiment(
        data_dir=data_dir,
        train_xml_files=train_xml_files,
        val_xml_files=val_xml_files,
        augmentation_type="basic",
        sampling_strategy="random",
        loss_reweighting=True,
        num_epochs=num_epochs,
        experiment_name="basic_random_reweighted",
        confidence_thresholds=optimal_thresholds
    )
    all_results.append(result2)

    # Load best model for test evaluation
    model2 = get_model(num_classes)
    model2.load_state_dict(torch.load(result2['model_path']))
    model2.to(device)

    # Evaluate on test set
    print("\nEvaluating Scenario 2 on test set...")
    test_metrics2 = evaluate_on_test_set(
        model2, test_loader, device, confidence_thresholds=optimal_thresholds
    )

    # Log test results
    log_file2 = log_test_results(
        result2['experiment_name'],
        test_metrics2,
        confidence_thresholds=optimal_thresholds
    )

    # Visualize predictions on test set
    test_viz_path2 = visualize_test_predictions(
        model2,
        test_dataset,
        device,
        num_images=5,
        confidence_thresholds=optimal_thresholds,
        save_path=f"test_predictions_{result2['experiment_name']}.png"
    )

    # Add test results to tracking
    test_result2 = {
        "experiment_name": result2["experiment_name"],
        "metrics": test_metrics2,
        "log_file": log_file2,
        "viz_path": test_viz_path2
    }
    all_test_results.append(test_result2)

    # Scenario 3: Mosaic+Mixup augmentation + Class-aware sampling (no loss reweighting)
    print("\n" + "="*80)
    print("SCENARIO 3: MOSAIC AUGMENTATION + CLASS-AWARE SAMPLING (NO LOSS REWEIGHTING)")
    print("="*80)

    result3 = run_experiment(
        data_dir=data_dir,
        train_xml_files=train_xml_files,
        val_xml_files=val_xml_files,
        augmentation_type="mosaic_mixup",
        sampling_strategy="class_aware",
        loss_reweighting=False,
        num_epochs=num_epochs,
        experiment_name="mosaic_classaware",
        confidence_thresholds=optimal_thresholds
    )
    all_results.append(result3)

    # Load best model for test evaluation
    model3 = get_model(num_classes)
    model3.load_state_dict(torch.load(result3['model_path']))
    model3.to(device)

    # Evaluate on test set
    print("\nEvaluating Scenario 3 on test set...")
    test_metrics3 = evaluate_on_test_set(
        model3, test_loader, device, confidence_thresholds=optimal_thresholds
    )

    # Log test results
    log_file3 = log_test_results(
        result3['experiment_name'],
        test_metrics3,
        confidence_thresholds=optimal_thresholds
    )

    # Visualize predictions on test set
    test_viz_path3 = visualize_test_predictions(
        model3,
        test_dataset,
        device,
        num_images=5,
        confidence_thresholds=optimal_thresholds,
        save_path=f"test_predictions_{result3['experiment_name']}.png"
    )

    # Add test results to tracking
    test_result3 = {
        "experiment_name": result3["experiment_name"],
        "metrics": test_metrics3,
        "log_file": log_file3,
        "viz_path": test_viz_path3
    }
    all_test_results.append(test_result3)

In [None]:
# If running as main script, execute the training
if __name__ == "__main__":
    main()