In [1]:
import os
import typing as ty
from itertools import islice

import cv2
import pandas as pd
import numpy as np


def batched_custom(iterable, n):
    """
    Batches an iterable into chunks of size n.
    Equivalent to itertools.batched in Python 3.12+.
    """
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch


class AnnotationDict(ty.TypedDict):
    image: str
    boxes: list[int]
    key_points: list[list[int]]


class AnnotationContainer:
    def __init__(self, base_dir: str, anno_file_path: str, parse_kps: bool = False) -> None:
        self.parse_kps = parse_kps
        self.base_dir = base_dir
        with open(anno_file_path, "r") as tf:
            self.content = tf.read()
        self.__parse()

    def __parse(self) -> None:
        self.meta = []
        self.images = []
        self.labels = []

        samples_as_text = self.content.split("# ")[1:] # We skip 1 element since it's an empty string.
        samples_as_text = [sample.strip().split("\n") for sample in samples_as_text]
        for sample in samples_as_text:
            image_meta, *boxes_str = sample
            
            name, height, width = image_meta.split(" ")
            image_path = os.path.join(*name.split("/")) # This should be correct for Windows and Unix.
            self.images.append(image_path)
            self.meta.append((height, width))

            labels = {"boxes": [], "key_points": []}
            for i, point_set in enumerate(boxes_str):
                coords = list(map(float, point_set.strip().split(" ")))

                # Box is a first 4 coordinates in top_left_x, top_left_y, bottom_right_x, bottom_right_y format.
                box = list(map(int, coords[:4]))

                # Artificallly add class label to box.
                box = [0, *box]
                if any(coord < 0 for coord in box):
                    msg = f"Image {image_path} has box with negative coords: {box}"
                    raise ValueError(msg)
                labels["boxes"].append(box)

                # Key points are the rest points. It should be 5 in total, 3 components each (x, y, visibility flag).
                if self.parse_kps:
                    kps = []
                    for point in batched_custom(coords[4:], 3):
                        if all(coord == -1 for coord in point):
                            kps.append([0.0, 0.0, 0.0])
                        else:
                            point = list(point)
                            point[-1] = 1.0
                            kps.append(point)
                    if len(kps) != 5:
                        msg = f"Image {image_path} has more or less than 5 kps: {kps}"
                        raise ValueError(msg)
                    labels["key_points"].append(kps)
                    
            self.labels.append(labels)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index: int) -> AnnotationDict:
        image = self.images[index]
        label = self.labels[index]
        key_points = label["key_points"]
        boxes = label["boxes"]
        return {"image": image, "boxes": boxes, "key_points": key_points}

    def __iter__(self):
        for image_path, label in zip(self.images, self.labels):
            yield {"image": image_path, **label}
            

DATA_PATH = os.path.join(os.getcwd(), "data", "widerface")
dataset_df = None
for subset in ("train", "val"):
    anno_file_path = os.path.join(DATA_PATH, "labelv2", subset, "labelv2.txt")
    image_dir = os.path.join(DATA_PATH, f"WIDER_{subset}", f"WIDER_{subset}", "images")
    container = AnnotationContainer(image_dir, anno_file_path, subset == "train")
    dataframe = pd.DataFrame(data=container)
    dataframe["subset"] = subset
    if dataset_df is None:
        dataset_df = dataframe
    else:
        dataset_df = pd.concat((dataset_df, dataframe))

assert dataset_df is not None
train_df = dataset_df.query("subset == 'train'")
train_idx = train_df.sample(frac=0.8).index
val_idx = np.setdiff1d(train_df.index, train_idx)
train_df.loc[train_idx, "subset"] = "train"
train_df.loc[val_idx, "subset"] = "val"
val_df = dataset_df.query("subset == 'val'")
val_df["subset"] = "test"
dataset_df = pd.concat((train_df, val_df)).reset_index(drop=True)
dataset_df.to_csv("widerface_main_3.csv")


KeyboardInterrupt: 

In [None]:
import os

import torch
import pandas as pd
from tqdm import tqdm
from torchvision.ops import nms
from torch.utils.data import DataLoader
from torchmetrics.detection import MeanAveragePrecision

from train import read_config
from source.postprocessing import postprocess_predictions
from source.targets import generate_targets_batch
from source.dataset import build_dataloaders
from source.models.yunet import YuNet


def decode_keypoints(kps_preds: torch.Tensor, grids: torch.Tensor) -> torch.Tensor:
    num_points = kps_preds.shape[-1] // 2
    decoded = []
    for i in range(num_points):
        kp_encoded = kps_preds[:, [2 * i, 2 * i + 1]]
        kp_decoded = kp_encoded * grids[:, 2:] + grids[:, :2]
        decoded.append(kp_decoded)
    return torch.cat(decoded, dim=1)


@torch.no_grad()
def calculate_map_torchmetrics(
    model: YuNet,
    dataloader: DataLoader,
    device: torch.device,
    conf_thresh: float = 0.5,
    iou_thresh: float = 0.45,
    metric_names: tuple[str, ...] = ("map_50", "map_small", "map_medium", "map_large"),
) -> dict[str, float]:
    model.eval()
    map_calculator = MeanAveragePrecision(backend="faster_coco_eval").to(device)
    for batch in dataloader:
        images = batch["image"].to(device, non_blocking=True)
        gt_boxes = [box.to(device, non_blocking=True) for box in batch["boxes"]]
        gt_kps = [kp_set.to(device, non_blocking=True) for kp_set in batch["key_points"]]
        gt_labels = [bl.to(device, non_blocking=True) for bl in batch["box_labels"]]
        p8_out, p16_out, p32_out = model(images)
        obj_preds, cls_preds, box_preds, kp_preds, priors = postprocess_predictions(
            (p8_out, p16_out, p32_out), (8, 16, 32)
        )
        batch_size = images.shape[0]
        conf = (obj_preds.sigmoid() * cls_preds.squeeze(dim=-1).sigmoid()).sqrt()
        prep = []
        for batch_idx in range(batch_size):
            sample_conf = conf[batch_idx]
            sample_boxes = box_preds[batch_idx]
            sample_kps = kp_preds[batch_idx]
            sample_priors = priors[batch_idx]
            decoded_kps = decode_keypoints(sample_kps, sample_priors)

            sample_keep = sample_conf >= conf_thresh
            filt_conf = sample_conf[sample_keep]
            filt_boxes = sample_boxes[sample_keep]
            filt_kps = decoded_kps[sample_keep]

            keep_indices = nms(filt_boxes, filt_conf, iou_thresh)
            filt_conf = filt_conf[keep_indices]
            filt_boxes = filt_boxes[keep_indices]
            filt_kps = filt_kps[keep_indices]

            prep.append(
                {
                    "boxes": filt_boxes,
                    "scores": filt_conf.view(-1),
                    "labels": torch.zeros(len(filt_boxes), dtype=torch.long, device=device).view(-1),
                }
            )
        targ = [{"boxes": gt_boxes[i], "labels": gt_labels[i].view(-1).long()} for i in range(batch_size)]
        map_calculator.update(prep, targ)
    metrics = {name: tensor.detach().cpu().item() for name, tensor in map_calculator.compute().items()}
    return {name: value for name, value in metrics.items() if name in metric_names}


config = read_config("config.yml")
device = torch.device("cuda:0")
dataframe = pd.read_csv(config.path.csv)
dataloaders = build_dataloaders(config, dataframe)
model = YuNet(**config.model.model_dump()).to(device)
ckpt = torch.load(os.path.join("artifacts", "objective_blackburn", "checkpoints", "epoch_112_ckpt.pt"))
model.load_state_dict(ckpt["model"])
model.eval()
map_calculator = MeanAveragePrecision(backend="faster_coco_eval").to(device)
pbar = tqdm(dataloaders["val"], total=len(dataloaders["val"]))
with torch.no_grad():
    for batch in pbar:
        images = batch["image"].to(device)
        gt_boxes = [box.to(device) for box in batch["boxes"]]
        gt_kps = [kp_set.to(device) for kp_set in batch["key_points"]]
        gt_labels = [bl.to(device) for bl in batch["box_labels"]]
        p8_out, p16_out, p32_out = model(images)
        obj_preds, cls_preds, box_preds, kp_preds, priors = postprocess_predictions(
            (p8_out, p16_out, p32_out), (8, 16, 32)
        )

        batch_size = images.shape[0]
        conf = (obj_preds.sigmoid() * cls_preds.squeeze(dim=-1).sigmoid()).sqrt()
        prep = []
        for batch_idx in range(batch_size):
            sample_conf = conf[batch_idx]
            sample_boxes = box_preds[batch_idx]
            sample_kps = kp_preds[batch_idx]
            sample_priors = priors[batch_idx]
            decoded_kps = decode_keypoints(sample_kps, sample_priors)

            sample_keep = sample_conf >= 0.5
            filt_conf = sample_conf[sample_keep]
            filt_boxes = sample_boxes[sample_keep]
            filt_kps = decoded_kps[sample_keep]

            keep_indices = nms(filt_boxes, filt_conf, 0.45)
            filt_conf = filt_conf[keep_indices]
            filt_boxes = filt_boxes[keep_indices]
            filt_kps = filt_kps[keep_indices]

            prep.append(
                {
                    "boxes": filt_boxes,
                    "scores": filt_conf.view(-1),
                    "labels": torch.zeros(len(filt_boxes), dtype=torch.long, device=device).view(-1),
                }
            )
        targ = [{"boxes": gt_boxes[i], "labels": gt_labels[i].view(-1).long()} for i in range(batch_size)]
        map_calculator.update(prep, targ)
    _map = map_calculator.compute()
    print(_map)


100%|██████████| 161/161 [00:16<00:00,  9.48it/s]


{'map': tensor(0.1202), 'map_50': tensor(0.2936), 'map_75': tensor(0.0793), 'map_small': tensor(0.0807), 'map_medium': tensor(0.3921), 'map_large': tensor(0.4399), 'mar_1': tensor(0.0357), 'mar_10': tensor(0.1108), 'mar_100': tensor(0.1448), 'mar_small': tensor(0.0989), 'mar_medium': tensor(0.4388), 'mar_large': tensor(0.4650), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor(0, dtype=torch.int32)}
