In [None]:
import os
import json
import random
import math
import time
import copy
import pickle # For caching splits

In [None]:
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
from torchvision.ops import box_iou
import torchvision.transforms.v2 as T # Use new v2 transforms API

from torch.utils.data import Dataset, DataLoader, Subset

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm.notebook import tqdm # Use notebook version for better display

In [None]:
try:
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
    print("pycocotools found.")
except ImportError:
    print("ERROR: pycocotools not found. Evaluation will fail.")
    print("Install it: pip install pycocotools")
    COCO = None
    COCOeval = None

# Suppress specific warnings if needed
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # Filter UserWarnings if they become noisy

In [None]:
BASE_DATA_DIR = "data"
FRAME_BASE_DIR = os.path.join(BASE_DATA_DIR, "frame")
ANNOTATION_BASE_DIR = os.path.join(BASE_DATA_DIR, "annotation", "coco")
MODEL_SAVE_DIR = "models"
CACHE_DIR = "cache"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

In [None]:
# Data Splitting & Caching
VIDEO_IDS = [f"{i:02d}" for i in range(1, 59)] # Assuming 01 to 58
TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
# TEST_RATIO is implicitly 1.0 - TRAIN_RATIO - VAL_RATIO
SPLIT_CACHE_FILE = os.path.join(CACHE_DIR, "data_splits.pkl")
FORCE_REGENERATE_SPLITS = False # Set to True to recreate splits, otherwise load from cache if exists

In [None]:
# Model & Training Hyperparameters
NUM_CLASSES = 1 + 1 # Number of foreground classes + 1 background (UPDATE THIS based on your actual classes!)
BATCH_SIZE = 6      # Adjust based on GPU memory
NUM_EPOCHS = 15     # Number of training epochs
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0005
LR_STEP_SIZE = 5    # Decrease LR every N epochs
LR_GAMMA = 0.1      # Factor to decrease LR by
NUM_WORKERS = 0     # DataLoader workers (adjust based on CPU cores)
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

print(f"Using device: {DEVICE}")
print(f"Number of object classes (excluding background): {NUM_CLASSES - 1}")

In [None]:
# Check if base directories exist
if not os.path.isdir(FRAME_BASE_DIR) or not os.path.isdir(ANNOTATION_BASE_DIR):
    raise FileNotFoundError("Frame or Annotation directory not found. Please check paths.")

In [None]:
def load_coco_metadata(annotation_dir, video_ids):
    """Loads image info and frame counts from multiple COCO JSON files."""
    all_image_info = [] # List of dicts: {'image_id': id, 'coco_file': path, 'img_info': {...}}
    video_frame_counts = {}
    categories = {} # {id: name}
    cat_name_to_id = {} # {name: id} - Ensure consistency across files

    print(f"Loading metadata from {len(video_ids)} COCO files...")
    for video_id in tqdm(video_ids, desc="Loading JSONs"):
        coco_file = os.path.join(annotation_dir, f"{video_id}.json")
        if not os.path.exists(coco_file):
            print(f"Warning: Annotation file not found: {coco_file}. Skipping.")
            continue
        try:
            with open(coco_file, 'r') as f:
                data = json.load(f)

            video_frame_counts[video_id] = len(data.get('images', []))

            # Process categories - ensure consistency
            if 'categories' in data:
                for cat in data['categories']:
                    if cat['id'] not in categories:
                        categories[cat['id']] = cat['name']
                        if cat['name'] not in cat_name_to_id:
                             cat_name_to_id[cat['name']] = cat['id']
                        elif cat_name_to_id[cat['name']] != cat['id']:
                             print(f"Warning: Category ID mismatch for '{cat['name']}' in {coco_file}. Using first encountered ID {cat_name_to_id[cat['name']]}.")
                    elif categories[cat['id']] != cat['name']:
                         print(f"Warning: Category Name mismatch for ID {cat['id']}' in {coco_file} ('{cat['name']}' vs '{categories[cat['id']]}'). Keeping first name.")

            # Store image references
            for img in data.get('images', []):
                # Ensure image id is unique globally if merging later, but here it's fine per file
                # We store the original image ID from the JSON and the file it came from
                all_image_info.append({
                    'image_id_in_file': img['id'], # ID within its original JSON
                    'coco_file': coco_file,
                    'img_info': img # Store the whole image dict
                })
        except json.JSONDecodeError:
            print(f"Warning: Error decoding JSON file: {coco_file}. Skipping.")
        except Exception as e:
            print(f"Warning: Error processing file {coco_file}: {e}. Skipping.")

    print(f"Total images found across all files: {len(all_image_info)}")
    if not categories:
         raise ValueError("No categories found in any annotation file. Cannot proceed.")

    # Create a sorted list of category names and a mapping to contiguous IDs (1-based)
    sorted_cat_names = sorted(list(cat_name_to_id.keys()))
    cat_name_to_contiguous_id = {name: i + 1 for i, name in enumerate(sorted_cat_names)} # 1-based index
    print("\nCategories found:", cat_name_to_contiguous_id)

    # Verify NUM_CLASSES matches detected categories
    if NUM_CLASSES != (len(cat_name_to_contiguous_id) + 1):
        print(f"\nWARNING: NUM_CLASSES configured ({NUM_CLASSES}) does not match detected categories ({len(cat_name_to_contiguous_id)} + 1 background).")
        print("Please update NUM_CLASSES in the configuration.")
        # Optionally raise error: raise ValueError("NUM_CLASSES mismatch")

    return all_image_info, video_frame_counts, cat_name_to_contiguous_id

In [None]:
all_image_references, video_frame_counts, category_map = load_coco_metadata(ANNOTATION_BASE_DIR, VIDEO_IDS)

In [None]:
def create_balanced_splits(video_ids, video_frame_counts, train_ratio, val_ratio):
    """Splits video IDs ensuring balanced frame counts."""
    num_videos = len(video_ids)
    num_train = math.floor(train_ratio * num_videos)
    num_val = math.floor(val_ratio * num_videos)
    num_test = num_videos - num_train - num_val

    print(f"\nTarget split sizes (by video count): Train={num_train}, Val={num_val}, Test={num_test}")

    # Shuffle videos for randomness
    shuffled_video_ids = random.sample(video_ids, len(video_ids))

    # Sort shuffled videos by frame count (desc) to help balance assignment
    sorted_shuffled_videos = sorted(shuffled_video_ids, key=lambda vid: video_frame_counts.get(vid, 0), reverse=True)

    train_vids, val_vids, test_vids = [], [], []
    train_frames, val_frames, test_frames = 0, 0, 0

    # Assign videos greedily to the split with the fewest frames currently
    for video_id in sorted_shuffled_videos:
        frames = video_frame_counts.get(video_id, 0)
        assigned = False
        # Prioritize filling smaller target sets first if counts are equal
        split_counts = [
            (train_frames, train_vids, num_train),
            (val_frames, val_vids, num_val),
            (test_frames, test_vids, num_test)
        ]
        # Sort splits by current frame count, then by remaining capacity (desc)
        split_counts.sort(key=lambda x: (x[0], -(x[2] - len(x[1])) ))

        for i in range(len(split_counts)):
            current_frames, vid_list, target_count = split_counts[i]
            if len(vid_list) < target_count:
                 vid_list.append(video_id)
                 # Update frame counts directly (list references mutable lists)
                 if vid_list is train_vids: train_frames += frames
                 elif vid_list is val_vids: val_frames += frames
                 else: test_frames += frames
                 assigned = True
                 break

        if not assigned: # Should not happen if ratios sum <= 1
            print(f"Warning: Could not assign video {video_id}. Adding to train set.")
            train_vids.append(video_id)
            train_frames += frames


    print("\nActual split counts:")
    print(f"  Train: {len(train_vids)} videos, {train_frames} frames")
    print(f"  Val:   {len(val_vids)} videos, {val_frames} frames")
    print(f"  Test:  {len(test_vids)} videos, {test_frames} frames")

    return set(train_vids), set(val_vids), set(test_vids)

In [None]:
if not FORCE_REGENERATE_SPLITS and os.path.exists(SPLIT_CACHE_FILE):
    print(f"Loading cached splits from {SPLIT_CACHE_FILE}...")
    with open(SPLIT_CACHE_FILE, 'rb') as f:
        split_data = pickle.load(f)
    train_image_refs = split_data['train']
    val_image_refs = split_data['val']
    test_image_refs = split_data['test']
    category_map = split_data['category_map'] # Load map from cache too
    print(f"Loaded splits: Train={len(train_image_refs)}, Val={len(val_image_refs)}, Test={len(test_image_refs)} images.")
    # Re-verify NUM_CLASSES from cached map
    if NUM_CLASSES != (len(category_map) + 1):
        print(f"\nWARNING: NUM_CLASSES configured ({NUM_CLASSES}) does not match cached category map ({len(category_map)} + 1 background). Using map from cache.")
        NUM_CLASSES = len(category_map) + 1

else:
    print("Generating new data splits...")
    train_vids, val_vids, test_vids = create_balanced_splits(
        list(video_frame_counts.keys()), # Only use videos we found annotations for
        video_frame_counts,
        TRAIN_RATIO,
        VAL_RATIO
    )

    # Filter all_image_references based on video ID splits
    train_image_refs, val_image_refs, test_image_refs = [], [], []
    for ref in all_image_references:
        # Extract video ID from coco_file path
        video_id = os.path.splitext(os.path.basename(ref['coco_file']))[0]
        if video_id in train_vids:
            train_image_refs.append(ref)
        elif video_id in val_vids:
            val_image_refs.append(ref)
        elif video_id in test_vids:
            test_image_refs.append(ref)

    # Cache the splits
    print(f"\nCaching splits to {SPLIT_CACHE_FILE}...")
    split_data_to_cache = {
        'train': train_image_refs,
        'val': val_image_refs,
        'test': test_image_refs,
        'category_map': category_map, # Save the category map used for this split
        'timestamp': time.time()
    }
    try:
        with open(SPLIT_CACHE_FILE, 'wb') as f:
            pickle.dump(split_data_to_cache, f)
        print("Splits cached successfully.")
    except Exception as e:
        print(f"Error caching splits: {e}")

In [None]:
# Sanity check
print(f"\nFinal split sizes (images): Train={len(train_image_refs)}, Val={len(val_image_refs)}, Test={len(test_image_refs)}")
total_split = len(train_image_refs) + len(val_image_refs) + len(test_image_refs)
print(f"Total images in splits: {total_split} (Should match total found images if all videos were assigned)")
if total_split != len(all_image_references):
    print("Warning: Total images in splits don't match initial count. Some video IDs might be missing from splits.")

In [None]:
class CocoMultiJsonDataset(Dataset):
    def __init__(self, image_references, frame_base_dir, category_map, transforms=None):
        """
        Args:
            image_references (list): List of dicts {'image_id_in_file', 'coco_file', 'img_info'}.
            frame_base_dir (str): Base directory where frame folders (01, 02...) are located.
            category_map (dict): Mapping from category name to contiguous ID (1-based).
            transforms (callable, optional): Transformations to apply.
        """
        self.image_refs = image_references
        self.frame_base_dir = frame_base_dir
        self.transforms = transforms
        self.category_map = category_map
        self.cat_id_to_contiguous_id = {} # Map original category ID from JSON to contiguous ID

        # Pre-load COCO objects or annotations per file for efficiency?
        # For moderate number of files, loading on demand might be okay.
        # If very slow, consider pre-loading into a dictionary:
        self._coco_cache = {} # Cache loaded COCO objects or just annotations

    def _get_coco_annotations(self, coco_file_path):
        """Loads annotations for a specific COCO file, caching results."""
        if coco_file_path not in self._coco_cache:
            try:
                with open(coco_file_path, 'r') as f:
                    data = json.load(f)
                # Index annotations by image_id for faster lookup
                ann_by_img_id = {}
                for ann in data.get('annotations', []):
                    img_id = ann['image_id']
                    if img_id not in ann_by_img_id:
                        ann_by_img_id[img_id] = []
                    ann_by_img_id[img_id].append(ann)

                 # Map original category IDs to contiguous IDs for this file
                local_cat_id_map = {}
                for cat in data.get('categories', []):
                    cat_name = cat['name']
                    if cat_name in self.category_map:
                        local_cat_id_map[cat['id']] = self.category_map[cat_name]
                    # else: # Should not happen if map is built correctly
                    #     print(f"Warning: Category '{cat_name}' from {coco_file_path} not in global map.")

                self._coco_cache[coco_file_path] = {
                    'annotations': ann_by_img_id,
                    'cat_id_map': local_cat_id_map
                }

            except Exception as e:
                print(f"Error loading or processing COCO file {coco_file_path}: {e}")
                self._coco_cache[coco_file_path] = {'annotations': {}, 'cat_id_map': {}} # Mark as failed

        return self._coco_cache[coco_file_path]


    def __len__(self):
        return len(self.image_refs)

    def __getitem__(self, idx):
        img_ref = self.image_refs[idx]
        img_metadata = img_ref['img_info']
        coco_file = img_ref['coco_file']
        img_id_in_file = img_ref['image_id_in_file'] # Use the ID specific to that JSON

        # Construct image path (handle potential path separators)
        # img_metadata['file_name'] might be like "03\\frame_000000.jpg" or "03/frame_000000.jpg"
        relative_img_path = img_metadata['file_name'].replace('\\', os.sep).replace('/', os.sep)
        img_path = os.path.join(self.frame_base_dir, relative_img_path)

        try:
            # Load image
            image = Image.open(img_path).convert("RGB")
        except FileNotFoundError:
             print(f"Error: Image file not found at {img_path}")
             # Return dummy data or raise error, depending on desired behavior
             # Returning dummy data to avoid crashing DataLoader completely
             w, h = img_metadata.get('width', 64), img_metadata.get('height', 64) # Use metadata size if possible
             return T.ToTensor()(Image.new('RGB', (w, h))), {'boxes': torch.empty((0, 4)), 'labels': torch.empty(0, dtype=torch.int64)}


        # Load annotations for this image from its corresponding COCO file
        coco_data = self._get_coco_annotations(coco_file)
        annotations = coco_data['annotations'].get(img_id_in_file, [])
        local_cat_id_map = coco_data['cat_id_map']

        boxes = []
        labels = []
        areas = []
        iscrowd = []

        for ann in annotations:
            # Convert COCO bbox [xmin, ymin, width, height] to [xmin, ymin, xmax, ymax]
            xmin = ann['bbox'][0]
            ymin = ann['bbox'][1]
            xmax = xmin + ann['bbox'][2]
            ymax = ymin + ann['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])

            # Map original category ID to contiguous ID
            original_cat_id = ann['category_id']
            if original_cat_id in local_cat_id_map:
                 labels.append(local_cat_id_map[original_cat_id])
            else:
                 # Should we skip this annotation or assign a default? Skipping for now.
                 # print(f"Warning: Annotation with unknown category ID {original_cat_id} in image {img_id_in_file} from {coco_file}")
                 boxes.pop() # Remove the corresponding box
                 continue # Skip this annotation

            areas.append(ann.get('area', ann['bbox'][2] * ann['bbox'][3]))
            iscrowd.append(ann.get('iscrowd', 0))

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.uint8)
        image_id = torch.tensor([idx]) # Use the index in the *current dataset split* as the image_id for evaluation

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id # Important for COCO eval
        target["area"] = areas
        target["iscrowd"] = iscrowd

        # Apply transformations
        if self.transforms is not None:
            # The new transforms handle image and target dicts together
            image, target = self.transforms(image, target)

        # Ensure boxes are valid after transforms (e.g., width/height > 0)
        if "boxes" in target:
             boxes = target['boxes']
             if boxes.shape[0] > 0:
                 # Calculate width and height
                 widths = boxes[:, 2] - boxes[:, 0]
                 heights = boxes[:, 3] - boxes[:, 1]
                 # Keep only boxes with positive width and height
                 keep = (widths > 0) & (heights > 0)
                 if not torch.all(keep):
                     # print(f"Warning: Found invalid boxes after transform for image index {idx}. Filtering.")
                     target['boxes'] = target['boxes'][keep]
                     target['labels'] = target['labels'][keep]
                     if 'area' in target: target['area'] = target['area'][keep]
                     if 'iscrowd' in target: target['iscrowd'] = target['iscrowd'][keep]


        return image, target

In [None]:
def get_transform(train):
    transforms = []

    # --- Add this line ---
    # Explicitly convert PIL Image to PyTorch Tensor (dtype=uint8, range [0, 255])
    transforms.append(T.PILToTensor())
    # --- End of Addition ---

    if train:
        # Standard augmentation for detection
        transforms.append(T.RandomHorizontalFlip(p=0.5))
        # Add other augmentations if needed (e.g., color jitter, geometric)
        # transforms.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1))

    # Convert tensor dtype to float and scale to [0, 1]
    transforms.append(T.ToDtype(torch.float, scale=True)) # scale=True divides by 255

    # Ensure output is pure tensor (might be slightly redundant after ToDtype, but good practice)
    transforms.append(T.ToPureTensor())

    return T.Compose(transforms)

In [None]:
dataset_train = CocoMultiJsonDataset(train_image_refs, FRAME_BASE_DIR, category_map, get_transform(train=True))
dataset_val = CocoMultiJsonDataset(val_image_refs, FRAME_BASE_DIR, category_map, get_transform(train=False))
dataset_test = CocoMultiJsonDataset(test_image_refs, FRAME_BASE_DIR, category_map, get_transform(train=False))

print(f"\nDataset sizes: Train={len(dataset_train)}, Val={len(dataset_val)}, Test={len(dataset_test)}")

In [None]:
def collate_fn(batch):
    # Standard collate function for detection: returns tuple(list of images, list of targets)
    return tuple(zip(*batch))

In [None]:
data_loader_train = DataLoader(
    dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
    collate_fn=collate_fn, pin_memory=torch.cuda.is_available() # Use pin_memory with GPU
)

data_loader_val = DataLoader(
    dataset_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,
    collate_fn=collate_fn, pin_memory=torch.cuda.is_available()
)

data_loader_test = DataLoader(
    dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,
    collate_fn=collate_fn, pin_memory=torch.cuda.is_available()
)

print(f"\nDataLoaders created with Batch Size={BATCH_SIZE}, Num Workers={NUM_WORKERS}")

In [None]:
def get_faster_rcnn_model(num_classes):
    # Load a model pre-trained on COCO
    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights)

    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # Replace the pre-trained head with a new one
    # num_classes includes the background class
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model, weights # Return weights to get transforms later if needed

In [None]:
model, model_weights = get_faster_rcnn_model(NUM_CLASSES)
model.to(DEVICE)

print("Faster R-CNN model loaded and modified for custom classes.")

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=0.9, weight_decay=WEIGHT_DECAY)

# Alternative: AdamW
# optimizer = torch.optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [None]:
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_GAMMA)

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=50):
    model.train() # Set model to training mode
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = f'Epoch: [{epoch}]'

    lr_scheduler_val = None # Placeholder, not step-based here

    for i, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]

        # Forward pass
        try:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            # Reduce losses over all GPUs for logging purposes if using DistributedDataParallel
            # (Not needed for single GPU/CPU)
            loss_dict_reduced = loss_dict # Assume single device
            losses_reduced = losses

            loss_value = losses_reduced.item()

            if not math.isfinite(loss_value):
                print(f"Loss is {loss_value}, stopping training")
                print(loss_dict_reduced)
                # Consider saving state or debugging here
                return None # Indicate failure

            # Backward pass and optimization
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        except Exception as e:
            print(f"Error during training iteration {i}: {e}")
            # Optionally: Save data that caused the error for debugging
            # torch.save({'images': images, 'targets': targets}, 'error_batch.pt')
            print("Skipping this batch.")
            # Ensure optimizer state doesn't get corrupted if possible
            optimizer.zero_grad() # Clear gradients from potential partial backward pass
            continue # Skip to next batch

    return metric_logger

In [None]:
@torch.inference_mode() # More efficient than torch.no_grad() for inference
def evaluate(model, data_loader, device):
    model.eval() # Set model to evaluation mode
    metric_logger = MetricLogger(delimiter="  ")
    header = 'Validate:'
    coco_gt = None # Ground truth COCO object (built from validation set)
    coco_predictions = [] # List to store predictions in COCO format

    # Check if pycocotools is available
    use_coco_eval = COCO is not None and COCOeval is not None

    if use_coco_eval:
        # Need to build a temporary COCO ground truth object from our validation dataset
        # This is a bit inefficient but necessary for pycocotools
        print("Building temporary COCO ground truth for validation evaluation...")
        val_gt_data = {'images': [], 'annotations': [], 'categories': []}
        # Use the actual category map
        for name, id_val in category_map.items():
             val_gt_data['categories'].append({'id': id_val, 'name': name, 'supercategory': 'object'})

        ann_id_counter = 1
        img_id_map = {} # Map dataset index to a unique image id for COCO eval
        current_img_id = 0

        for i, (_, targets_batch) in enumerate(tqdm(data_loader, desc="Building Val GT")):
             for targets in targets_batch: # Process each target dict in the batch
                 dataset_img_idx = targets['image_id'].item() # Get original dataset index
                 if dataset_img_idx not in img_id_map:
                     img_id_map[dataset_img_idx] = current_img_id
                     # Find original image info (less efficient this way, could store in dataset item)
                     original_ref = data_loader.dataset.image_refs[dataset_img_idx]
                     img_info = original_ref['img_info']
                     val_gt_data['images'].append({
                         'id': current_img_id,
                         'width': img_info.get('width', 0), # Get dimensions if available
                         'height': img_info.get('height', 0),
                         'file_name': img_info.get('file_name', f'img_{current_img_id}')
                     })
                     current_img_id += 1

                 coco_img_id = img_id_map[dataset_img_idx]

                 boxes = targets['boxes'].cpu().numpy()
                 labels = targets['labels'].cpu().numpy()
                 areas = targets.get('area', torch.zeros(len(boxes))).cpu().numpy() # Handle missing area

                 for j in range(boxes.shape[0]):
                     bbox = boxes[j]
                     # Convert [xmin, ymin, xmax, ymax] back to [xmin, ymin, width, height]
                     coco_bbox = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]
                     val_gt_data['annotations'].append({
                         'id': ann_id_counter,
                         'image_id': coco_img_id,
                         'category_id': labels[j],
                         'bbox': coco_bbox,
                         'area': areas[j],
                         'iscrowd': targets['iscrowd'][j].cpu().item(),
                     })
                     ann_id_counter += 1

        # Create COCO object from the built ground truth
        if val_gt_data['annotations']:
            coco_gt = COCO()
            coco_gt.dataset = val_gt_data
            coco_gt.createIndex()
            print("Validation ground truth COCO object created.")
        else:
            print("Warning: No ground truth annotations found for validation set. COCO evaluation will be skipped.")
            use_coco_eval = False


    # --- Actual Evaluation Loop ---
    val_loss_total = 0.0
    val_batches = 0
    print("\nRunning validation inference...")
    for images, targets in metric_logger.log_every(data_loader, 50, header):
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]

        # Calculate validation loss (optional but good practice)
        # Need torch.no_grad() or torch.inference_mode() if model is not already in eval mode
        # But if model IS in eval mode, it doesn't return losses by default
        # Solution: Temporarily switch to train mode WITH inference_mode context
        model.train() # Temp switch to get losses
        with torch.inference_mode(): # Still disable gradients
            loss_dict = model(images, targets)
        model.eval() # Switch back to eval for predictions
        losses = sum(loss for loss in loss_dict.values())
        val_loss_total += losses.item()
        val_batches += 1

        # Get predictions
        outputs = model(images)
        outputs = [{k: v.to(torch.device('cpu')) for k, v in t.items()} for t in outputs]

        # Format predictions for COCOeval
        if use_coco_eval and coco_gt:
            # Map dataset index back to the temporary COCO image ID
            res = []
            for i, output in enumerate(outputs):
                original_dataset_idx = targets[i]['image_id'].item()
                if original_dataset_idx in img_id_map:
                     coco_image_id = img_id_map[original_dataset_idx]
                     for box, label, score in zip(output['boxes'], output['labels'], output['scores']):
                         if score > 0.05: # Confidence threshold for evaluation
                             # Convert [xmin, ymin, xmax, ymax] to [xmin, ymin, width, height]
                             bbox = box.tolist()
                             coco_bbox = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]
                             res.append({
                                 "image_id": coco_image_id,
                                 "category_id": label.item(),
                                 "bbox": coco_bbox,
                                 "score": score.item(),
                             })
            coco_predictions.extend(res)

        # Log validation loss if calculated
        metric_logger.update(val_loss=losses)

    # --- COCO Evaluation ---
    eval_summary = None
    if use_coco_eval and coco_gt and coco_predictions:
        print("\nRunning COCO evaluation on validation predictions...")
        try:
            coco_dt = coco_gt.loadRes(coco_predictions)
            coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
            coco_eval.evaluate()
            coco_eval.accumulate()
            coco_eval.summarize()
            eval_summary = coco_eval.stats # Store summary stats (mAP etc)
        except Exception as e:
            print(f"Error during COCO evaluation: {e}")
            eval_summary = None
    elif use_coco_eval:
         print("Skipping COCO evaluation: Ground truth or predictions missing.")


    # Gather all stats from metric_logger
    metric_logger.synchronize_between_processes()
    print('Validation Result: {}'.format(metric_logger))

    avg_val_loss = val_loss_total / val_batches if val_batches > 0 else float('inf')
    print(f"Average Validation Loss: {avg_val_loss:.4f}")

    return avg_val_loss, eval_summary # Return loss and coco stats

In [None]:
import collections
import datetime
import pickle
import time

class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = collections.deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        # Placeholder for distributed training, does nothing here
        return

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count if self.count > 0 else float('nan')

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        if self.count == 0: return "N/A" # Handle empty case
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = collections.defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            if not isinstance(v, (float, int)):
                raise TypeError(
                    "This logger accepts only int or float values but received {} of type {}".format(v, type(v))
                )
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append("{}: {}".format(name, str(meter)))
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))


In [None]:
print("\n--- Starting Training ---")
start_training_time = time.time()
best_val_metric = float('inf') # Lower loss is better
best_val_map = 0.0 # Higher mAP is better (if using COCO eval)
best_model_path = os.path.join(MODEL_SAVE_DIR, "fasterrcnn_best.pth")
use_map_for_best = COCO is not None # Use mAP if pycocotools is available

# Store history
history = {'train_loss': [], 'val_loss': [], 'val_map': []}

for epoch in range(NUM_EPOCHS):
    print(f"\n===== Epoch {epoch+1}/{NUM_EPOCHS} =====")

    # Train for one epoch
    train_logger = train_one_epoch(model, optimizer, data_loader_train, DEVICE, epoch)
    if train_logger is None: # Check if training failed
         print("Stopping training due to high loss.")
         break

    # Update the learning rate
    lr_scheduler.step()

    # Evaluate on the validation set
    avg_val_loss, coco_stats = evaluate(model, data_loader_val, DEVICE)

    # Store history (use global averages from loggers)
    history['train_loss'].append(train_logger.meters['loss'].global_avg)
    history['val_loss'].append(avg_val_loss)
    val_map_50_95 = coco_stats[0] if coco_stats is not None else 0.0 # mAP @ IoU=0.50:0.95
    history['val_map'].append(val_map_50_95)

    print(f"Epoch {epoch+1} Summary: Train Loss={train_logger.meters['loss'].global_avg:.4f}, Val Loss={avg_val_loss:.4f}, Val mAP@.5-.95={val_map_50_95:.4f}")

    # --- Save the best model ---
    current_metric = val_map_50_95 if use_map_for_best else avg_val_loss
    is_better = (current_metric > best_val_map) if use_map_for_best else (current_metric < best_val_metric)

    if is_better:
        if use_map_for_best:
            print(f"Validation mAP improved ({best_val_map:.4f} --> {current_metric:.4f}). Saving model...")
            best_val_map = current_metric
            best_val_metric = avg_val_loss # Still track best loss when map improves
        else:
            print(f"Validation loss improved ({best_val_metric:.4f} --> {current_metric:.4f}). Saving model...")
            best_val_metric = current_metric

        try:
            # Save model state dictionary and other useful info
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                'best_val_loss': best_val_metric,
                'best_val_map': best_val_map,
                'num_classes': NUM_CLASSES,
                'category_map': category_map # Save category map with model
            }, best_model_path)
            print(f"Best model saved to {best_model_path}")
        except Exception as e:
            print(f"Error saving model: {e}")
    else:
         print("Validation metric did not improve.")


total_training_time = time.time() - start_training_time
print(f"\n--- Training Finished ---")
print(f"Total Training Time: {str(datetime.timedelta(seconds=int(total_training_time)))}")
print(f"Best model saved at epoch {torch.load(best_model_path)['epoch']+1 if os.path.exists(best_model_path) else 'N/A'} with Validation Loss: {best_val_metric:.4f} and mAP: {best_val_map:.4f}")

In [None]:
print("\n--- Evaluating Best Model on Test Set ---")
best_model_path = os.path.join(MODEL_SAVE_DIR, "fasterrcnn_best.pth")

if not os.path.exists(best_model_path):
    print("Error: Best model file not found. Skipping test evaluation.")
else:
    print(f"Loading best model from: {best_model_path}")
    checkpoint = torch.load(best_model_path, map_location=DEVICE)

    # Ensure NUM_CLASSES matches the saved model
    saved_num_classes = checkpoint.get('num_classes', NUM_CLASSES) # Default to config if not saved
    if saved_num_classes != NUM_CLASSES:
        print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) differs from saved model ({saved_num_classes}). Using saved value.")
        NUM_CLASSES = saved_num_classes

    eval_model, _ = get_faster_rcnn_model(NUM_CLASSES) # Create model structure
    eval_model.load_state_dict(checkpoint['model_state_dict'])
    eval_model.to(DEVICE)
    eval_model.eval() # Set to evaluation mode

    print("Best model loaded successfully.")

    # --- Evaluate on Test Set ---
    # We use the same evaluate function but pass the test data loader
    # It will build the ground truth COCO object for the test set this time
    if COCO is None or COCOeval is None:
         print("pycocotools not available. Skipping COCO evaluation on test set.")
    else:
         print("\nRunning evaluation on the Test Set...")
         _, test_coco_stats = evaluate(eval_model, data_loader_test, DEVICE)

         if test_coco_stats:
             print("\n--- Test Set Evaluation Summary (COCO Metrics) ---")
             print(f"mAP @ IoU=0.50:0.95 | area=all | maxDets=100: {test_coco_stats[0]:.4f}")
             print(f"mAP @ IoU=0.50      | area=all | maxDets=100: {test_coco_stats[1]:.4f}")
             print(f"mAP @ IoU=0.75      | area=all | maxDets=100: {test_coco_stats[2]:.4f}")
             print(f"mAP @ IoU=0.50:0.95 | area=small | maxDets=100: {test_coco_stats[3]:.4f}")
             print(f"mAP @ IoU=0.50:0.95 | area=medium| maxDets=100: {test_coco_stats[4]:.4f}")
             print(f"mAP @ IoU=0.50:0.95 | area=large | maxDets=100: {test_coco_stats[5]:.4f}")
             # ... and so on for AR stats if needed (indices 6-11)
         else:
             print("COCO evaluation on test set could not be completed.")

In [None]:
def plot_image_with_boxes(image_tensor, gt_boxes=None, pred_boxes=None, pred_scores=None, pred_labels=None, gt_labels=None, category_map=None, score_threshold=0.5, title="Image"):
    """Helper function to plot image with bounding boxes."""
    if image_tensor.is_cuda:
        image_tensor = image_tensor.cpu()
    if gt_boxes is not None and gt_boxes.is_cuda: gt_boxes = gt_boxes.cpu()
    if pred_boxes is not None and pred_boxes.is_cuda: pred_boxes = pred_boxes.cpu()
    if pred_scores is not None and pred_scores.is_cuda: pred_scores = pred_scores.cpu()
    if pred_labels is not None and pred_labels.is_cuda: pred_labels = pred_labels.cpu()
    if gt_labels is not None and gt_labels.is_cuda: gt_labels = gt_labels.cpu()


    # Convert tensor image back to PIL format for plotting
    img = T.ToPILImage()(image_tensor)
    plt.figure(figsize=(12, 8))
    plt.imshow(img)
    ax = plt.gca()
    plt.title(title)

    # Reverse category map {id: name}
    id_to_name = {v: k for k, v in category_map.items()} if category_map else {}

    # Plot Ground Truth boxes (Green)
    if gt_boxes is not None and len(gt_boxes) > 0:
        for i, box in enumerate(gt_boxes):
            xmin, ymin, xmax, ymax = box
            rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='g', facecolor='none')
            ax.add_patch(rect)
            label_id = gt_labels[i].item() if gt_labels is not None else 0
            label_name = id_to_name.get(label_id, f'ID:{label_id}')
            plt.text(xmin, ymin - 5, f'GT: {label_name}', color='g', fontsize=9, bbox=dict(facecolor='white', alpha=0.5, pad=0))

    # Plot Predicted boxes (Red)
    if pred_boxes is not None and len(pred_boxes) > 0:
        for i, box in enumerate(pred_boxes):
            score = pred_scores[i].item() if pred_scores is not None else 1.0
            if score >= score_threshold:
                xmin, ymin, xmax, ymax = box
                rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='r', facecolor='none')
                ax.add_patch(rect)
                label_id = pred_labels[i].item() if pred_labels is not None else 0
                label_name = id_to_name.get(label_id, f'ID:{label_id}')
                plt.text(xmax, ymin - 5, f'Pred: {label_name} ({score:.2f})', color='r', fontsize=9, ha='right', bbox=dict(facecolor='white', alpha=0.5, pad=0))

    plt.axis('off')
    plt.show()


In [None]:
if 'eval_model' in locals(): # Check if model was loaded
    print("\nVisualizing some Test Set predictions...")
    num_samples_to_show = 5000
    eval_model.eval()

    with torch.no_grad():
        for i, (images, targets) in enumerate(data_loader_test):
            if i >= num_samples_to_show:
                break

            images_device = list(img.to(DEVICE) for img in images)
            outputs = eval_model(images_device)

            # Plot the first image of the batch
            img_idx = 0
            image_tensor = images[img_idx] # Original tensor before moving to device
            target = targets[img_idx]
            output = outputs[img_idx]

            plot_image_with_boxes(
                image_tensor,
                gt_boxes=target['boxes'],
                gt_labels=target['labels'],
                pred_boxes=output['boxes'].cpu(),
                pred_scores=output['scores'].cpu(),
                pred_labels=output['labels'].cpu(),
                category_map=category_map, # Use the loaded/cached map
                score_threshold=0.5,
                title=f"Test Image {i*BATCH_SIZE + img_idx}"
            )

else:
    print("Best model not loaded. Skipping visualization.")