In [None]:
!pip install loguru torchmetrics albumentations_experimental



In [None]:
from torchvision.datasets.widerface import WIDERFace
WIDERFace(root='data', split='train', download=True)

1465602149it [00:10, 145850180.47it/s]
362752168it [00:08, 41324991.34it/s]
1844140520it [00:14, 128914963.66it/s]


Downloading http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip to data/widerface/wider_face_split.zip


100%|██████████| 3591642/3591642 [00:01<00:00, 1909871.30it/s]


Extracting data/widerface/wider_face_split.zip to data/widerface


Dataset WIDERFace
    Number of datapoints: 12880
    Root location: data/widerface
    Split: train

In [None]:
import itertools
import cv2
import torch
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset


def xyxy2xywh(x):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
    y[:, 0] = x[:, 0]
    y[:, 1] = x[:, 1]
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y


def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
    y[:, 0] = x[:, 0]
    y[:, 1] = x[:, 1]
    y[:, 2] = x[:, 0] + x[:, 2]  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3]  # bottom right y
    return y


def simple_collate_fn(batch):
    img, bb, kp, img_id, img_path, img_shape = zip(*batch)  # transposed
    for i, b in enumerate(bb):
        b[:, 0] = i  # add target image index
    for i, k in enumerate(kp):
        k[:, 0] = i  # add target image index

    new_batch = {
        'image': torch.stack(img, 0),
        # targets
        'bb': torch.cat(bb, 0),
        'kp': torch.cat(kp, 0),
        # meta
        'img_id': img_id,
        'img_path': img_path,
        'img_shape': img_shape,
    }

    return new_batch


class SimpleCustomBatch:
    def __init__(self, data):
        img, bb, kp, img_id, img_path, img_shape = list(zip(*data))
        for i, b in enumerate(bb):
            b[:, 0] = i  # add target image index
        for i, k in enumerate(kp):
            k[:, 0] = i  # add target image index

        self.image = torch.stack(img, 0)
        self.bb = torch.cat(bb, 0)
        self.kp = torch.cat(kp, 0)
        self.img_path = img_path
        self.img_id = img_id
        self.img_shape = img_shape

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.image = self.image.pin_memory()
        self.bb = self.bb.pin_memory()
        self.kp = self.kp.pin_memory()
        return self

    def __getitem__(self, item):
        return getattr(self, item)


def collate_fn(batch):
    return SimpleCustomBatch(batch)


def is_in_image(point, shape):
    return 0 <= point[0] < shape[0] and 0 <= point[1] < shape[1]


class WiderFaceDataset(Dataset):
    NK = 5
    BB_CLASS_LABELS = ('Face', )
    KP_CLASS_LABELS = ['left_eye', 'right_eye', 'nose', 'left_mouth', 'right_mouth']

    def __init__(self, ds_path, mode, min_size=None, transforms=None, color_layout='RGB'):
        super(WiderFaceDataset, self).__init__()
        self.min_size = min_size
        self.ds_path = ds_path
        self.mode = mode
        self.transforms = transforms
        self.color_layout = color_layout

        self.gt_path = str(Path(ds_path) / f'WIDER_{self.mode}' / 'labelv2.txt')

        self.bb_cat2id = {cat: idx for idx, cat in enumerate(self.BB_CLASS_LABELS)}
        self.bb_id2cat = {idx: cat for idx, cat in enumerate(self.BB_CLASS_LABELS)}

        self.kp_cat2id = {cat: idx for idx, cat in enumerate(self.KP_CLASS_LABELS)}
        self.kp_id2cat = {idx: cat for idx, cat in enumerate(self.KP_CLASS_LABELS)}

        self.images, self.annotations = self.load_annotations(self.gt_path)

    def __len__(self):
        return len(self.annotations)

    def _load_image(self, idx):
        info = self.images[idx]
        img_path = str(Path(self.ds_path) / f'WIDER_{self.mode}' / 'images' / info['filename'])
        img = cv2.imread(img_path)  # BGR

        if self.color_layout.lower() == 'rgb':
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        assert img is not None, f'Image Not Found {img_path}'
        h0, w0 = img.shape[:2]  # orig hw
        return img, img_path, (h0, w0), idx  # img, hw_original, hw_resized

    def __getitem__(self, index):
        image, fp, shape, img_id = self._load_image(index)
        bb_data, kp_data = self.annotations[index]

        bb, bb_labels, bb_ignore, bb_ids = bb_data
        kp, kp_labels, kp_ignore, kp2bb_ids = kp_data

        if self.transforms:
            # apply albumentations transform
            transformed = self.transforms(
                image=image,
                bboxes=bb,
                bb_classes=bb_labels,
                bb_ignore=bb_ignore,
                bb_id=bb_ids,
                keypoints=kp,
                kp_classes=kp_labels,
                kp_ignore=kp_ignore,
                kp2bb_id=kp2bb_ids,
            )
            image = transformed['image']
            bb, kp = transformed['bboxes'], transformed['keypoints']
            bb_labels, kp_labels = transformed['bb_classes'], transformed['kp_classes']
            bb_ignore, kp_ignore = transformed['bb_ignore'], transformed['kp_ignore']
            bb_ids = transformed['bb_id']
            kp2bb_ids = transformed['kp2bb_id']

        bboxes = np.zeros((len(bb), 7), dtype=np.float32)  # (img_id, cat_id, weight, x1, y1, x2, y2)
        key_points = np.zeros((len(bb), 16), dtype=np.float32)  # (img_id, x1, y1, w1, ... x5, y5, w5)

        if len(bb):
            bboxes[:, 1] = np.array(bb_labels)
            bboxes[:, 2] = np.array(bb_ignore) == 0.0
            bboxes[:, 3:] = xywh2xyxy(np.array(bb))

            _id_map = {bb_ids[i]: i for i in range(len(bboxes))}

            for i, (point, ignore, box_id) in enumerate(zip(kp, kp_ignore, kp2bb_ids)):
                lbl_id = i % 5
                weight = 1.0 if is_in_image(point, image.size()[-2:]) and not ignore else 0.0
                start_ind = 1 + 3 * lbl_id
                end_ind = 1 + 3 * (lbl_id + 1)
                box_id = _id_map.get(box_id, None)

                if box_id is not None:
                    key_points[box_id, start_ind:end_ind] = [*point, weight]

        bboxes = torch.tensor(bboxes, dtype=torch.float)
        key_points = torch.tensor(key_points, dtype=torch.float)

        return image, bboxes, key_points, img_id, fp, shape

    def _parse_ann_line(self, line):
        values = [float(x) for x in line.strip().split()]
        bbox = np.array(values[0:4], dtype=np.float32)
        kps = np.zeros((self.NK, 3), dtype=np.float32)
        ignore = False
        if self.min_size is not None:
            assert not self.test_mode
            w = bbox[2] - bbox[0]
            h = bbox[3] - bbox[1]
            if w < self.min_size or h < self.min_size:
                ignore = True
        if len(values) > 4:
            if len(values) > 5:
                kps = np.array(values[4:19], dtype=np.float32).reshape((self.NK, 3))
                for li in range(kps.shape[0]):
                    if (kps[li, :] == -1).all():
                        kps[li][2] = 0.0    # weight = 0, ignore
                    else:
                        assert kps[li][2] >= 0
                        kps[li][2] = 1.0    # weight
            else:
                if not ignore:
                    ignore = (values[4] == 1)

        return dict(bbox=bbox, kps=kps, ignore=ignore, cat='Face')

    def load_annotations(self, ann_file):
        """Load annotation from COCO style annotation file.

        Args:
            ann_file (str): Path of annotation file.

        Returns:
            list[dict]: Annotation info from COCO api.
        """
        name = None
        bbox_map = {}
        for line in open(ann_file, 'r'):
            line = line.strip()
            if line.startswith('#'):
                value = line[1:].strip().split()
                name = value[0]
                width = int(value[1])
                height = int(value[2])

                bbox_map[name] = dict(width=width, height=height, objs=[])
                continue

            assert name is not None
            assert name in bbox_map
            bbox_map[name]['objs'].append(line)

        data_infos = []
        for name in bbox_map:
            item = bbox_map[name]
            width = item['width']
            height = item['height']
            vals = item['objs']

            objs = []
            for line in vals:
                data = self._parse_ann_line(line)
                if data is None:
                    continue
                objs.append(data)   # data is (bbox, kps, cat)

            # if len(objs) == 0:
            #     continue

            data_infos.append(dict(filename=name, width=width, height=height, objs=objs))

        out_ann = []
        images = []

        for info in data_infos:
            objects = info['objs']
            images.append({k: info[k] for k in ['filename', 'width', 'height']})

            n_anns = len(objects)

            bb = np.zeros((n_anns, 4), dtype=np.float32)
            bb_labels = np.zeros((n_anns, 1), dtype=np.int32)
            bb_ignore = np.zeros((n_anns, 1), dtype=np.bool_)
            bb_ids = np.zeros((n_anns, 1), dtype=np.int32)

            kp = np.zeros((n_anns, self.NK, 2), dtype=np.float32)
            kp_labels = np.zeros((n_anns, self.NK), dtype=np.int32)
            kp_ignore = np.zeros((n_anns, self.NK), dtype=np.bool_)
            kp_bb_ids = np.zeros((n_anns, self.NK), dtype=np.int32)

            for idx, obj in enumerate(objects):
                bb[idx, :] = np.array(obj['bbox'], dtype=np.float32)
                bb_labels[idx, :] = self.bb_cat2id[obj['cat']]
                bb_ignore[idx, :] = False  # todo:
                bb_ids[idx, :] = idx

                kp[idx, :, :] = obj['kps'][:, :2]
                kp_labels[idx, :] = [self.kp_cat2id[_] for _ in self.KP_CLASS_LABELS]
                kp_ignore[idx, :] = obj['kps'][:, 2] == 0
                kp_bb_ids[idx, :] = [idx for _ in self.KP_CLASS_LABELS]

            bb = xyxy2xywh(bb)

            kp[kp == -1] = 0    # replace -1 with 0

            img_shape = (info['height'], info['width'])

            bb = self.validate_bb(bb, img_shape)
            kp, ignore = self.validate_kp(kp, bb)

            bb_data = [bb, bb_labels.flatten(), bb_ignore.flatten(), bb_ids.flatten()]
            kp_data = [kp, kp_labels, kp_ignore, kp_bb_ids]

            # assert len(bb_data[0]) > 0
            # kp_data['kp'] = self.validate_kp(kp_data['kp'], bb_data['bb'])

            kp_data[0] = kp_data[0].reshape(-1, 2)
            kp_data[1] = kp_data[1].flatten()
            kp_data[2] = kp_data[2].flatten()
            kp_data[3] = kp_data[3].flatten()
            out_ann.append((bb_data, kp_data))

        return images, out_ann

    def validate_bb(self, bb, img_shape):
        bb = xywh2xyxy(np.array(bb))
        h, w = img_shape
        bb[:, 0::2] = bb[:, 0::2].clip(0, w - 1)
        bb[:, 1::2] = bb[:, 1::2].clip(0, h - 1)
        bb = xyxy2xywh(bb)
        return bb

    def validate_kp(self, kp, bb):
        """
        Clip key points to the bbox size. If kp coords differ more than by 1 pixel: kp is marked as ignore
        bb: np.array(nl, 4)
        kp: np.array(nl, 5, 2)
        """
        ignore_list = []
        for i in range(len(kp)):
            box = bb[i]

            x1, y1, x2, y2 = box[0], box[1], box[0]+box[2], box[1]+box[3]

            # set outbound kps as ignored
            x_ignore_1 = kp[i, :, 0] < x1 - 1
            x_ignore_2 = kp[i, :, 0] > x2
            y_ignore_1 = kp[i, :, 1] < y1 - 1
            y_ignore_2 = kp[i, :, 1] > y2
            ignore = np.sum(np.stack([x_ignore_1, x_ignore_2, y_ignore_1, y_ignore_2]), axis=0) > 0

            # clip all key points to the bbox size
            kp[i, :, 0] = kp[i, :, 0].clip(x1, x2 - 1)
            kp[i, :, 1] = kp[i, :, 1].clip(y1, y2 - 1)

            ignore_list.append(ignore)
        return kp, np.array(ignore_list)

    @property
    def get_ds_name(self):
        return str(Path(self.ds_path).name)

    @staticmethod
    def collate_fn(batch):
        return collate_fn(batch)

In [None]:
from src.upd_scrfd.scrfd_10g import SCRFD_10G
from src.upd_scrfd.scrfd import SCRFD_500M
from src.losses.reviewkd_loss import build_kd_trans

class KDSCRFD_500M(SCRFD_500M):
    def __init__(self, nc, strides=(8, 16, 32), export=True):
        super().__init__(nc, strides, export)
        self.kd_trans = build_kd_trans(None)

    def forward(self, x):
        device = x.device
        backbone_feats = self.backbone(x)
        neck_feat = self.neck(backbone_feats)
        anchors = self.anchor_generator(x, neck_feat, device)
        cls_scores, bboxes, key_points = self.bbox_head(neck_feat)

        if self.training:
            neck_feat = self.kd_trans(neck_feat)
            return neck_feat, (cls_scores, bboxes, key_points, anchors)
        return cls_scores, bboxes, key_points, anchors

class KDSCRFD_10G(SCRFD_10G):
    def forward(self, x):
        device = x.device
        backbone_feats = self.backbone(x)
        neck_feat = self.neck(backbone_feats)
        anchors = self.anchor_generator(x, neck_feat, device)
        cls_scores, bboxes, key_points = self.bbox_head(neck_feat)

        if self.training:
            neck_feat = self.kd_trans(neck_feat)
            return neck_feat, (cls_scores, bboxes, key_points, anchors)
        return cls_scores, bboxes, key_points, anchors

In [None]:
import torch
import torchvision
import albumentations as A
import albumentations_experimental as AE

from tqdm import tqdm
from loguru import logger
from albumentations.pytorch import ToTensorV2
from torchmetrics import MultioutputWrapper, MeanMetric

from src.evaluator import DSMetrics, Metrics
from src.losses.reviewkd_loss import hcl
from src.evaluator import WiderFaceEvaluator
from src.transforms import RandomSmartCrop
from src.losses.detection_loss import MultiOutputDetectionLoss
from src.losses.qfl import QualityFocalLoss
from src.losses.iou_loss import DIoULoss
from src.losses.smooth_l1 import SmoothL1Loss
from src.upd_scrfd.utils import distance2kps, distance2bbox, kps2distance

In [None]:

def setup_tqdm_loader(data_loader, mode, epoch, max_epochs, loader_ind=None):
    if mode == 'train':
        desc = f'Train'.ljust(10) + f'{epoch}/{max_epochs}'
    elif mode == 'val':
        if loader_ind is None:
            split = f'Val'.ljust(10)
        else:
            split = f'Val_{loader_ind}'.ljust(10)
        desc = split + f'{epoch}/{max_epochs}'
    else:
        raise NotImplementedError

    bar_format = '{l_bar}{bar}{r_bar}'
    enabled = True
    pbar = tqdm(data_loader, disable=not enabled, bar_format=bar_format, total=len(data_loader), desc=desc)
    return pbar

def get_anchor_centers(fm_h, fm_w, stride, na):
    xv, yv = torch.meshgrid(torch.arange(fm_w), torch.arange(fm_h), indexing='ij')
    anchor_centers = torch.stack([yv, xv], dim=-1).float()
    anchor_centers = (anchor_centers * stride).reshape((-1, 2))
    anchor_centers = torch.stack([anchor_centers] * na, dim=1).reshape((-1, 2))
    return anchor_centers


def postprocess(raw_pred, iou_thresh, conf_thresh):
        num_classes = 1
        strides = (8, 16, 32)
        cls_scores, bboxes, key_points, _ = raw_pred

        bb_per_lvl = []
        kp_per_lvl = []
        scores_per_lvl = []
        labels_per_lvl = []
        per_lvl_batch_ind = []

        for stride_idx, stride in enumerate(strides):
            bbox = bboxes[stride_idx]
            kps = key_points[stride_idx]

            device = bbox.device

            b, AxC, h, w = bbox.shape
            na = AxC // 4
            c = 4
            bbox = bbox.view(b, na, c, h, w)
            bbox = bbox.permute(0, 3, 4, 1, 2)
            bbox = bbox.reshape(-1, c)
            bbox *= stride

            anchor_centers = get_anchor_centers(h, w, stride, na)
            anchor_centers = anchor_centers.repeat(b, 1)
            anchor_centers = anchor_centers.to(device)

            bbox = distance2bbox(anchor_centers, bbox)
            bb_per_lvl.append(bbox)

            b, AxC, h, w = kps.shape
            na = AxC // 10
            c = 10
            kps = kps.view(b, na, c, h, w)
            kps = kps.permute(0, 3, 4, 1, 2)
            kps = kps.reshape(-1, c)
            kps *= stride

            kps = distance2kps(anchor_centers, kps)
            kp_per_lvl.append(kps)

            scores = cls_scores[stride_idx]
            b, AxC, h, w = scores.shape
            na = AxC // num_classes
            c = num_classes
            scores = scores.view(b, na, c, h, w)
            scores = scores.permute(0, 3, 4, 1, 2)
            scores = scores.reshape(-1, c)
            scores = scores.sigmoid()
            scores, labels = torch.max(scores, dim=1)

            scores_per_lvl.append(scores)
            labels_per_lvl.append(labels)

            batch_ind = torch.arange(b, device=device).view(-1, 1).repeat(1, na * w * h).flatten()  # b x (na*w*h)
            per_lvl_batch_ind.append(batch_ind)

        bboxes = torch.cat(bb_per_lvl)
        kps = torch.cat(kp_per_lvl)
        scores = torch.cat(scores_per_lvl)
        labels = torch.cat(labels_per_lvl)
        idxs = torch.cat(per_lvl_batch_ind)

        # todo: returns tuple check?
        is_pos = torch.where(scores > conf_thresh)[0]
        bboxes = bboxes[is_pos]
        kps = kps[is_pos]
        scores = scores[is_pos]
        labels = labels[is_pos]
        idxs = idxs[is_pos]

        keep_after_nms = torchvision.ops.batched_nms(bboxes, scores, idxs, iou_threshold=iou_thresh)

        bboxes = bboxes[keep_after_nms]
        kps = kps[keep_after_nms]
        scores = scores[keep_after_nms]
        labels = labels[keep_after_nms]
        idxs = idxs[keep_after_nms]

        out_bboxes = [[] for _ in range(b)]
        out_scores = [[] for _ in range(b)]
        out_labels = [[] for _ in range(b)]
        out_key_points = [[] for _ in range(b)]
        for bbox, kp, score, lbl, idx in zip(bboxes, kps, scores, labels, idxs):
            out_scores[idx].append(score)
            out_bboxes[idx].append(bbox)
            out_key_points[idx].append(kp)
            out_labels[idx].append(lbl)
        # cls_score, labels, bbox_pred, kps
        return out_scores, out_labels, out_bboxes, out_key_points


def get_transforms(mode):
    if mode == 'train':
        t = A.Compose([
                RandomSmartCrop(p=1),
                A.LongestMaxSize(max_size=640),
                A.PadIfNeeded(min_height=640, min_width=640, value=[0, 0, 0], border_mode=0, position='top_left'),
                A.ColorJitter(hue=0.0705, saturation=[0.5, 1.5], contrast=[0.5, 1.5], brightness=0.1254, p=0.3),
                AE.HorizontalFlipSymmetricKeypoints(symmetric_keypoints=[[0, 1], [2, 2], [3, 4]], p=0.5),
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.50196, 0.50196, 0.50196]),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format='coco', min_visibility=0.7, label_fields=['bb_classes', 'bb_ignore', 'bb_id']),
            keypoint_params=A.KeypointParams(format='xy', label_fields=[ 'kp_classes', 'kp2bb_id', 'kp_ignore' ], remove_invisible=False)
        )
    else:
        t = A.Compose([
            A.LongestMaxSize(max_size=640),
            A.PadIfNeeded(min_height=640, min_width=640, value=[0, 0, 0], border_mode=0, position='top_left'),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.50196, 0.50196, 0.50196]),
            ToTensorV2(),
        ],
            bbox_params=A.BboxParams(format='coco', min_visibility=0.7,
                                     label_fields=['bb_classes', 'bb_ignore', 'bb_id']),
            keypoint_params=A.KeypointParams(format='xy', label_fields=['kp_classes', 'kp2bb_id', 'kp_ignore'],
                                             remove_invisible=True)
        )
    return t

In [None]:
def train_epoch(model, dataloader, optimizer, lr_scheduler, criterion, epoch, max_epochs, device, debug=False):
    model.train()
    loss_avg = MultioutputWrapper(base_metric=MeanMetric(), num_outputs=criterion.num_loss_items).to(device)

    pbar = setup_tqdm_loader(dataloader, mode='train', epoch=epoch, max_epochs=max_epochs)

    # bb_targets: [N, 7]   (img_id, cat_id, ignore, x1, y1, x2, y2)
    # kp_targets: [N, 17]  (img_id, box_id, x1, y1, i1, ... x5, y5, i5)
    for iter, batch in enumerate(pbar):
        img = batch['image']
        img_shape = batch['image'].size()[-2:]
        targets = {'bb': batch['bb'].to(device), 'kp': batch['kp'].to(device)}
        img = img.to(device)

        # todo:
        raw_output = model(img)
        loss, loss_items = criterion(raw_output, targets, img_shape)

        if torch.any(torch.isnan(loss_items)):
            logger.warning('Nan Loss encountered')
        else:
            loss_avg.update(loss_items.unsqueeze(0))
            loss.backward()

        # optimize
        optimizer.step()
        optimizer.zero_grad()

        # lr scheduler step (per iteration)
        if lr_scheduler:
            lr_scheduler.step()

        # tqdm pbar postfix
        avg_losses = loss_avg.compute()
        loss_names = criterion.loss_items_names
        avg_losses = {name: f'{value:.4f}' for name, value in zip(loss_names, avg_losses)}
        mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
        pbar.set_postfix({**avg_losses, 'gpu_mem': f'{mem:.2f}Gb', 'img_size': list(img_shape)})

        if debug and iter > 5:
            break

    loss_items = map(float, loss_avg.compute())
    loss_items = dict(zip(criterion.loss_items_names, loss_items))
    train_metrics = DSMetrics(losses=loss_items)
    train_metrics = Metrics(metrics=[train_metrics], split_names=['Train'])
    return model, optimizer, lr_scheduler, train_metrics

In [None]:
def nd_loss(ND_loss, s_emb, t_emb, T_EMB, labels):
    nd = ND_loss(s_emb, t_emb, T_EMB, labels)
    return {
        'loss_nd': nd,
    }


def distillation_train_epoch(s_model, t_model, dataloader, optimizer, lr_scheduler, criterion, epoch, max_epochs, device, debug=False):
    s_model.train()
    loss_avg = MultioutputWrapper(base_metric=MeanMetric(), num_outputs=criterion.num_loss_items).to(device)

    pbar = setup_tqdm_loader(dataloader, mode='train', epoch=epoch, max_epochs=max_epochs)

    # T_EMB = ...
    # T = ...

    # ND_loss = DirectNormLoss(num_class=1, nd_weight=1.5)
    # KD_loss = KDLoss(kd_weight=2.2, T=T)

    # bb_targets: [N, 7]   (img_id, cat_id, ignore, x1, y1, x2, y2)
    # kp_targets: [N, 17]  (img_id, box_id, x1, y1, i1, ... x5, y5, i5)
    for iter, batch in enumerate(pbar):
        img = batch['image']
        img_shape = batch['image'].size()[-2:]
        targets = {'bb': batch['bb'].to(device), 'kp': batch['kp'].to(device)}
        img = img.to(device)

        s_features, s_raw_output = s_model(img)

        with torch.no_grad():
            t_features, t_raw_output = t_model(img)

        # compute loss
        det_loss, loss_items = criterion(s_raw_output, targets, img_shape)

        # ND Loss
        # gt = [x.gt_classes for x in sampled_proposals]
        # gt_classes = torch.cat(tuple(gt), 0).reshape(-1)
        # nd_loss_dict = nd_loss(ND_loss=ND_loss, s_emb=s_emb, t_emb=t_emb, T_EMB=T_EMB, labels=gt_classes)
        # reviewkd loss
        # t_features = [t_features[f] for f in t_features]
        # s_features = [s_features[f] for f in s_features]
        reviewkd_loss = hcl(s_features, t_features) * 2.0

        loss = det_loss + reviewkd_loss


        if torch.any(torch.isnan(loss_items)):
            logger.warning('Nan Loss encountered')
        else:
            loss_avg.update(loss_items.unsqueeze(0))
            loss.backward()

        # optimize
        optimizer.step()
        optimizer.zero_grad()

        # lr scheduler step (per iteration)
        if lr_scheduler:
            lr_scheduler.step()

        # tqdm pbar postfix
        avg_losses = loss_avg.compute()
        loss_names = criterion.loss_items_names
        avg_losses = {name: f'{value:.4f}' for name, value in zip(loss_names, avg_losses)}
        mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
        pbar.set_postfix({**avg_losses, 'gpu_mem': f'{mem:.2f}Gb', 'img_size': list(img_shape)})

        if debug and iter > 5:
            break

    loss_items = map(float, loss_avg.compute())
    loss_items = dict(zip(criterion.loss_items_names, loss_items))
    train_metrics = DSMetrics(losses=loss_items)
    train_metrics = Metrics(metrics=[train_metrics], split_names=['Train'])
    return s_model, optimizer, lr_scheduler, train_metrics

In [None]:
import time
def eval(model, dataloader, postprocess, device='cpu', debug=False):
    model.eval()
    desc = f'Evaluating'
    bar_format = '{l_bar}{bar}{r_bar}'
    pbar = tqdm(dataloader, bar_format=bar_format, total=len(dataloader), desc=desc)

    gt_dir = 'data/widerface/WIDER_val/gt'
    evaluator = WiderFaceEvaluator(gt_dir, iou_thresh=0.5)

    # time_arr = []
    with torch.no_grad():
        for i, batch in enumerate(pbar):
            img = batch['image']
            img_shape = batch['image'].size()[-2:]
            targets = {'bb': batch['bb'].to(device), 'kp': batch['kp'].to(device)}
            img = img.to(device)

            t0 = time.time()
            raw_output = model(img)
            output = postprocess(raw_output)
            t1 = time.time()

            meta_data = {'img_id': batch['img_id'], 'img0_shape': batch['img_shape'], 'img1_shape': img_shape,
                         'img_path': batch['img_path']}

            evaluator.add_batch(output, targets, meta_data)

            # if i > 10:
            #     time_arr.append(t1-t0)

            if debug and i > 5:
                break

        metrics = evaluator.compute()
        # print(sum(time_arr) / len(time_arr) * 1000)
    return metrics

In [None]:
debug = False
lr = 0.003
max_epochs = 30
device = 'cuda:0'
ds_path = 'data/widerface'

# train split
t_transforms = get_transforms('train')
t_dataset = WiderFaceDataset(ds_path, 'train', min_size=None, transforms=t_transforms, color_layout='RGB')
train_dataloader = torch.utils.data.DataLoader(t_dataset, batch_size=16, num_workers=2, pin_memory=True, collate_fn=t_dataset.collate_fn)

# val split
v_transforms = get_transforms('val')
v_dataset = WiderFaceDataset(ds_path, 'val', min_size=None, transforms=v_transforms, color_layout='RGB')
val_dataloader = torch.utils.data.DataLoader(v_dataset, batch_size=16, num_workers=2, pin_memory=True, collate_fn=v_dataset.collate_fn)

postproc = lambda x: postprocess(x, conf_thresh=0.02, iou_thresh=0.45)

# init teacher model
t_model = KDSCRFD_10G(nc=1).to(device)
t_model.load_from_checkpoint('weights/upd_SCRFD_10G_KPS.pth')

# init student model
s_model = KDSCRFD_500M(nc=1).to(device)
s_model.load_from_checkpoint('weights/upd_SCRFD_500M_KPS.pth', strict=False)

# init optimizer
p_groups = s_model.get_param_groups(wd=1e-4, no_decay_bn_filter_bias=True)
optimizer = torch.optim.Adam(p_groups, lr=lr)

# lr_scheduler
lr_scheduler = None

# loss
cls_loss = QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0)
bb_loss = DIoULoss(loss_weight=2.0)
kp_loss = SmoothL1Loss(beta=0.1111111111111111, loss_weight=0.1)
criterion = MultiOutputDetectionLoss(cls_loss=cls_loss, bb_loss=bb_loss, kp_loss=kp_loss)

best_score = ((0.9071 + 0.8805 + 0.6768) / 3, -1)
for i in range(max_epochs):
    out = distillation_train_epoch(
        s_model,
        t_model,
        train_dataloader,
        optimizer,
        lr_scheduler,
        criterion,
        epoch=i,
        max_epochs=max_epochs,
        device=device,
        debug=debug
    )

    metrics = eval(s_model, val_dataloader, postproc, debug=False, device=device)
    score = metrics.get_fitness_score(labels=['bb_easy_AP', 'bb_medium_AP', 'bb_hard_AP'], weights=[1.0, 1.0, 1.0])
    print(f'Current score: {score}, best_score: {best_score[0]}')
    print(f'Metrics: {metrics}')

    best_score = (max(score, best_score[0]), i)
    ckpt = {
        'model': s_model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': i,
        'metrics': metrics,
    }
    torch.save(ckpt, f'runs1/reviewkd_ckpt_{i}_{score:.4f}.pt')


Train     0/30: 100%|██████████| 805/805 [09:18<00:00,  1.44it/s, cls_QFL=0.1288, bb_DIoU=0.2465, kp_SmoothL1=0.3094, Total=0.2282, gpu_mem=7.01Gb, img_size=[640, 640]]
Evaluating: 100%|██████████| 202/202 [13:53<00:00,  4.13s/it]
  _pr_curve[i, 0] = pr_curve[i, 1] / pr_curve[i, 0]


Current score: 0.3428982120005228, best_score: 0.8214666666666667
Metrics: Detection Metrics:
easy_AP: 0.4760, medium_AP: 0.3795, hard_AP: 0.1731


Train     1/30: 100%|██████████| 805/805 [09:11<00:00,  1.46it/s, cls_QFL=0.1288, bb_DIoU=0.2394, kp_SmoothL1=0.2934, Total=0.2206, gpu_mem=7.01Gb, img_size=[640, 640]]
Evaluating: 100%|██████████| 202/202 [08:39<00:00,  2.57s/it]


Current score: 0.370182983358564, best_score: 0.8214666666666667
Metrics: Detection Metrics:
easy_AP: 0.5268, medium_AP: 0.4032, hard_AP: 0.1805


Train     2/30: 100%|██████████| 805/805 [09:10<00:00,  1.46it/s, cls_QFL=0.1286, bb_DIoU=0.2349, kp_SmoothL1=0.2776, Total=0.2137, gpu_mem=7.01Gb, img_size=[640, 640]]
Evaluating: 100%|██████████| 202/202 [10:06<00:00,  3.00s/it]


Current score: 0.3242337038640175, best_score: 0.8214666666666667
Metrics: Detection Metrics:
easy_AP: 0.4469, medium_AP: 0.3627, hard_AP: 0.1632


Train     3/30: 100%|██████████| 805/805 [09:08<00:00,  1.47it/s, cls_QFL=0.1265, bb_DIoU=0.2405, kp_SmoothL1=0.2925, Total=0.2198, gpu_mem=7.01Gb, img_size=[640, 640]]
Evaluating: 100%|██████████| 202/202 [06:25<00:00,  1.91s/it]


Current score: 0.47637578063184577, best_score: 0.8214666666666667
Metrics: Detection Metrics:
easy_AP: 0.6336, medium_AP: 0.5385, hard_AP: 0.2571


Train     4/30:  80%|████████  | 648/805 [07:27<01:40,  1.57it/s, cls_QFL=0.1264, bb_DIoU=0.2364, kp_SmoothL1=0.2762, Total=0.2130, gpu_mem=7.01Gb, img_size=[640, 640]]

In [None]:
score = metrics.get_fitness_score(labels=['bb_easy_AP', 'bb_medium_AP', 'bb_hard_AP'], weights=[1.0, 1.0, 1.0])

# easy_AP: 0.9071, medium_AP: 0.8805, hard_AP: 0.6768