# Configuration

In [1]:
import os
import json
import time
import datetime
from collections import defaultdict, deque
import torch
import torch.distributed as dist
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
import random
from tqdm import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from PIL import Image, ImageDraw  # Import ImageDraw

# Data Loading and Transformation

In [2]:
class CocoDataset(Dataset):
    def __init__(self, frame_dir, coco_dir, video_range, transform=None):
        self.frame_dir = frame_dir
        self.coco_dir = coco_dir
        self.video_range = video_range
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        data = []
        for video_num in range(self.video_range[0], self.video_range[1] + 1):
            video_name = f"video_{video_num:03d}"
            frame_subdir = os.path.join(self.frame_dir, video_name)
            coco_subdir = os.path.join(self.coco_dir, f"annotation_cvat_{video_num:03d}")

            if not os.path.exists(frame_subdir) or not os.path.exists(coco_subdir):
                print(f"Warning: Data for {video_name} not found. Skipping.")
                continue

            for coco_file in os.listdir(coco_subdir):
                if not coco_file.endswith(".json"):
                    continue

                coco_path = os.path.join(coco_subdir, coco_file)
                with open(coco_path, 'r') as f:
                    coco_data = json.load(f)

                frame_filename = coco_data['images'][0]['file_name']
                frame_path = os.path.join(frame_subdir, frame_filename)

                if not os.path.exists(frame_path):
                    print(f"Warning: Frame file {frame_filename} not found. Skipping.")
                    continue

                annotations = coco_data['annotations']
                boxes = []
                labels = []
                for ann in annotations:
                    bbox = ann['bbox']
                    boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
                    labels.append(ann['category_id'])

                if not boxes:
                    continue

                data.append({
                    'frame_path': frame_path,
                    'boxes': boxes,
                    'labels': labels
                })
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        try:
            img = Image.open(item['frame_path']).convert("RGB")
        except Exception as e:
            print(f"  ERROR loading image: {item['frame_path']}, Error: {e}")
            raise

        boxes = torch.as_tensor(item['boxes'], dtype=torch.float32)
        labels = torch.as_tensor(item['labels'], dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels

        if self.transform:
            try:
                img, target = self.transform(img, target)
            except Exception as e:
                print(f"  ERROR applying transformations: {e}")
                raise

        return img, target

In [3]:
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        return F.to_tensor(image), target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            image = F.hflip(image)
            bbox = target["boxes"]

            bbox[:, [0, 2]] = image.size[0] - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

def get_transform(train=False):
    transforms = []
    transforms.append(ToTensor())
    if train:
        transforms.append(RandomHorizontalFlip(0.5))
    return Compose(transforms)

# Training

In [4]:
def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def collate_fn(batch):
    return tuple(zip(*batch))

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    print("Training...")
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
    header = f"Epoch: [{epoch}]"

    print("Setup Metrics Completed")
    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        print("Logging images to GPU")
        images = list(image.to(device) for image in images)
        print("Logging targets to GPU")
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        print("Load images and targets in GPU to the models")
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        loss_value = losses_reduced.item()

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    return metric_logger

@torch.inference_mode()
def evaluate(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    metric_logger = MetricLogger(delimiter="  ")
    header = "images:"

    for images, targets in metric_logger.log_every(data_loader, 100, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(images)

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

class SmoothedValue:
    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value,
        )

class MetricLogger:
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(
            f"'{type(self).__name__}' object has no attribute '{attr}'"
        )

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {str(meter)}")
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ""
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt="{avg:.4f}")
        data_time = SmoothedValue(fmt="{avg:.4f}")
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
        if torch.cuda.is_available():
            log_msg = self.delimiter.join(
                [
                    header,
                    "[{0" + space_fmt + "}/{1}]",
                    "eta: {eta}",
                    "{meters}",
                    "time: {time}",
                    "data: {data}",
                    "max mem: {memory:.0f}",
                ]
            )
        else:
            log_msg = self.delimiter.join(
                [
                    header,
                    "[{0" + space_fmt + "}/{1}]",
                    "eta: {eta}",
                    "{meters}",
                    "time: {time}",
                    "data: {data}",
                ]
            )
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB,
                        )
                    )
                else:
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                        )
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(
            f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)"
        )

def reduce_dict(input_dict, average=True):
    """
    Reduce the values in the dictionary from all processes so that all processes
    have the averaged results.

    Args:
        input_dict (dict): all the values will be reduced
        average (bool): whether to do average or sum
    Reduce the values in the dictionary from all processes so that all processes
    have the averaged results. Returns a dict with the same fields as
    input_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return input_dict
    with torch.inference_mode():
        names = []
        values = []
        # sort the keys so that they are consistent across processes
        for k in sorted(input_dict.keys()):
            names.append(k)
            values.append(input_dict[k])
        values = torch.stack(values, dim=0)
        dist.all_reduce(values)
        if average:
            values /= world_size
        reduced_dict = {k: v for k, v in zip(names, values)}
    return reduced_dict

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

# Training Workflow

In [5]:
def main():
    # --- Configuration ---
    frame_directory = "./data/frame/"
    coco_annotation_directory = "./data/annotation/coco/"
    num_classes = 2
    batch_size = 4
    num_workers = 0
    num_epochs = 5
    learning_rate = 0.005
    momentum = 0.9
    weight_decay = 0.0005
    print_freq = 20

    # --- Device ---
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(f"Using device: {device}")

    # --- Datasets ---
    print("Loading dataset...")
    train_dataset = CocoDataset(frame_directory, coco_annotation_directory, (1, 7), get_transform())
    test_dataset = CocoDataset(frame_directory, coco_annotation_directory, (8, 10), get_transform())

    # --- DataLoaders ---
    print("Adding dataset to dataloaders")
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True)

    # --- TEST DATA LOADER ---
    print("Testing DataLoader...")
    try:
        for i, (images, targets) in enumerate(train_dataloader):
            print(f"Batch {i} loaded successfully.")
            print(f"  Images shape: {images[0].shape}")  # Check image shape
            print(f"  Targets: {targets}") # Check the targets
            break  # Only test the first batch
    except Exception as e:
        print(f"Error loading batch: {e}")
    print("DataLoader test complete.")
    # --- END TEST ---

    # --- Model ---
    print("Getting model from remote")
    model = get_model(num_classes)
    model.to(device)

    # --- Optimizer ---
    print("Setting up optimizer")
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

    # --- Training Loop ---
    print("Training...")
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        train_one_epoch(model, optimizer, train_dataloader, device, epoch, print_freq=print_freq)
        # Evaluate after each epoch (optional, but good practice)
        evaluate(model, test_dataloader, device)

    print("Training finished!")

    # --- Save Model (optional) ---
    torch.save(model.state_dict(), "./output/faster_rcnn_model_epoch.pth")
    print("Model saved to output/faster_rcnn_model_epoch.pth")

In [None]:
if __name__ == "__main__":
    main()

# Model Evaluation

In [17]:
class CocoDataset (Dataset):
    def __init__ (self, frame_dir, coco_dir, video_range, transform = None):
        self.frame_dir = frame_dir
        self.coco_dir = coco_dir
        self.video_range = video_range
        self.transform = transform
        self.data = self._load_data ()

    def _load_data (self):
        data = []
        for video_num in range (self.video_range[0], self.video_range[1] + 1):
            video_name = f"video_{video_num:03d}"
            frame_subdir = os.path.join (self.frame_dir, video_name)
            coco_subdir = os.path.join (self.coco_dir, f"annotation_cvat_{video_num:03d}")

            if not os.path.exists (frame_subdir) or not os.path.exists (coco_subdir):
                print (f"Warning: Data for {video_name} not found. Skipping.")
                continue

            for coco_file in os.listdir (coco_subdir):
                if not coco_file.endswith (".json"):
                    continue

                coco_path = os.path.join (coco_subdir, coco_file)
                with open (coco_path, 'r') as f:
                    coco_data = json.load (f)

                # Extract frame filename and remove extension
                frame_filename = coco_data['images'][0]['file_name']
                frame_path = os.path.join (frame_subdir, frame_filename)

                if not os.path.exists (frame_path):
                    print (f"Warning: Frame file {frame_filename} not found. Skipping.")
                    continue

                img = Image.open(frame_path) # Load here, to make it accessible to bounding box corrections.
                img_width, img_height = img.size

                annotations = coco_data['annotations']
                boxes = []
                labels = []
                for ann in annotations:
                    bbox = ann['bbox']
                    boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
                    labels.append (ann['category_id'])

                if not boxes:  # Skip images without annotations
                    continue
                image_id = coco_data['images'][0]['id']

                data.append ({
                    'frame_path': frame_path,
                    'boxes': boxes,
                    'labels': labels,
                    'image_id': image_id,
                })
        return data

    def __len__ (self):
        return len (self.data)

    def __getitem__ (self, idx):
        item = self.data[idx]
        img = Image.open(item['frame_path']).convert("RGB")
        boxes = torch.as_tensor(item['boxes'], dtype=torch.float32)
        labels = torch.as_tensor(item['labels'], dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])  # Add image_id here!

        if self.transform:
            img, target = self.transform(img, target)

        return img, target

In [7]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision

In [18]:
@torch.inference_mode()
def evaluate(model, data_loader, device):
    model.eval()

    metric = MeanAveragePrecision(box_format="xyxy", iou_type="bbox", class_metrics=True)

    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="Evaluating"):
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            outputs = model(images)

            preds = []
            for output in outputs:
                preds.append(
                    {
                        "boxes": output["boxes"].cpu(),
                        "scores": output["scores"].cpu(),
                        "labels": output["labels"].cpu(),
                    }
                )

            targets_formatted = []
            for target in targets:
                targets_formatted.append(
                    {
                        "boxes": target["boxes"].cpu(),
                        "labels": target["labels"].cpu(),
                    }
                )

            metric.update(preds, targets_formatted)

    result = metric.compute()
    print(f"TorchMetrics result inside evaluate: {result}") #Add this line
    return result

In [None]:
if __name__ == "__main__":
    # --- Configuration ---
    frame_directory = "./data/frame/"
    coco_annotation_directory = "./data/annotation/coco/"
    num_classes = 2
    model_path = "./output/faster_rcnn_model_5_epoch.pth"
    batch_size = 4
    num_workers = 0

    # --- Device ---
    device = torch.device('cuda')

    # --- Datasets ---
    #  Only use the test dataset for evaluation
    test_dataset = CocoDataset(frame_directory, coco_annotation_directory, (8, 11), get_transform())

    # --- DataLoaders ---
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)

    # --- Model ---
    model = get_model(num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))  # Load the saved weights
    model.to(device)

    # --- Evaluation ---
    print("Evaluating performance...")
    print(evaluate(model, test_dataloader, device))

In [None]:
import matplotlib.pyplot as plt

def visualize_predictions(image, predictions, threshold=0.5):
    """
    Visualizes predicted bounding boxes on an image.

    Args:
        image: PIL Image.
        predictions: Dictionary with 'boxes', 'labels', and 'scores' keys.
        threshold: Confidence threshold for displaying boxes.
    """
    draw = ImageDraw.Draw(image)
    for box, label, score in zip(predictions['boxes'], predictions['labels'], predictions['scores']):
        if score >= threshold:
            box = box.cpu().numpy()  # Move to CPU and convert to NumPy
            label = label.cpu().numpy()
            score = score.cpu().numpy()

            x1, y1, x2, y2 = box
            draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
            draw.text((x1, y1 - 10), f"Label: {label}, Score: {score:.2f}", fill="red")

    plt.imshow(image)
    plt.axis('off')
    plt.show()

index = 1750

test_dataset = CocoDataset(frame_directory, coco_annotation_directory, (11, 11), get_transform())

image, target = test_dataset[index]  # Get the first image and target
image = image.to(device)
# target = {k: v.to(device) for k,v in target.items()} # No need to send target to device for inference

# Run inference
with torch.no_grad():  # Disable gradient calculations during inference
    predictions = model([image])[0]  # Pass a list containing the image, get first element of result


# Visualize the predictions
original_image = Image.open(test_dataset.data[index]['frame_path']).convert("RGB") # Load PIL Image
visualize_predictions(original_image, predictions, threshold=0.5) #Visualize with threshold