### Dataset

In [1]:
import os
import torch

from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
import torchvision.transforms.v2 as v2

class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = read_image(img_path)
        mask = read_image(mask_path)
        # instances are encoded as different colors
        obj_ids = torch.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        num_objs = len(obj_ids)

        # split the color-encoded mask into a set
        # of binary masks
        masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)

        # get bounding box coordinates for each mask
        boxes = masks_to_boxes(masks)

        # there is only one class: attention différent de fcos, A UNIFORMISER
        labels = torch.zeros((num_objs,), dtype=torch.int64)

        image_id = idx
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Wrap sample and targets into torchvision tv_tensors:
        img = tv_tensors.Image(img)

        target = {}
        h, w = v2.functional.get_size(img)
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(h,w))
        target["masks"] = tv_tensors.Mask(masks)
        target["labels"] = labels
        target["image_id"] = torch.Tensor([image_id])
        #target["area"] = area
        #target["iscrowd"] = iscrowd
        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        return {'image': img, 'target': target}
    
    def __len__(self):
        return len(self.imgs)

### Model

In [2]:
import timm
import torch.nn as nn
import torchvision
from timm.layers import resample_abs_pos_embed     

    
class YOLOS(nn.Module):

    def __init__(
        self,
        num_classes,
        det_token_num,
        backbone,
    ):
        super().__init__()
        self.backbone = timm.create_model(
            backbone,
            pretrained=True,
            dynamic_img_size=True #Deals with inputs of other size than pretraining
        )
        self.num_classes = num_classes
        self.embed_dim = self.backbone.embed_dim 
        self.det_token_num = det_token_num
        self.add_det_tokens_to_backbone()
        self.class_embed = torchvision.ops.MLP(
            self.embed_dim,
            [self.embed_dim, self.embed_dim, self.num_classes+1]
        ) #Num_classes + 1 to deal with no_obj category
        self.bbox_embed = torchvision.ops.MLP(
            self.embed_dim,
            [self.embed_dim, self.embed_dim, 4]
        )
        
    def add_det_tokens_to_backbone(self):
        det_token = nn.Parameter(
            torch.zeros(
                1,
                self.det_token_num,
                self.backbone.embed_dim
            )
        )
        self.det_token = torch.nn.init.trunc_normal_(
            det_token,
            std=.02
        )
        det_pos_embed = nn.Parameter(
            torch.zeros(
                1,
                self.det_token_num,
                self.backbone.embed_dim
            )
        )
        self.det_pos_embed = torch.nn.init.trunc_normal_(
            det_pos_embed,
            std=.02
        )
        #The ViT needs to know how many input tokens are not for patch embeddings
        self.backbone.num_prefix_tokens += self.det_token_num
        
    def forward(self, x):
        """ This code relies on class_token=True in ViT
        """
        x = self.backbone.patch_embed(x)
        
        # Inserting position embedding for detection tokens and resampling if dynamic
        cls_pos_embed = self.backbone.pos_embed[:, 0, :][:,None] # size 1x1xembed_dim
        patch_pos_embed = self.backbone.pos_embed[:, 1:, :] # 1xnum_patchxembed_dim
        pos_embed = torch.cat((self.det_pos_embed, cls_pos_embed, patch_pos_embed), dim=1)
        if self.backbone.dynamic_img_size:
            B, H, W, C = x.shape
            pos_embed = resample_abs_pos_embed(
                pos_embed,
                (H, W),
                num_prefix_tokens=self.backbone.num_prefix_tokens,
            )
            x = x.view(B, -1, C)
            
        # Inserting detection tokens    
        cls_token = self.backbone.cls_token.expand(x.shape[0], -1, -1) 
        det_token = self.det_token.expand(x.shape[0], -1, -1)
        x = torch.cat([det_token, cls_token, x], dim=1)
        
        # Forward ViT
        x += pos_embed
        x = self.backbone.pos_drop(x)
        x = self.backbone.patch_drop(x)
        x = self.backbone.norm_pre(x)
        x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        
        # Extracting processed detection tokens + forward heads
        x = x[:,:self.det_token_num,...]
        outputs_class = self.class_embed(x)
        outputs_coord = self.bbox_embed(x).sigmoid()
        
        #out = {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}
        return outputs_class, outputs_coord

  from .autonotebook import tqdm as notebook_tqdm


### Loss

In [33]:
from torchvision.ops import box_convert, generalized_box_iou

class SetCriterion(nn.Module):
    def __init__(self, num_classes, eos_coef):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            eos_coef: relative classification weight applied to the no-object category
        """
        super().__init__()
        self.num_classes = num_classes
        self.eos_coef = eos_coef
        self.weight_dict = {
            'loss_ce': 1,
            'loss_bbox': 5,
            'loss_giou': 2
        }
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, pred_logits, tgt_cls, matches):
        """
        Params:
            matches: list of batch_size pairs (I, J) of arrays such that output bbox I[n] must be matched with target bbox J[n]
        """
                
        # all N tgt labels are reordered following the matches and concatenated  
        reordered_labels = [t[J] for t, (_, J) in zip(tgt_cls, matches)]
        reordered_labels = torch.cat(reordered_labels) # Nx1
        #print(f'{reordered_labels.shape =}')
        
        # batch_idxs[i] is the idx in the batch of the img to which the i-th elem in the new order corresponds
        batch_idxs = [torch.full_like(pred, i) for i, (pred, _) in enumerate(matches)]
        batch_idxs = torch.cat(batch_idxs) # Nx1
        
        # src_idxs[i] is the idx of the preds for img batch_idxs[i] to which the i-th elem in the new order corresponds
        pred_idxs = torch.cat([pred for (pred, _) in matches]) # Nx1
        
        # target_classes is of shape batch_size x num det tokens, and is num_classes (=no_obj) everywhere, except for each token that is matched to a tgt, where it is the label of the matched tgt
        target_classes = torch.full(
            pred_logits.shape[:2], #BxNdetTok
            self.num_classes, #Filled with num_cls
            dtype=torch.int64, 
            device=pred_logits.device
        )
        target_classes[(batch_idxs, pred_idxs)] = reordered_labels
        loss_ce = nn.functional.cross_entropy(
            pred_logits.transpose(1, 2), #BxNclsxd1xd2...
            target_classes, #Bxd1xd2...
            self.empty_weight
        )
        
        ## If we did as follows, then there would be no incentive for the network to output small logits for non-matched tokens
        #reordered_pred_logits = pred_logits[(batch_idxs, pred_idxs)] # NxNcls
        #other_loss_ce = F.cross_entropy(
        #    reordered_pred_logits,
        #    reordered_labels
        #)
        
        losses = {'loss_ce': loss_ce}
        return losses

    def loss_boxes(self, pred_boxes, tgt_boxes, matches, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
        """
        reordered_target_boxes = [t[i] for t, (_, i) in zip(tgt_boxes, matches)]
        reordered_target_boxes = torch.cat(reordered_target_boxes) # Nx4
        #print(f'{reordered_target_boxes.shape =}')
        
        # batch_idxs[i] is the idx in the batch of the img to which the i-th elem in the new order corresponds
        batch_idxs = [torch.full_like(pred, i) for i, (pred, _) in enumerate(matches)]
        batch_idxs = torch.cat(batch_idxs) # Nx1
        
        # src_idxs[i] is the idx of the preds for img batch_idxs[i] to which the i-th elem in the new order corresponds
        pred_idxs = torch.cat([pred for (pred, _) in matches]) # Nx1
        
        #print(f'{pred_boxes.shape =}')
        reordered_pred_boxes = pred_boxes[(batch_idxs, pred_idxs)] # Nx4
        #print(f'{reordered_pred_boxes.shape =}')

        losses = {}
        loss_bbox = nn.functional.l1_loss(
            reordered_pred_boxes,
            reordered_target_boxes,
            reduction='none'
        )
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes
        loss_giou = 1 - torch.diag(generalized_box_iou(
            box_convert(reordered_pred_boxes, 'cxcywh', 'xyxy'),
            box_convert(reordered_target_boxes, 'cxcywh', 'xyxy')))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

    def forward(self, pred_cls, pred_boxes, tgt_cls, tgt_boxes, matches):
        """ This performs the loss computation.
        Parameters:
            outputs: dict of tensors, see the output specification of the model for the format
            targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc
            matches: list of batch_size pairs (I, J) of arrays such that for pair (I,J) output bbox I[n] must be matched with target bbox J[n]
        """
        # Compute the average (?) number of target boxes accross all nodes, for normalization purposes
        #num_boxes = sum(len(boxes) for boxes in tgt_boxes)
        #device = next(iter(outputs.values())).device
        #num_boxes = torch.as_tensor(
        #    [num_boxes],
        #    dtype=torch.float,
        #    device=device
        #)
        num_boxes = sum(len(boxes) for boxes in tgt_boxes)
        # Compute all the requested losses
        losses = {}
        losses.update(self.loss_labels(pred_cls, tgt_cls, matches))
        losses.update(self.loss_boxes(pred_boxes, tgt_boxes, matches, float(num_boxes)))
        losses['combined_loss'] = sum(losses[k] * self.weight_dict[k] for k in losses.keys() if k in self.weight_dict)
        return losses   

### Utils

In [34]:
from scipy.optimize import linear_sum_assignment

@torch.no_grad()
def hungarian_matching(pred_cls, pred_boxes, target_cls, target_boxes):
    """ 
    Params:
        pred_cls: Tensor of dim [B, num_queries, num_classes] with the class logits
        pred_boxes: Tensor of dim [B, num_queries, 4] with the pred box coord
        target_cls: list (len=batchsize) of Tensors of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels
        target_boxes: list (len=batchsize) of Tensors of dim [num_target_boxes, 4] containing the target box coord

    Returns:
        A list of size batch_size, containing tuples of (index_i, index_j) where:
            - index_i is the indices of the selected predictions (in order)
            - index_j is the indices of the corresponding selected targets (in order)
        For each batch element, it holds:
            len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
    """
    
    # Also concat the target labels and boxes
    tgt_ids = torch.cat(target_cls)
    tgt_bbox = torch.cat(target_boxes)
    
    # For each query box in the batch, the output proba of all classes
    all_query_probs = pred_cls.flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
    # Compute the classification cost. Contrary to the loss, we don't use the NLL,
    # but approximate it in 1 - proba[target class].
    # The 1 is a constant that doesn't change the matching, it can be ommitted.
    # For each query box in the batch, the output prob of the classes of all targets of the batch
    cost_class = -all_query_probs[:, tgt_ids] # bs*num_q x tot num targets over batch

    # Compute the L1 cost between boxes
    out_bbox = pred_boxes.flatten(0, 1)  # [batch_size * num_queries, 4]
    cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # bs*num_q x tot num targets over batch

    # Compute the giou cost betwen boxes
    out_bbox = box_convert(out_bbox, 'cxcywh', 'xyxy')
    tgt_bbox = box_convert(tgt_bbox, 'cxcywh', 'xyxy')
    cost_giou = -generalized_box_iou(out_bbox, tgt_bbox)

    # Final cost matrix
    C = 5 * cost_bbox + cost_class + 2 * cost_giou
    B, Q = pred_cls.shape[:2]
    C = C.view(B, Q, -1).cpu() # bs x num_q x tot num targets over batch

    sizes = [len(bbox) for bbox in target_boxes] # num_tgt per img 

    # Finds the minimum cost detection token/target assignment per img
    indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
    int64 = lambda x: torch.as_tensor(x, dtype=torch.int64)
    return [(int64(i), int64(j)) for i, j in indices]

### Post-processing predictions to boxes

In [35]:
@torch.no_grad()
def post_process(logits, boxes, size):
    prob = nn.functional.softmax(logits, -1) # bxNdetTokxNcls
    # Most prob cls (except no-obj: NOOBJ is class Ncls ?) and its score per img per token
    scores, labels = prob[..., :-1].max(-1) # bxNdetTok
    
    #boxes = box_convert(boxes, 'cxcywh', 'xyxy')
    # and from relative [0, 1] to absolute [0, height] coordinates
    #img_h, img_w = target_sizes.unbind(1)
    #scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
    h,w = size
    scale_fct = torch.tensor([w, h, w, h], device=boxes.device)
    boxes = boxes * scale_fct[None, None, :]

    results = [{'scores': s, 'labels': l, 'boxes': b}
               for s, l, b in zip(scores, labels, boxes)]
    return results

### Training

In [36]:
import time
import random
import numpy as np
from tqdm.auto import tqdm
from pprint import pformat
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import gc 

import random
import torch
import numpy as np

from torchvision.transforms import v2 as T
import torch
from torch.utils.data import DataLoader, Subset, RandomSampler

In [37]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [38]:
from dl_toolbox.transforms import NormalizeBB

tf = T.Compose(
    [
        T.Resize(size=480, max_size=560),
        T.RandomCrop(size=(560,560), pad_if_needed=True, fill=0),
        T.ToDtype(torch.float, scale=True),
        T.SanitizeBoundingBoxes(),
        T.ConvertBoundingBoxFormat(format='CXCYWH'),
        NormalizeBB(),
        #T.ToPureTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)


dataset = PennFudanDataset('/data/PennFudanPed', tf)
dataset_test = PennFudanDataset('/data/PennFudanPed', tf)
# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
train_set = torch.utils.data.Subset(dataset, indices[:-50])
val_set = torch.utils.data.Subset(dataset_test, indices[-50:])

In [39]:
from collections import defaultdict

def list_of_dicts_to_dict_of_lists(list_of_dicts):
    dict_of_lists = defaultdict(list)
    for dct in list_of_dicts:
        for key, value in dct.items():
            dict_of_lists[key].append(value)
    res = dict(dict_of_lists)
    return res

def collate(batch):
    batch = list_of_dicts_to_dict_of_lists(batch)
    batch['image'] = torch.stack(batch['image'])
    return batch

train_dataloader = DataLoader(
    batch_size=4,
    num_workers=0,
    pin_memory=False,
    dataset=train_set,
    sampler=RandomSampler(
        train_set,
    ),
    drop_last=True,
    collate_fn=collate
)

val_dataloader = DataLoader(
    batch_size=2,
    num_workers=0,
    pin_memory=False,
    dataset=val_set,
    shuffle=False,
    drop_last=False,
    collate_fn=collate
)

In [43]:
import logging
import math
import numpy as np
import torch

from typing import Dict, Any

import torch


class Scheduler:
    """ Parameter Scheduler Base Class
    A scheduler base class that can be used to schedule any optimizer parameter groups.

    Unlike the builtin PyTorch schedulers, this is intended to be consistently called
    * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
    * At the END of each optimizer update, after incrementing the update count, to calculate next update's value

    The schedulers built on this should try to remain as stateless as possible (for simplicity).

    This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
    and -1 values for special behaviour. All epoch and update counts must be tracked in the training
    code and explicitly passed in to the schedulers on the corresponding step or step_update call.

    Based on ideas from:
     * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
     * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
    """

    def __init__(self,
                 optimizer: torch.optim.Optimizer,
                 param_group_field: str,
                 noise_range_t=None,
                 noise_type='normal',
                 noise_pct=0.67,
                 noise_std=1.0,
                 noise_seed=None,
                 initialize: bool = True) -> None:
        self.optimizer = optimizer
        self.param_group_field = param_group_field
        self._initial_param_group_field = f"initial_{param_group_field}"
        if initialize:
            for i, group in enumerate(self.optimizer.param_groups):
                if param_group_field not in group:
                    raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
                group.setdefault(self._initial_param_group_field, group[param_group_field])
        else:
            for i, group in enumerate(self.optimizer.param_groups):
                if self._initial_param_group_field not in group:
                    raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
        self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
        self.metric = None  # any point to having this for all?
        self.noise_range_t = noise_range_t
        self.noise_pct = noise_pct
        self.noise_type = noise_type
        self.noise_std = noise_std
        self.noise_seed = noise_seed if noise_seed is not None else 42
        self.update_groups(self.base_values)

    def state_dict(self) -> Dict[str, Any]:
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        self.__dict__.update(state_dict)

    def get_epoch_values(self, epoch: int):
        return None

    def get_update_values(self, num_updates: int):
        return None

    def step(self, epoch: int, metric: float = None) -> None:
        self.metric = metric
        values = self.get_epoch_values(epoch)
        if values is not None:
            values = self._add_noise(values, epoch)
            self.update_groups(values)

    def step_update(self, num_updates: int, metric: float = None):
        self.metric = metric
        values = self.get_update_values(num_updates)
        if values is not None:
            values = self._add_noise(values, num_updates)
            self.update_groups(values)

    def update_groups(self, values):
        if not isinstance(values, (list, tuple)):
            values = [values] * len(self.optimizer.param_groups)
        for param_group, value in zip(self.optimizer.param_groups, values):
            param_group[self.param_group_field] = value

    def _add_noise(self, lrs, t):
        if self.noise_range_t is not None:
            if isinstance(self.noise_range_t, (list, tuple)):
                apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
            else:
                apply_noise = t >= self.noise_range_t
            if apply_noise:
                g = torch.Generator()
                g.manual_seed(self.noise_seed + t)
                if self.noise_type == 'normal':
                    while True:
                        # resample if noise out of percent limit, brute force but shouldn't spin much
                        noise = torch.randn(1, generator=g).item()
                        if abs(noise) < self.noise_pct:
                            break
                else:
                    noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
                lrs = [v + v * noise for v in lrs]
        return lrs


class CosineLRScheduler(Scheduler):
    """
    Cosine decay with restarts.
    This is described in the paper https://arxiv.org/abs/1608.03983.

    Inspiration from
    https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
    """

    def __init__(self,
                 optimizer: torch.optim.Optimizer,
                 t_initial: int,
                 t_mul: float = 1.,
                 lr_min: float = 0.,
                 decay_rate: float = 1.,
                 warmup_t=0,
                 warmup_lr_init=0,
                 warmup_prefix=False,
                 cycle_limit=0,
                 t_in_epochs=True,
                 noise_range_t=None,
                 noise_pct=0.67,
                 noise_std=1.0,
                 noise_seed=42,
                 initialize=True) -> None:
        super().__init__(
            optimizer, param_group_field="lr",
            noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
            initialize=initialize)

        assert t_initial > 0
        assert lr_min >= 0
        if t_initial == 1 and t_mul == 1 and decay_rate == 1:
            _logger.warning("Cosine annealing scheduler will have no effect on the learning "
                           "rate since t_initial = t_mul = eta_mul = 1.")
        self.t_initial = t_initial
        self.t_mul = t_mul
        self.lr_min = lr_min
        self.decay_rate = decay_rate
        self.cycle_limit = cycle_limit
        self.warmup_t = warmup_t
        self.warmup_lr_init = warmup_lr_init
        self.warmup_prefix = warmup_prefix
        self.t_in_epochs = t_in_epochs
        if self.warmup_t:
            self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
            super().update_groups(self.warmup_lr_init)
        else:
            self.warmup_steps = [1 for _ in self.base_values]

    def _get_lr(self, t):
        if t < self.warmup_t:
            lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
        else:
            if self.warmup_prefix:
                t = t - self.warmup_t

            if self.t_mul != 1:
                i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
                t_i = self.t_mul ** i * self.t_initial
                t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
            else:
                i = t // self.t_initial
                t_i = self.t_initial
                t_curr = t - (self.t_initial * i)

            gamma = self.decay_rate ** i
            lr_min = self.lr_min * gamma
            lr_max_values = [v * gamma for v in self.base_values]

            if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
                lrs = [
                    lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
                ]
            else:
                lrs = [self.lr_min for _ in self.base_values]

        return lrs

    def get_epoch_values(self, epoch: int):
        if self.t_in_epochs:
            return self._get_lr(epoch)
        else:
            return None

    def get_update_values(self, num_updates: int):
        if not self.t_in_epochs:
            return self._get_lr(num_updates)
        else:
            return None

    def get_cycle_length(self, cycles=0):
        if not cycles:
            cycles = self.cycle_limit
        cycles = max(1, cycles)
        if self.t_mul == 1.0:
            return self.t_initial * cycles
        else:
            return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))

In [45]:
# Freeze params here if needed

model = YOLOS(
    num_classes=1,
    det_token_num=100,
    backbone='vit_tiny_patch16_224',
)
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#dev = torch.device("cpu")
model.to(dev)
eval_losses = SetCriterion(
    num_classes=1,
    eos_coef=0.1
)
eval_losses.to(dev)

#for param in model.feature_extractor.parameters():
#    param.requires_grad = False
#for n, p in model.named_parameters():
#    if not (n.startswith('det') or n.startswith('class_embed') or n.startswith('bbox_embed')):
#        p.requires_grad = False

train_params = list(filter(lambda p: p[1].requires_grad, model.named_parameters()))
nb_train = sum([int(torch.numel(p[1])) for p in train_params])
nb_tot = sum([int(torch.numel(p)) for p in model.parameters()])
print(f"Training {nb_train} params out of {nb_tot}")

#optimizer = torch.optim.SGD(
#    params=[p[1] for p in train_params],
#    lr=0.005,
#    momentum=0.9,
#    weight_decay=0.0005
#)

optimizer = torch.optim.AdamW(
    params=[p[1] for p in train_params],
    lr=1e-3,
    betas=(0.9,0.999),
    weight_decay=1e-4,
    eps=1e-8,
)

#lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
#    optimizer,
#    max_lr=1e-3,
#    steps_per_epoch=len(train_dataloader),
#    epochs=100)
#lr_scheduler = torch.optim.lr_scheduler.LinearLR(
#    optimizer=optimizer,
#    start_factor=1.,
#    end_factor=0.1,
#    total_iters=10000
#)

lr_scheduler = CosineLRScheduler(
    optimizer,
    t_initial=150,
    t_mul=1.,
    lr_min=1e-7,
    decay_rate=0.1,
    warmup_lr_init=1e-6,
    warmup_t=0,
    cycle_limit=1,
    t_in_epochs=True,
    noise_range_t=None,
    noise_pct=0.67,
    noise_std=1.,
    noise_seed=42,
)

Training 5905198 params out of 5905198


In [41]:
def unnorm_bounding_boxes(inpt):
    bounding_boxes = inpt.as_subclass(torch.Tensor)
    in_dtype = bounding_boxes.dtype
    bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
    whwh = torch.Tensor(inpt.canvas_size).repeat(2).flip(dims=(0,)).to(inpt.device) # canvas_size is H,W hence the flip to WHWH
    out_boxes = bounding_boxes*whwh 
    return tv_tensors.wrap(out_boxes.to(in_dtype), like=inpt)

In [50]:
gc.collect()
#torch.cuda.empty_cache()
gc.collect()

start_epoch = 0
for epoch in range(start_epoch, 150):
    time_ep = time.time()
    
    valid_loss = 0
    valid_bbox_loss = 0
    valid_giou_loss = 0
    valid_ce_loss = 0
    model.eval()
    with torch.no_grad():
        map_metric = MeanAveragePrecision(
            box_format='cxcywh', # make sure your dataset outputs target in xywh format
            backend='faster_coco_eval'
        )
        for batch in tqdm(val_dataloader, total=len(val_dataloader)):
            image = batch["image"].to(dev)
            targets = batch['target']
            pred_logits, pred_boxes = model.forward(image)
            tgt_cls = [tgt["labels"].to(dev) for tgt in targets]
            #tgt_boxes = [normalize_bbox(tgt["boxes"]).to(dev) for tgt in targets]
            tgt_boxes = [tgt["boxes"].to(dev) for tgt in targets]
            matches = hungarian_matching(
                pred_logits,
                pred_boxes,
                tgt_cls,
                tgt_boxes
            )
            losses = eval_losses(
                pred_logits,
                pred_boxes,
                tgt_cls,
                tgt_boxes,
                matches
            )

            loss = losses['combined_loss']
            valid_loss += loss.detach().item()
            valid_bbox_loss += losses["loss_bbox"].detach().item()
            valid_giou_loss += losses["loss_giou"].detach().item()
            valid_ce_loss += losses["loss_ce"].detach().item()
            b,c,h,w = image.shape
            preds = post_process(
                pred_logits.to("cpu"),
                pred_boxes.to("cpu"),
                (h,w)
            )
            for t in batch['target']:
                t['boxes'] = unnorm_bounding_boxes(t['boxes'])
            map_metric.update(preds, batch['target'])
        valid_loss /= len(val_dataloader)
        valid_bbox_loss /= len(val_dataloader)
        valid_giou_loss /= len(val_dataloader)
        valid_ce_loss /= len(val_dataloader)
        mapmetrics = map_metric.compute()
        print(f"{epoch = }")
        print(f"{valid_loss = }")
        print(f"{valid_bbox_loss = }")
        print(f"{valid_giou_loss = }")
        print(f"{valid_ce_loss = }")
        print(pformat(mapmetrics))
        map_metric.reset()
        
    train_loss = 0
    train_bbox_loss = 0
    train_giou_loss = 0
    train_ce_loss = 0
    model.train()
    for batch in tqdm(train_dataloader, total=len(train_dataloader)):
        image = batch["image"].to(dev)
        targets = batch['target']
        pred_logits, pred_boxes = model.forward(image)
        tgt_cls = [tgt["labels"].to(dev) for tgt in targets]
        #tgt_boxes = [normalize_bbox(tgt["boxes"]).to(dev) for tgt in targets]
        tgt_boxes = [tgt["boxes"].to(dev) for tgt in targets]
        matches = hungarian_matching(
            pred_logits,
            pred_boxes,
            tgt_cls,
            tgt_boxes
        )
        losses = eval_losses(
            pred_logits,
            pred_boxes,
            tgt_cls,
            tgt_boxes,
            matches
        )
        loss = losses['combined_loss']
        loss.backward()
        optimizer.step()
        #lr_scheduler.step()
        train_loss += loss.detach().item()
        train_bbox_loss += losses["loss_bbox"].detach().item()
        train_giou_loss += losses["loss_giou"].detach().item()
        train_ce_loss += losses["loss_ce"].detach().item()
    train_loss /= len(train_dataloader)
    train_bbox_loss /= len(train_dataloader)
    train_giou_loss /= len(train_dataloader)
    train_ce_loss /= len(train_dataloader)
    print(f"{epoch = }")
    #print(f"lr = {lr_scheduler.get_last_lr()[0]}"),
    print(f"{train_loss = }")
    print(f"{train_bbox_loss = }")
    print(f"{train_giou_loss = }")
    print(f"{train_ce_loss = }")
    lr_scheduler.step(epoch)
    time_ep = time.time() - time_ep

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 25.70it/s]


epoch = 0
valid_loss = 21.601595420837402
valid_bbox_loss = 1.681334180831909
valid_giou_loss = 1.8176218461990357
valid_ce_loss = 9.559680948257446
{'classes': tensor(0, dtype=torch.int32),
 'map': tensor(0.),
 'map_50': tensor(0.),
 'map_75': tensor(0.),
 'map_large': tensor(0.),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.),
 'mar_100': tensor(0.),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


 30%|████████████████████████████████▋                                                                            | 9/30 [00:01<00:03,  6.38it/s]


KeyboardInterrupt: 