In [1]:
import os
import typing as ty
from itertools import islice

import cv2
import pandas as pd
import numpy as np


def batched_custom(iterable, n):
    """
    Batches an iterable into chunks of size n.
    Equivalent to itertools.batched in Python 3.12+.
    """
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch


class AnnotationDict(ty.TypedDict):
    image: str
    boxes: list[int]
    key_points: list[list[int]]


class AnnotationContainer:
    def __init__(self, base_dir: str, anno_file_path: str, parse_kps: bool = False) -> None:
        self.parse_kps = parse_kps
        self.base_dir = base_dir
        with open(anno_file_path, "r") as tf:
            self.content = tf.read()
        self.__parse()

    def __parse(self) -> None:
        self.meta = []
        self.images = []
        self.labels = []

        samples_as_text = self.content.split("# ")[1:] # We skip 1 element since it's an empty string.
        samples_as_text = [sample.strip().split("\n") for sample in samples_as_text]
        for sample in samples_as_text:
            image_meta, *boxes_str = sample
            
            name, height, width = image_meta.split(" ")
            image_path = os.path.join(*name.split("/")) # This should be correct for Windows and Unix.
            self.images.append(image_path)
            self.meta.append((height, width))

            labels = {"boxes": [], "key_points": []}
            for i, point_set in enumerate(boxes_str):
                coords = list(map(float, point_set.strip().split(" ")))

                # Box is a first 4 coordinates in top_left_x, top_left_y, bottom_right_x, bottom_right_y format.
                box = list(map(int, coords[:4]))
                if any(coord < 0 for coord in box):
                    msg = f"Image {image_path} has box with negative coords: {box}"
                    raise ValueError(msg)
                labels["boxes"].append(box)

                # Key points are the rest points. It should be 5 in total, 3 components each (x, y, ...). I don't know
                # what is 3rd component.
                if self.parse_kps:
                    kps = []
                    for point in batched_custom(coords[4:], 3):
                        kps.append(list(point[:2]))
                    if len(kps) != 5:
                        msg = f"Image {image_path} has more or less than 5 kps: {kps}"
                        raise ValueError(msg)
                    labels["key_points"].append(kps)
                    
            self.labels.append(labels)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index: int) -> AnnotationDict:
        image = self.images[index]
        label = self.labels[index]
        key_points = label["key_points"]
        return {"image": image, "boxes": boxes, "key_points": key_points}

    def __iter__(self):
        for image_path, label in zip(self.images, self.labels):
            yield {"image": image_path, **label}
            

DATA_PATH = os.path.join(os.getcwd(), "data", "widerface")
dataset_df = None
for subset in ("train", "val"):
    anno_file_path = os.path.join(DATA_PATH, "labelv2", subset, "labelv2.txt")
    image_dir = os.path.join(DATA_PATH, f"WIDER_{subset}", f"WIDER_{subset}", "images")
    container = AnnotationContainer(image_dir, anno_file_path, subset == "train")
    dataframe = pd.DataFrame(data=container)
    dataframe["subset"] = subset
    if dataset_df is None:
        dataset_df = dataframe
    else:
        dataset_df = pd.concat((dataset_df, dataframe))
dataset_df.head()

Unnamed: 0,image,boxes,key_points,subset
0,0--Parade\0_Parade_marchingband_1_849.jpg,"[[449, 330, 571, 479]]","[[[488.90601, 373.64301], [542.08899, 376.4419...",train
1,0--Parade\0_Parade_Parade_0_904.jpg,"[[361, 98, 624, 437]]","[[[424.14301, 251.65601], [547.13397, 232.571]...",train
2,0--Parade\0_Parade_marchingband_1_799.jpg,"[[78, 221, 85, 229], [78, 238, 92, 255], [113,...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
3,0--Parade\0_Parade_marchingband_1_117.jpg,"[[69, 359, 119, 395], [227, 382, 283, 425], [2...","[[[92.232, 391.397], [94.451, 377.45099], [103...",train
4,0--Parade\0_Parade_marchingband_1_778.jpg,"[[27, 226, 60, 262], [63, 95, 79, 114], [64, 6...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train


In [2]:
dataset_df["key_points"]

0       [[[488.90601, 373.64301], [542.08899, 376.4419...
1       [[[424.14301, 251.65601], [547.13397, 232.571]...
2       [[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...
3       [[[92.232, 391.397], [94.451, 377.45099], [103...
4       [[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...
                              ...                        
3221                                                   []
3222                                                   []
3223                                                   []
3224                                                   []
3225                                                   []
Name: key_points, Length: 16106, dtype: object

In [3]:
dataset_df

Unnamed: 0,image,boxes,key_points,subset
0,0--Parade\0_Parade_marchingband_1_849.jpg,"[[449, 330, 571, 479]]","[[[488.90601, 373.64301], [542.08899, 376.4419...",train
1,0--Parade\0_Parade_Parade_0_904.jpg,"[[361, 98, 624, 437]]","[[[424.14301, 251.65601], [547.13397, 232.571]...",train
2,0--Parade\0_Parade_marchingband_1_799.jpg,"[[78, 221, 85, 229], [78, 238, 92, 255], [113,...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
3,0--Parade\0_Parade_marchingband_1_117.jpg,"[[69, 359, 119, 395], [227, 382, 283, 425], [2...","[[[92.232, 391.397], [94.451, 377.45099], [103...",train
4,0--Parade\0_Parade_marchingband_1_778.jpg,"[[27, 226, 60, 262], [63, 95, 79, 114], [64, 6...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
...,...,...,...,...
3221,9--Press_Conference\9_Press_Conference_Press_C...,"[[334, 182, 634, 582]]",[],val
3222,9--Press_Conference\9_Press_Conference_Press_C...,"[[316, 224, 586, 571]]",[],val
3223,9--Press_Conference\9_Press_Conference_Press_C...,"[[332, 172, 626, 544]]",[],val
3224,9--Press_Conference\9_Press_Conference_Press_C...,"[[336, 242, 488, 444], [712, 278, 838, 430]]",[],val


In [4]:
dataset_df

Unnamed: 0,image,boxes,key_points,subset
0,0--Parade\0_Parade_marchingband_1_849.jpg,"[[449, 330, 571, 479]]","[[[488.90601, 373.64301], [542.08899, 376.4419...",train
1,0--Parade\0_Parade_Parade_0_904.jpg,"[[361, 98, 624, 437]]","[[[424.14301, 251.65601], [547.13397, 232.571]...",train
2,0--Parade\0_Parade_marchingband_1_799.jpg,"[[78, 221, 85, 229], [78, 238, 92, 255], [113,...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
3,0--Parade\0_Parade_marchingband_1_117.jpg,"[[69, 359, 119, 395], [227, 382, 283, 425], [2...","[[[92.232, 391.397], [94.451, 377.45099], [103...",train
4,0--Parade\0_Parade_marchingband_1_778.jpg,"[[27, 226, 60, 262], [63, 95, 79, 114], [64, 6...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
...,...,...,...,...
3221,9--Press_Conference\9_Press_Conference_Press_C...,"[[334, 182, 634, 582]]",[],val
3222,9--Press_Conference\9_Press_Conference_Press_C...,"[[316, 224, 586, 571]]",[],val
3223,9--Press_Conference\9_Press_Conference_Press_C...,"[[332, 172, 626, 544]]",[],val
3224,9--Press_Conference\9_Press_Conference_Press_C...,"[[336, 242, 488, 444], [712, 278, 838, 430]]",[],val


In [5]:
# dataset_df[dataset_df["boxes"].apply(len) == 0]
dataset_df.to_csv("widerface_main_2.csv", index=False)

In [None]:
from collections import defaultdict

import torch

from source.postprocessing import postprocess_predictions
from source.targets import generate_targets_batch
from source.dataset import build_dataloaders
from source.general import get_cpu_state_dict
from source.losses import DetectionLoss
from source.models.yunet import YuNet


def lr_lambda(current_iter):
    if current_iter >= warmup_iters:
        return 1.0
    # linear warmup от warmup_ratio до 1.0
    alpha = current_iter / float(warmup_iters)
    return warmup_ratio * (1 - alpha) + alpha


num_epochs = 80 * 8  # 640
milestones = [50 * 8, 68 * 8]  # [400, 544]
warmup_iters = 1500
warmup_ratio = 0.01

device = torch.device("cuda:0")
yunet = YuNet().to(device)
dataloaders = build_dataloaders(DATA_PATH, dataset_df, torch.device("cpu"))
optimizer = torch.optim.SGD(yunet.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005)
base_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1,)
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
criterion = DetectionLoss(obj_weight=1.0, cls_weight=1.0, box_weight=1.0, kps_weight=1.0)

train_dataloader = dataloaders["train"]
global_iter = 0
for epoch in range(num_epochs):
    running_losses: defaultdict[str, float] = defaultdict(float)
    yunet.train()
    for batch in train_dataloader:
        optimizer.zero_grad()

        images = batch["image"].to(device, non_blocking=True)
        boxes = [item.to(device, non_blocking=True) for item in batch["boxes"]]
        kps = [item.to(device, non_blocking=True) for item in batch["key_points"]]
        p8_out, p16_out, p32_out = yunet(images)
        obj_preds, cls_preds, box_preds, kps_preds, grids = postprocess_predictions((p8_out, p16_out, p32_out), (8, 16, 32))
        foreground_mask, target_cls, target_obj, target_boxes, target_kps = generate_targets_batch(obj_preds, cls_preds, box_preds, grids, boxes, kps, device)
        loss_dict: dict[str, torch.Tensor] = criterion(
                (obj_preds, cls_preds, box_preds, kps_preds),
                (target_obj, target_cls, target_boxes, target_kps),
                foreground_mask,
            )
        loss = loss_dict["total_loss"]
        loss.backward()
        optimizer.step()

        if global_iter < warmup_iters:
            warmup_scheduler.step()
        global_iter += 1

        for loss_name, loss_tensor in loss_dict.items():
            loss_value = loss_tensor.detach().cpu().item()
            running_losses[f"train_{loss_name}"] += loss_value / len(train_dataloader)

    base_scheduler.step()
    loss_str = ", ".join([f"{loss_name}={loss_value:.4f}" for loss_name, loss_value in running_losses.items()])
    print(f"[EPOCH {epoch + 1}/{num_epochs}] {loss_str}, LR: {optimizer.param_groups[0]['lr']}")
    ckpt = {
        "epoch": epoch,
        "model": get_cpu_state_dict(yunet.state_dict()),
        "optimizer": get_cpu_state_dict(optimizer.state_dict()),
        "warmup_scheduler": get_cpu_state_dict(warmup_scheduler.state_dict()),
        "base_scheduler": get_cpu_state_dict(base_scheduler.state_dict()),
    }
    torch.save(ckpt, f"weights/epoch_{epoch}_state_dict.pt")

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
