In [1]:
import os
import typing as ty
from itertools import islice

import cv2
import pandas as pd
import numpy as np


def batched_custom(iterable, n):
    """
    Batches an iterable into chunks of size n.
    Equivalent to itertools.batched in Python 3.12+.
    """
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch


class AnnotationDict(ty.TypedDict):
    image: str
    boxes: list[int]
    key_points: list[list[int]]


class AnnotationContainer:
    def __init__(self, base_dir: str, anno_file_path: str, parse_kps: bool = False) -> None:
        self.parse_kps = parse_kps
        self.base_dir = base_dir
        with open(anno_file_path, "r") as tf:
            self.content = tf.read()
        self.__parse()

    def __parse(self) -> None:
        self.meta = []
        self.images = []
        self.labels = []

        samples_as_text = self.content.split("# ")[1:] # We skip 1 element since it's an empty string.
        samples_as_text = [sample.strip().split("\n") for sample in samples_as_text]
        for sample in samples_as_text:
            image_meta, *boxes_str = sample
            
            name, height, width = image_meta.split(" ")
            image_path = os.path.join(*name.split("/")) # This should be correct for Windows and Unix.
            self.images.append(image_path)
            self.meta.append((height, width))

            labels = {"boxes": [], "key_points": []}
            for i, point_set in enumerate(boxes_str):
                coords = list(map(float, point_set.strip().split(" ")))

                # Box is a first 4 coordinates in top_left_x, top_left_y, bottom_right_x, bottom_right_y format.
                box = list(map(int, coords[:4]))
                if any(coord < 0 for coord in box):
                    msg = f"Image {image_path} has box with negative coords: {box}"
                    raise ValueError(msg)
                labels["boxes"].append(box)

                # Key points are the rest points. It should be 5 in total, 3 components each (x, y, ...). I don't know
                # what is 3rd component.
                if self.parse_kps:
                    kps = []
                    for point in batched_custom(coords[4:], 3):
                        kps.append(list(point[:2]))
                    if len(kps) != 5:
                        msg = f"Image {image_path} has more or less than 5 kps: {kps}"
                        raise ValueError(msg)
                    labels["key_points"].append(kps)
                    
            self.labels.append(labels)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index: int) -> AnnotationDict:
        image = self.images[index]
        label = self.labels[index]
        key_points = label["key_points"]
        return {"image": image, "boxes": boxes, "key_points": key_points}

    def __iter__(self):
        for image_path, label in zip(self.images, self.labels):
            yield {"image": image_path, **label}
            

DATA_PATH = os.path.join(os.getcwd(), "data", "widerface")
dataset_df = None
for subset in ("train", "val"):
    anno_file_path = os.path.join(DATA_PATH, "labelv2", subset, "labelv2.txt")
    image_dir = os.path.join(DATA_PATH, f"WIDER_{subset}", f"WIDER_{subset}", "images")
    container = AnnotationContainer(image_dir, anno_file_path, subset == "train")
    dataframe = pd.DataFrame(data=container)
    dataframe["subset"] = subset
    if dataset_df is None:
        dataset_df = dataframe
    else:
        dataset_df = pd.concat((dataset_df, dataframe))
dataset_df.head()

Unnamed: 0,image,boxes,key_points,subset
0,0--Parade\0_Parade_marchingband_1_849.jpg,"[[449, 330, 571, 479]]","[[[488.90601, 373.64301], [542.08899, 376.4419...",train
1,0--Parade\0_Parade_Parade_0_904.jpg,"[[361, 98, 624, 437]]","[[[424.14301, 251.65601], [547.13397, 232.571]...",train
2,0--Parade\0_Parade_marchingband_1_799.jpg,"[[78, 221, 85, 229], [78, 238, 92, 255], [113,...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
3,0--Parade\0_Parade_marchingband_1_117.jpg,"[[69, 359, 119, 395], [227, 382, 283, 425], [2...","[[[92.232, 391.397], [94.451, 377.45099], [103...",train
4,0--Parade\0_Parade_marchingband_1_778.jpg,"[[27, 226, 60, 262], [63, 95, 79, 114], [64, 6...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train


In [2]:
dataset_df["key_points"]

0       [[[488.90601, 373.64301], [542.08899, 376.4419...
1       [[[424.14301, 251.65601], [547.13397, 232.571]...
2       [[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...
3       [[[92.232, 391.397], [94.451, 377.45099], [103...
4       [[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...
                              ...                        
3221                                                   []
3222                                                   []
3223                                                   []
3224                                                   []
3225                                                   []
Name: key_points, Length: 16106, dtype: object

In [3]:
dataset_df

Unnamed: 0,image,boxes,key_points,subset
0,0--Parade\0_Parade_marchingband_1_849.jpg,"[[449, 330, 571, 479]]","[[[488.90601, 373.64301], [542.08899, 376.4419...",train
1,0--Parade\0_Parade_Parade_0_904.jpg,"[[361, 98, 624, 437]]","[[[424.14301, 251.65601], [547.13397, 232.571]...",train
2,0--Parade\0_Parade_marchingband_1_799.jpg,"[[78, 221, 85, 229], [78, 238, 92, 255], [113,...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
3,0--Parade\0_Parade_marchingband_1_117.jpg,"[[69, 359, 119, 395], [227, 382, 283, 425], [2...","[[[92.232, 391.397], [94.451, 377.45099], [103...",train
4,0--Parade\0_Parade_marchingband_1_778.jpg,"[[27, 226, 60, 262], [63, 95, 79, 114], [64, 6...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
...,...,...,...,...
3221,9--Press_Conference\9_Press_Conference_Press_C...,"[[334, 182, 634, 582]]",[],val
3222,9--Press_Conference\9_Press_Conference_Press_C...,"[[316, 224, 586, 571]]",[],val
3223,9--Press_Conference\9_Press_Conference_Press_C...,"[[332, 172, 626, 544]]",[],val
3224,9--Press_Conference\9_Press_Conference_Press_C...,"[[336, 242, 488, 444], [712, 278, 838, 430]]",[],val


In [4]:
dataset_df

Unnamed: 0,image,boxes,key_points,subset
0,0--Parade\0_Parade_marchingband_1_849.jpg,"[[449, 330, 571, 479]]","[[[488.90601, 373.64301], [542.08899, 376.4419...",train
1,0--Parade\0_Parade_Parade_0_904.jpg,"[[361, 98, 624, 437]]","[[[424.14301, 251.65601], [547.13397, 232.571]...",train
2,0--Parade\0_Parade_marchingband_1_799.jpg,"[[78, 221, 85, 229], [78, 238, 92, 255], [113,...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
3,0--Parade\0_Parade_marchingband_1_117.jpg,"[[69, 359, 119, 395], [227, 382, 283, 425], [2...","[[[92.232, 391.397], [94.451, 377.45099], [103...",train
4,0--Parade\0_Parade_marchingband_1_778.jpg,"[[27, 226, 60, 262], [63, 95, 79, 114], [64, 6...","[[[-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0], [-...",train
...,...,...,...,...
3221,9--Press_Conference\9_Press_Conference_Press_C...,"[[334, 182, 634, 582]]",[],val
3222,9--Press_Conference\9_Press_Conference_Press_C...,"[[316, 224, 586, 571]]",[],val
3223,9--Press_Conference\9_Press_Conference_Press_C...,"[[332, 172, 626, 544]]",[],val
3224,9--Press_Conference\9_Press_Conference_Press_C...,"[[336, 242, 488, 444], [712, 278, 838, 430]]",[],val


In [5]:
# dataset_df[dataset_df["boxes"].apply(len) == 0]
dataset_df.to_csv("widerface_main_2.csv", index=False)

In [None]:
from collections import defaultdict

import torch

from source.postprocessing import postprocess_predictions
from source.targets import generate_targets_batch
from source.dataset import build_dataloaders
from source.general import get_cpu_state_dict
from source.losses import DetectionLoss
from source.models.yunet import YuNet

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


import torch
from torch import nn
from torchvision.ops import box_iou, nms

from source.postprocessing import postprocess_predictions


@torch.no_grad()
def validate(
    model: nn.Module,
    dataloader,
    device: torch.device,
    iou_thr: float = 0.45,
    score_thr: float = 0.0,  # для mAP лучше не резать по score, максимум очень низкий порог
) -> dict[str, float]:
    """
    Валидация:
      - mAP@iou_thr по боксам (один класс "face");
      - NME по 5 кейпоинтам.
    """
    model.eval()

    # Для mAP: собираем все детекции и GT по всему датасету
    all_scores = []      # list[float]
    all_tp_flags = []    # list[int] 1/0 для каждой детекции
    num_gt_total = 0

    # Для NME
    nme_sum = 0.0
    nme_count = 0
    debug_faces = 0

    for batch in dataloader:
        images = batch["image"].to(device)
        gt_boxes_batch = batch["boxes"]          # list[Tensor (Mi, 4)]
        gt_kps_batch = batch["key_points"]       # list[Tensor (Mi, 5, 2)]

        p8_out, p16_out, p32_out = model(images)
        obj_logits, cls_logits, boxes_xyxy, kps_logits, priors = postprocess_predictions(
            (p8_out, p16_out, p32_out),
            (8, 16, 32),
        )

        obj_probs = obj_logits.sigmoid()         # (B, N)
        cls_probs = cls_logits.sigmoid()[..., 0] # (B, N)
        scores = obj_probs * cls_probs           # (B, N)

        B = images.shape[0]
        for b in range(B):
            gt_boxes = gt_boxes_batch[b].to(device).float()      # (M, 4)
            gt_kps = gt_kps_batch[b].to(device).float()          # (M, 5, 2)
            num_gt = gt_boxes.shape[0]
            num_gt_total += num_gt

            scores_b = scores[b]         # (N,)
            boxes_b = boxes_xyxy[b]      # (N, 4)
            kps_b = kps_logits[b]        # (N, 10)
            priors_b = priors[b]         # (N, 4)

            # очень мягкий порог по score, чисто чтобы выкинуть совсем мусор
            keep = scores_b >= score_thr
            if keep.sum() == 0:
                continue

            scores_b = scores_b[keep]
            boxes_b = boxes_b[keep]
            kps_b = kps_b[keep]
            priors_b = priors_b[keep]

            # NMS
            keep_nms = nms(boxes_b, scores_b, iou_thr)
            scores_b = scores_b[keep_nms]
            boxes_b = boxes_b[keep_nms]
            kps_b = kps_b[keep_nms]
            priors_b = priors_b[keep_nms]

            num_dets = boxes_b.shape[0]
            if num_dets == 0:
                continue

            # ===== mAP: TP/FP для каждой детекции =====
            if num_gt == 0:
                # все детекции — FP
                all_scores.extend(scores_b.tolist())
                all_tp_flags.extend([0] * num_dets)
            else:
                ious = box_iou(boxes_b, gt_boxes)  # (Ndets, Ngt)
                # чтобы одна GT не матчилась много раз
                gt_matched = torch.zeros(num_gt, dtype=torch.bool, device=device)

                for d in range(num_dets):
                    iou_row = ious[d]
                    max_iou, gi = iou_row.max(dim=0)
                    gi = gi.item()
                    if max_iou >= iou_thr and not gt_matched[gi]:
                        all_scores.append(float(scores_b[d].item()))
                        all_tp_flags.append(1)
                        gt_matched[gi] = True
                    else:
                        all_scores.append(float(scores_b[d].item()))
                        all_tp_flags.append(0)

            # ===== NME по кейпоинтам =====
            if num_gt == 0 or num_dets == 0 or gt_kps.numel() == 0:
                continue

            # декодируем keypoints из нормализованных в пиксели
            num_points = kps_b.shape[1] // 2  # 5
            kps_abs_list = []
            for i in range(num_points):
                kp_xy = kps_b[:, [2 * i, 2 * i + 1]]          # (Ndets, 2)
                kp_xy_abs = kp_xy * priors_b[:, 2:] + priors_b[:, :2]
                kps_abs_list.append(kp_xy_abs)
            kps_abs = torch.stack(kps_abs_list, dim=1)        # (Ndets, 5, 2)

            # матчинг для NME: для каждого GT берём лучшую детекцию по IoU
            ious_kps = box_iou(boxes_b, gt_boxes)      # (Ndets, Ngt)
            max_det_iou, det_idx_for_gt = ious_kps.max(dim=0)  # (Ngt,)

            for gi in range(num_gt):
                if max_det_iou[gi] < iou_thr:
                    continue

                d = det_idx_for_gt[gi].item()
                gt_kps_face = gt_kps[gi]          # (5, 2)
                det_kps_face = kps_abs[d]         # (5, 2)

                # если в GT есть -1 (нет валидных kps) — пропускаем
                if (gt_kps_face == -1).any():
                    continue

                # нормирующий размер: средняя сторона GT-бокса
                w = gt_boxes[gi, 2] - gt_boxes[gi, 0]
                h = gt_boxes[gi, 3] - gt_boxes[gi, 1]
                face_size = (w + h) / 2.0
                if face_size <= 0:
                    continue

                per_point_err = (det_kps_face - gt_kps_face).norm(dim=-1)  # (5,)
                nme_face = per_point_err.mean() / face_size                # скаляр

                nme_sum += float(nme_face)
                nme_count += 1
                debug_faces += 1

    # ===== mAP расчёт =====
    if len(all_scores) == 0 or num_gt_total == 0:
        ap = 0.0
    else:
        scores_t = torch.tensor(all_scores)
        tps = torch.tensor(all_tp_flags).float()

        # сортировка по score по убыванию
        sorted_idx = torch.argsort(scores_t, descending=True)
        tps = tps[sorted_idx]

        fps = 1.0 - tps
        tp_cum = torch.cumsum(tps, dim=0)
        fp_cum = torch.cumsum(fps, dim=0)

        recalls = tp_cum / max(1, num_gt_total)              # R(k)
        precisions = tp_cum / torch.clamp(tp_cum + fp_cum, min=1e-8)  # P(k)

        # классический вычислитель AP по PR-кривой (интерполяция по 11 точкам можно не делать)
        # делаем интеграл по ломаной: sum (R_i - R_{i-1}) * P_i
        # сначала добавим (0,1) точку
        recalls = torch.cat([torch.tensor([0.0]), recalls])
        precisions = torch.cat([torch.tensor([1.0]), precisions])

        # сглаживаем precision как в VOC: P_interp(r) = max_{r' >= r} P(r')
        for i in range(precisions.shape[0] - 2, -1, -1):
            precisions[i] = torch.maximum(precisions[i], precisions[i + 1])

        # AP = сумма площадей под кривой
        ap = torch.sum((recalls[1:] - recalls[:-1]) * precisions[1:]).item()

    nme = nme_sum / max(1, nme_count)
    print(f"[VAL DEBUG] faces_used_for_NME={debug_faces}, nme_count={nme_count}")

    return {
        f"mAP@{iou_thr:.2f}": ap,
        "NME": nme,
    }


def lr_lambda(current_iter):
    if current_iter >= warmup_iters:
        return 1.0
    # linear warmup от warmup_ratio до 1.0
    alpha = current_iter / float(warmup_iters)
    return warmup_ratio * (1 - alpha) + alpha


def nan_hook(name):
    def hook(module, inp, out):
        if isinstance(out, (tuple, list)):
            outs = out
        else:
            outs = (out,)
        for o in outs:
            if torch.isnan(o).any() or torch.isinf(o).any():
                print(f"NaN in module {name}")
                raise RuntimeError(f"NaN detected after {name}")
    return hook


num_epochs = 80 * 8  # 640
milestones = [50 * 8, 68 * 8]  # [400, 544]
warmup_iters = 1500
warmup_ratio = 0.001

device = torch.device("cuda:0")
yunet = YuNet().to(device)
handles = []
for name, module in yunet.named_modules():
    if len(list(module.children())) == 0:  # только "листья"
        handles.append(module.register_forward_hook(nan_hook(name)))

dataloaders = build_dataloaders(DATA_PATH, dataset_df, torch.device("cpu"))
optimizer = torch.optim.SGD(yunet.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
base_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1,)
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
criterion = DetectionLoss(obj_weight=1.0, cls_weight=1.0, box_weight=5.0, kps_weight=0.1)

global_iter = 0
train_dataloader = dataloaders["train"]
for epoch in range(num_epochs):
    running_losses: defaultdict[str, float] = defaultdict(float)
    yunet.train()
    for batch in train_dataloader:
        optimizer.zero_grad()

        images = batch["image"].to(device, non_blocking=True)
        boxes = [item.to(device, non_blocking=True) for item in batch["boxes"]]
        kps = [item.to(device, non_blocking=True) for item in batch["key_points"]]
        p8_out, p16_out, p32_out = yunet(images)
        obj_preds, cls_preds, box_preds, kps_preds, grids = postprocess_predictions((p8_out, p16_out, p32_out), (8, 16, 32))
        foreground_mask, target_cls, target_obj, target_boxes, target_kps = generate_targets_batch(obj_preds, cls_preds, box_preds, grids, boxes, kps, device)
        loss_dict: dict[str, torch.Tensor] = criterion(
            (obj_preds, cls_preds, box_preds, kps_preds),
            (target_obj, target_cls, target_boxes, target_kps),
            foreground_mask,
            grids,
        )
        loss = loss_dict["total_loss"]
        loss.backward()
        # # Клиппинг по норме (рекомендуемо для твоего случая)
        # max_norm = 100
        # torch.nn.utils.clip_grad_norm_(yunet.parameters(), max_norm)
        optimizer.step()

        if global_iter < warmup_iters:
            warmup_scheduler.step()
        global_iter += 1

        for loss_name, loss_tensor in loss_dict.items():
            loss_value = loss_tensor.detach().cpu().item()
            running_losses[f"train_{loss_name}"] += loss_value / len(train_dataloader)

    val_results = validate(yunet, dataloaders["val"], device, score_thr=0.02, iou_thr=0.45)

    base_scheduler.step()
    loss_str = ", ".join([f"{loss_name}={loss_value:.4f}" for loss_name, loss_value in running_losses.items()])
    loss_str += "|" + ", ".join([f"{name}={val:.4f}" for name, val in val_results.items()])
    print(f"[EPOCH {epoch + 1}/{num_epochs}] {loss_str}")
    ckpt = {
        "epoch": epoch,
        "model": get_cpu_state_dict(yunet.state_dict()),
        "optimizer": optimizer.state_dict(),
        "warmup_scheduler": warmup_scheduler.state_dict(),
        "base_scheduler": base_scheduler.state_dict(),
    }
    torch.save(ckpt, f"weights/epoch_{epoch}_state_dict.pt")

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[VAL DEBUG] faces_used_for_NME=0, nme_count=0
[EPOCH 1/640] train_total_loss=3.4931, train_obj_loss=0.2484, train_cls_loss=0.3117, train_box_loss=2.8238, train_kps_loss=0.1093|mAP@0.45=0.0008, NME=0.0000
[VAL DEBUG] faces_used_for_NME=0, nme_count=0
[EPOCH 2/640] train_total_loss=2.2554, train_obj_loss=0.0293, train_cls_loss=0.0325, train_box_loss=2.1067, train_kps_loss=0.0869|mAP@0.45=0.0154, NME=0.0000
[VAL DEBUG] faces_used_for_NME=0, nme_count=0
[EPOCH 3/640] train_total_loss=1.9599, train_obj_loss=0.0219, train_cls_loss=0.0032, train_box_loss=1.8539, train_kps_loss=0.0809|mAP@0.45=0.1027, NME=0.0000
[VAL DEBUG] faces_used_for_NME=0, nme_count=0
[EPOCH 4/640] train_total_loss=1.7950, train_obj_loss=0.0206, train_cls_loss=0.0011, train_box_loss=1.6985, train_kps_loss=0.0748|mAP@0.45=0.1468, NME=0.0000
[VAL DEBUG] faces_used_for_NME=0, nme_count=0
[EPOCH 5/640] train_total_loss=1.6904, train_obj_loss=0.0195, train_cls_loss=0.0007, train_box_loss=1.5989, train_kps_loss=0.0714|mAP@0.45