In [1]:
import random
import os
import time
import gc 
import math

from matplotlib import pyplot as plt
from IPython import display
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, RandomSampler
import numpy as np
from scipy.optimize import linear_sum_assignment
import torchvision
from torchvision.ops import box_convert, generalized_box_iou
from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms import v2 as T
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelP6P7
import timm
from timm.layers import resample_abs_pos_embed 
from tqdm.auto import tqdm
from pprint import pformat
from torchmetrics.detection.mean_ap import MeanAveragePrecision

from dl_toolbox.transforms import NormalizeBB
from dl_toolbox.utils import list_of_dicts_to_dict_of_lists

  from .autonotebook import tqdm as notebook_tqdm


### Dataset

In [2]:
class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = read_image(img_path)
        mask = read_image(mask_path)
        # instances are encoded as different colors
        obj_ids = torch.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        num_objs = len(obj_ids)

        # split the color-encoded mask into a set
        # of binary masks
        masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)

        # get bounding box coordinates for each mask
        boxes = masks_to_boxes(masks)

        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)

        image_id = idx
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Wrap sample and targets into torchvision tv_tensors:
        img = tv_tensors.Image(img)
        h, w = T.functional.get_size(img)
        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(h,w))
        target["masks"] = tv_tensors.Mask(masks)
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return {'image': img, 'target': target}
    
    def __len__(self):
        return len(self.imgs)
    
def collate(batch):
    batch = list_of_dicts_to_dict_of_lists(batch)
    batch['image'] = torch.stack(batch['image'])
    return batch

### Model

In [3]:
class Scale(nn.Module):

    def __init__(self, init_value=1.0):
        super(Scale, self).__init__()
        self.scale = nn.Parameter(torch.FloatTensor([init_value]))

    def forward(self, x):
        return x * self.scale


class Head(nn.Module):

    def __init__(self, in_channels, n_classes, n_share_convs=4, n_feat_levels=5):
        super().__init__()
        
        tower = []
        for _ in range(n_share_convs):
            tower.append(
                nn.Conv2d(in_channels,
                          in_channels,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=True))
            tower.append(nn.GroupNorm(32, in_channels))
            tower.append(nn.ReLU())
        self.shared_layers = nn.Sequential(*tower)

        self.cls_logits = nn.Conv2d(in_channels,
                                    n_classes,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
        self.bbox_pred = nn.Conv2d(in_channels,
                                   4,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1)
        self.ctrness = nn.Conv2d(in_channels,
                                 1,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)

        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])

        #cls_tower = []
        #bbox_tower = []
        #for _ in range(n_share_convs):
        #    cls_tower.append(
        #        nn.Conv2d(in_channels,
        #                  in_channels,
        #                  kernel_size=3,
        #                  stride=1,
        #                  padding=1,
        #                  bias=True))
        #    cls_tower.append(nn.GroupNorm(32, in_channels))
        #    cls_tower.append(nn.ReLU())
        #    bbox_tower.append(
        #        nn.Conv2d(in_channels,
        #                  in_channels,
        #                  kernel_size=3,
        #                  stride=1,
        #                  padding=1,
        #                  bias=True))
        #    bbox_tower.append(nn.GroupNorm(32, in_channels))
        #    bbox_tower.append(nn.ReLU())
        #self.cls_layers = nn.Sequential(*cls_tower)
        #self.bbox_layers = nn.Sequential(*bbox_tower)
#
        #self.cls_logits = nn.Conv2d(in_channels,
        #                            n_classes,
        #                            kernel_size=3,
        #                            stride=1,
        #                            padding=1)
        #self.bbox_pred = nn.Conv2d(in_channels,
        #                           4,
        #                           kernel_size=3,
        #                           stride=1,
        #                           padding=1)
        #self.ctrness = nn.Conv2d(in_channels,
        #                         1,
        #                         kernel_size=3,
        #                         stride=1,
        #                         padding=1)
        #self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(n_feat_levels)])

        # initialize the bias for focal loss
        # Not in basara repo but it is done in the paper fcos repo
        #prior_prob = 0.01
        #bias_value = -math.log((1 - prior_prob) / prior_prob)
        #torch.nn.init.constant_(self.cls_logits.bias, bias_value)

    def forward(self, x):
        cls_logits = []
        bbox_preds = []
        cness_preds = []
        for l, features in enumerate(x):
            features = self.shared_layers(features)
            cls_logits.append(self.cls_logits(features).flatten(-2))
            cness_preds.append(self.ctrness(features).flatten(-2))
            reg = self.bbox_pred(features)
            reg = self.scales[l](reg)
            bbox_preds.append(nn.functional.relu(reg).flatten(-2))
        #for l, features in enumerate(x):
        #    cls_features = self.cls_layers(features) # BxinChannelsxfeatSize
        #    bbox_features = self.bbox_layers(features)
        #    # This flatten must flatten things in the same order anchors are flattened 
        #    cls_logits.append(self.cls_logits(cls_features).flatten(-2)) # BxNumClsxHf*Wf
        #    cness_preds.append(self.ctrness(bbox_features).flatten(-2)) # Bx1xHf*Wf
        #    reg = self.bbox_pred(bbox_features) # Bx4xFeatSize
        #    reg = self.scales[l](reg)
        #    bbox_preds.append(nn.functional.relu(reg).flatten(-2))
        all_logits = torch.cat(cls_logits, dim=-1).permute(0,2,1) # BxNumAnchorsxC
        all_box_regs = torch.cat(bbox_preds, dim=-1).permute(0,2,1) # BxNumAnchorsx4
        all_cness = torch.cat(cness_preds, dim=-1).permute(0,2,1) # BxNumAnchorsx1
        return all_logits, all_box_regs, all_cness
    
class FCOS(nn.Module):

    def __init__(self, n_classes, input_size):
        super().__init__()
        self.num_classes = n_classes
        self.input_size = input_size
        self.feature_extractor = create_feature_extractor(
            torchvision.models.convnext_tiny(
                weights=torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
            ), 
            {
                'features.3.2.add': 'layer2', # 1/8th feat map
                'features.5.8.add': 'layer3', # 1/16
                'features.7.2.add': 'layer4', # 1/32
            }
        )        
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=[192, 384, 768],
            out_channels=256,
            extra_blocks=LastLevelP6P7(256,256) # adds feature maps at strides 64, 128
        )
        self.fm_strides = [8, 16, 32, 64, 128] 
        features = nn.Sequential(self.feature_extractor, self.fpn)
        self.head = Head(256, self.num_classes)
        
        
    def forward(self, x):
        features = self.feature_extractor(x)
        feature_maps = list(self.fpn(features).values()) # feature maps from FPN
        box_cls, box_regression, centerness = self.head(feature_maps)
        return box_cls, box_regression, centerness

### Loss

In [4]:
class LossEvaluator(nn.Module):

    def __init__(self, num_classes):
        super(LossEvaluator, self).__init__()
        self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")
        self.num_classes = num_classes
                
    def __call__(self, cls_logits, reg_preds, cness_preds, cls_tgts, reg_tgts):
        pos_inds_b, pos_inds_loc = torch.nonzero(cls_tgts > 0, as_tuple=True)
        num_pos = len(pos_inds_b)
        reg_preds = reg_preds[pos_inds_b, pos_inds_loc, :]
        reg_tgts = reg_tgts[pos_inds_b, pos_inds_loc, :]
        cness_preds = cness_preds[pos_inds_b, pos_inds_loc, :].squeeze(-1)
        cness_tgts = self._compute_centerness_targets(reg_tgts)
        cls_loss = self._get_cls_loss(cls_logits, cls_tgts, max(num_pos, 1.))
        reg_loss, centerness_loss = 0,0
        if num_pos > 0:
            reg_loss = self._get_reg_loss(
                reg_preds, reg_tgts, cness_tgts)
            centerness_loss = self._get_centerness_loss(
                cness_preds, cness_tgts, num_pos)
        losses = {}
        losses["cls_loss"] = cls_loss
        losses["reg_loss"] = reg_loss
        losses["centerness_loss"] = centerness_loss
        losses["combined_loss"] = cls_loss + reg_loss + centerness_loss
        return losses
    
    def _compute_centerness_targets(self, reg_tgts):
        """
        Args:
            reg_tgts: l, t, r, b values to regress, shape BxNumAx4
        Returns:
            A tensor of shape BxNumA giving how centered each anchor is for the bbox it must regress
        """
        if len(reg_tgts) == 0:
            return reg_tgts.new_zeros(len(reg_tgts))
        left_right = reg_tgts[..., [0, 2]]
        top_bottom = reg_tgts[..., [1, 3]]
        centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                    (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        return torch.sqrt(centerness)

    def _get_cls_loss(self, cls_preds, cls_targets, num_pos_samples):
        """
        cls_targets takes values in 0...C, 0 only when there is no obj to be detected for the anchor
        """
        onehot = nn.functional.one_hot(cls_targets.long(), self.num_classes+1)[...,1:].float()
        cls_loss = torchvision.ops.sigmoid_focal_loss(cls_preds, onehot)
        return cls_loss.sum() / num_pos_samples

    def _get_reg_loss(self, reg_preds, reg_targets, centerness_targets):
        ltrb_preds = reg_preds.reshape(-1, 4)
        ltrb_tgts = reg_targets.reshape(-1, 4)
        xyxy_preds = torch.cat([-ltrb_preds[:,:2], ltrb_preds[:,2:]], dim=1) 
        xyxy_tgts = torch.cat([-ltrb_tgts[:,:2], ltrb_tgts[:,2:]], dim=1)
        reg_losses = torchvision.ops.distance_box_iou_loss(xyxy_preds, xyxy_tgts, reduction='none')
        sum_centerness_targets = centerness_targets.sum()
        reg_loss = (reg_losses * centerness_targets).sum() / sum_centerness_targets
        return reg_loss

    def _get_centerness_loss(self, centerness_preds, centerness_targets,
                             num_pos_samples):
        centerness_loss = self.centerness_loss_func(centerness_preds,
                                                    centerness_targets)
        return centerness_loss / num_pos_samples

### Matching

In [5]:
INF = 100000000

def get_fm_anchors(h, w, s):
    """
    Args:
        h, w: height, width of the feat map
        s: stride of the featmap = size reduction factor relative to image
    Returns:
        Tensor NumAnchorsInFeatMap x 2, ordered by column 
        TODO: check why: DONE: it corresponds to how locs are computed in 
        https://github.com/tianzhi0549/FCOS/blob/master/fcos_core/modeling/rpn/fcos/fcos.py
        When flattening feat maps, we see first the line at H(=y) fixed and W(=x) moving
        
    """
    locs_x = [s / 2 + x * s for x in range(w)]
    locs_y = [s / 2 + y * s for y in range(h)]
    locs = [(x, y) for y in locs_y for x in locs_x] # order !
    return torch.tensor(locs)

def get_all_anchors_bb_sizes(fm_sizes, fm_strides, bb_sizes):
    """
    Args:
        fm_sizes: seq of feature_maps sizes
        fm_strides: seq of corresponding strides
        bb_sizes: seq of bbox sizes feature maps are associated with, len = len(fm) + 1
    Returns:
        anchors: list of num_featmaps elem, where each elem indicates the tensor of anchors of size Nx2 in the original image corresponding to each location in the feature map at this level
        anchors_bb_sizes: sizes of the bbox each anchor is authorized/supposed to detect
    """
    bb_sizes = [-1] + bb_sizes + [INF]
    anchors, anchors_bb_sizes = [], []
    for l, ((h,w), s) in enumerate(zip(fm_sizes, fm_strides)):
        fm_anchors = get_fm_anchors(h, w, s)
        sizes = torch.tensor([bb_sizes[l], bb_sizes[l+1]], dtype=torch.float32)
        sizes = sizes.repeat(len(fm_anchors)).view(len(fm_anchors), 2)
        anchors.append(fm_anchors)
        anchors_bb_sizes.append(sizes)
    return torch.cat(anchors, 0), torch.cat(anchors_bb_sizes, 0)

In [6]:
def calculate_reg_targets(anchors, bbox):
    """
    Args:
        anchors: Lx2, anchors coordinates
        bbox: tensor of bbox Tx4, format should be xywh
    Returns:
        reg_tgt: l,t,r,b values to regress for each pair (anchor, bbox)
        anchor_in_box: whether anchor is in bbox for each pair (anchor, bbox)
    """
    xs, ys = anchors[:, 0], anchors[:, 1] # L & L, x & y reversed ?? x means position on x-axis
    l = xs[:, None] - bbox[:, 0][None] # Lx1 - 1xT -> LxT
    t = ys[:, None] - bbox[:, 1][None]
    r = bbox[:, 2][None] + bbox[:, 0][None] - xs[:, None]
    b = bbox[:, 3][None] + bbox[:, 1][None] - ys[:, None]  
    #print(xs[0], ys[0], l[0], t[0], r[0], b[0])
    return torch.stack([l, t, r, b], dim=2) # LxTx4

def apply_distance_constraints(reg_targets, anchor_sizes):
    """
    Args:
        reg_targets: LxTx4
        anchor_bb_sizes: Lx2
    Returns:
        A LxT tensor where value at (anchor, bbox) is true if the max value to regress at this anchor for this bbox is inside the bounds associated to this anchor
        If other values to regress than the max are negatives, it is dealt with anchor_in_boxes.
    """
    max_reg_targets, _ = reg_targets.max(dim=2) # LxT
    min_reg_targets, _ = reg_targets.min(dim=2) # LxT
    dist_constraints = torch.stack([
        min_reg_targets > 0,
        max_reg_targets >= anchor_sizes[:, None, 0],
        max_reg_targets <= anchor_sizes[:, None, 1]
    ])
    return torch.all(dist_constraints, dim=0)

def anchor_bbox_area(bbox, anchors, fits_to_feature_level):
    """
    Args: bbox is XYWH
    Returns: 
        Tensor LxT where value at (anchor, bbox) is the area of bbox if anchor is in bbox and anchor is associated with bbox of that size
        Else INF.
    """
    #bbox_areas = _calc_bbox_area(bbox_targets) # T
    bbox_areas = bbox[:, 2] * bbox[:, 3] # T
    # area of each target bbox repeated for each loc with inf where the the loc is not 
    # in the target bbox or if the loc is not at the right level for this bbox size
    anchor_bbox_area = bbox_areas[None].repeat(len(anchors), 1) # LxT
    anchor_bbox_area[~fits_to_feature_level] = INF
    return anchor_bbox_area

def associate_targets_to_anchors(targets_batch, anchors, anchors_bb_sizes):
    """
    Associate one target cls/bbox to regress ONLY to each anchor: among the bboxes that contain the anchor and have the right size, pick that of min area.
    If no tgt exists for an anchor, the tgt class is 0.
    inputs:
        targets_batch: list of dict of tv_tensors {'labels':, 'boxes':}; boxes should be in XYWH format
        anchors: 
        anchor_bb_sizes:
    outputs:
        all class targets: BxNumAnchors
        all bbox targets: BxNumAnchorsx4
    """
    all_reg_targets, all_cls_targets = [], []
    for targets in targets_batch:
        bbox_targets = targets['boxes'] # Tx4, format XYWH
        cls_targets = targets['labels'] # T
        reg_targets = calculate_reg_targets(
            anchors, bbox_targets) # LxTx4, LxT
        fits_to_feature_level = apply_distance_constraints(
            reg_targets, anchors_bb_sizes) # LxT
        locations_to_gt_area = anchor_bbox_area(
            bbox_targets, anchors, fits_to_feature_level)
        # Core of the anchor/target association
        if cls_targets.shape[0]>0:
            loc_min_area, loc_min_idxs = locations_to_gt_area.min(dim=1) #L,idx in [0,T-1],T must be>0
            reg_targets = reg_targets[range(len(anchors)), loc_min_idxs] # Lx4
            cls_targets = cls_targets[loc_min_idxs] # L
            cls_targets[loc_min_area == INF] = 0 # 0 is no-obj category
        else:
            cls_targets = cls_targets.new_zeros((len(anchors),))
            reg_targets = reg_targets.new_zeros((len(anchors),4))
        all_cls_targets.append(cls_targets)
        all_reg_targets.append(reg_targets)
    # BxL & BxLx4
    return torch.stack(all_cls_targets), torch.stack(all_reg_targets)

### Post-processing predictions to boxes

In [7]:
pre_nms_thresh=0.3
pre_nms_top_n=100000
nms_thresh=0.45
fpn_post_nms_top_n=50
min_size=0

def post_process(logits, ltrb, cness, input_size):
    probas = logits.sigmoid() # LxC
    high_probas = probas > pre_nms_thresh # LxC
    # Indices on L and C axis of high prob pairs anchor/class
    high_prob_anchors_idx, high_prob_cls = high_probas.nonzero(as_tuple=True) # dim l <= L*C
    high_prob_cls += 1 # 0 is for no object
    high_prob_ltrb = ltrb[high_prob_anchors_idx] # lx4
    high_prob_anchors = anchors[high_prob_anchors_idx] # lx2
    # Tensor shape l with values from logits*cness such that logits > pre_nms_thresh 
    cness_modulated_probas = probas * cness.sigmoid() # LxC
    high_prob_scores = cness_modulated_probas[high_probas] # l
    # si l est trop longue
    if high_probas.sum().item() > pre_nms_top_n:
        # Filter the pre_nms_top_n most probable pairs 
        high_prob_scores, top_k_indices = high_prob_scores.topk(
            pre_nms_top_n, sorted=False) 
        high_prob_cls = high_prob_cls[top_k_indices]
        high_prob_ltrb = high_prob_ltrb[top_k_indices]
        high_prob_anchors = high_prob_anchors[top_k_indices]

    # Rewrites bbox (x0,y0,x1,y1) from reg targets (l,t,r,b) following eq (1) in paper
    high_prob_boxes = torch.stack([
        high_prob_anchors[:, 0] - high_prob_ltrb[:, 0],
        high_prob_anchors[:, 1] - high_prob_ltrb[:, 1],
        high_prob_anchors[:, 0] + high_prob_ltrb[:, 2],
        high_prob_anchors[:, 1] + high_prob_ltrb[:, 3],
    ], dim=1)

    high_prob_boxes = torchvision.ops.clip_boxes_to_image(high_prob_boxes, input_size)
    big_enough_box_idxs = torchvision.ops.remove_small_boxes(high_prob_boxes, min_size)
    boxes = high_prob_boxes[big_enough_box_idxs]
    # Why not do that on scores and classes too ? 
    classes = high_prob_cls[big_enough_box_idxs]
    scores = high_prob_scores[big_enough_box_idxs]
    #high_prob_scores = torch.sqrt(high_prob_scores) # WHY SQRT ? REmOVED
    # NMS expects boxes to be in xyxy format
    nms_idxs = torchvision.ops.nms(boxes, scores, nms_thresh)
    boxes = boxes[nms_idxs]
    scores = scores[nms_idxs]
    classes = classes[nms_idxs]
    if len(nms_idxs) > fpn_post_nms_top_n:
        image_thresh, _ = torch.kthvalue(
            scores.cpu(),
            len(nms_idxs) - fpn_post_nms_top_n + 1)
        keep = scores >= image_thresh.item()
        #keep = torch.nonzero(keep).squeeze(1)
        boxes, scores, classes = boxes[keep], scores[keep], classes[keep]
    # Then back to xywh boxes for preds and metric computation
    boxes[:, 2] -= boxes[:, 0]
    boxes[:, 3] -= boxes[:, 1]
    
    # Isn't this cond auto valid from the beginning filter ?
    #keep = scores >= pre_nms_thresh
    #boxes, scores, classes = boxes[keep], scores[keep], classes[keep]
    return boxes, scores, classes 

def post_process_batch(
    cls_preds, # B x L x C 
    reg_preds, # B x L x 4
    cness_preds, # B x L x 1
    input_size
): 
    preds = []
    for logits, ltrb, cness in zip(cls_preds, reg_preds, cness_preds):
        boxes, scores, classes = post_process(logits, ltrb, cness, input_size)
        preds.append({'boxes': boxes, 'scores': scores, 'labels': classes})
    return preds

### Seeds

In [8]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

### Instanciations

In [9]:
tf = T.Compose(
    [
        T.ToDtype(torch.float, scale=True),
        T.Resize(size=480, max_size=640),
        T.RandomCrop(size=(640,640), pad_if_needed=True, fill=0),
        T.ConvertBoundingBoxFormat(format='XYWH'),
        T.SanitizeBoundingBoxes(),
        #T.ToPureTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

dataset = PennFudanDataset('/data/PennFudanPed', tf)
dataset_test = PennFudanDataset('/data/PennFudanPed', tf)
# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
train_set = torch.utils.data.Subset(dataset, indices[:-50])
val_set = torch.utils.data.Subset(dataset_test, indices[-50:])

train_dataloader = DataLoader(
    batch_size=4,
    num_workers=4,
    pin_memory=True,
    dataset=train_set,
    sampler=RandomSampler(
        train_set,
        replacement=True,
        num_samples=100*2
    ),
    drop_last=True,
    collate_fn=collate
)

val_dataloader = DataLoader(
    batch_size=5,
    num_workers=4,
    pin_memory=True,
    dataset=val_set,
    shuffle=False,
    drop_last=False,
    collate_fn=collate
)

In [10]:
# Freeze params here if needed
    
#for param in model.feature_extractor.parameters():
#    param.requires_grad = False

anchors, anchor_sizes = get_all_anchors_bb_sizes(
    fm_sizes=[(80,80),(40,40),(20,20),(10,10),(5,5)],
    fm_strides=[8, 16, 32, 64, 128],
    bb_sizes=[64, 128, 256, 512]
)

model = FCOS(n_classes=1, input_size=(640,640))
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#dev = torch.device("cpu")
model.to(dev)
eval_losses = LossEvaluator(
    num_classes=1
)

train_params = list(filter(lambda p: p[1].requires_grad, model.named_parameters()))
nb_train = sum([int(torch.numel(p[1])) for p in train_params])
nb_tot = sum([int(torch.numel(p)) for p in model.parameters()])
print(f"Training {nb_train} params out of {nb_tot}")

#optimizer = torch.optim.SGD(
#    params=[p[1] for p in train_params],
#    lr=0.005,
#    momentum=0.9,
#    weight_decay=0.0005
#)

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=1e-3,
    betas=(0.9,0.999),
    weight_decay=5e-2,
    eps=1e-8,
)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    steps_per_epoch=len(train_dataloader),
    epochs=100)
#lr_scheduler = torch.optim.lr_scheduler.LinearLR(
#    optimizer=optimizer,
#    start_factor=1.,
#    end_factor=0.1,
#    total_iters=20
#)

Training 33490027 params out of 33490027


### Training

In [11]:
#gc.collect()
#torch.cuda.empty_cache()
#gc.collect()

start_epoch = 0
for epoch in range(start_epoch, 100):
    time_ep = time.time()
    
    valid_loss = 0
    valid_cls_loss = 0
    valid_reg_loss = 0
    valid_cen_loss = 0
    model.eval()
    with torch.no_grad():
        map_metric = MeanAveragePrecision(
            box_format='xywh', # make sure your dataset outputs target in xywh format
            backend='faster_coco_eval'
        )
        for batch in tqdm(val_dataloader, total=len(val_dataloader)):
            cls_tgts, reg_tgts = associate_targets_to_anchors(
                batch['target'],
                anchors,
                anchor_sizes
            )
            image = batch["image"].to(dev)
            cls_logits, bbox_reg, centerness = model(image)
            losses = eval_losses(
                cls_logits,
                bbox_reg,
                centerness,
                cls_tgts.to(dev),
                reg_tgts.to(dev)
            )
            loss = losses['combined_loss']
            valid_loss += loss.detach().item()
            valid_cls_loss += losses["cls_loss"].detach().item()
            valid_reg_loss += losses["reg_loss"].detach().item()
            valid_cen_loss += losses["centerness_loss"].detach().item()
            b,c,h,w = image.shape
            preds = post_process_batch(
                cls_logits.to("cpu"),
                bbox_reg.to("cpu"),
                centerness.to("cpu"),
                (h,w)
            )
            map_metric.update(preds, batch['target'])
        valid_loss /= len(val_dataloader)
        valid_cls_loss /= len(val_dataloader)
        valid_reg_loss /= len(val_dataloader)
        valid_cen_loss /= len(val_dataloader)
        mapmetrics = map_metric.compute()
        print(f"{epoch = }")
        print(f"{valid_loss = }")
        print(f"{valid_cls_loss = }")
        print(f"{valid_reg_loss = }")
        print(f"{valid_cen_loss = }")
        print(pformat(mapmetrics))
        map_metric.reset()
    train_loss = 0
    train_cls_loss = 0
    train_reg_loss = 0
    train_cen_loss = 0
    model.train()
    for batch in tqdm(train_dataloader, total=len(train_dataloader)):
        image = batch["image"].to(dev)               
        optimizer.zero_grad()
        cls_logits, bbox_reg, centerness = model(image)
        cls_tgts, reg_tgts = associate_targets_to_anchors(
            batch['target'],
            anchors,
            anchor_sizes
        ) # BxNumAnchors, BxNumAnchorsx4    
        losses = eval_losses(
            cls_logits,
            bbox_reg,
            centerness,
            cls_tgts.to(dev),
            reg_tgts.to(dev)
        )
        loss = losses['combined_loss']
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        train_loss += loss.detach().item()
        train_cls_loss += losses["cls_loss"].detach().item()
        train_reg_loss += losses["reg_loss"].detach().item()
        train_cen_loss += losses["centerness_loss"].detach().item()
    train_loss /= len(train_dataloader)
    train_cls_loss /= len(train_dataloader)
    train_reg_loss /= len(train_dataloader)
    train_cen_loss /= len(train_dataloader)
    print(f"{epoch = }")
    print(f"lr = {lr_scheduler.get_last_lr()[0]}"),
    print(f"{train_loss = }")
    print(f"{train_cls_loss = }")
    print(f"{train_reg_loss = }")
    print(f"{train_cen_loss = }")
    
    time_ep = time.time() - time_ep

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.41it/s]


epoch = 0
valid_loss = 14.82506217956543
valid_cls_loss = 13.09198055267334
valid_reg_loss = 1.0352440237998963
valid_cen_loss = 0.6978376269340515
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.),
 'map_50': tensor(0.),
 'map_75': tensor(0.),
 'map_large': tensor(0.),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.),
 'mar_100': tensor(0.),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00,  1.83it/s]


epoch = 0
lr = 4.263299650727223e-05
train_loss = 2.417167649269104
train_cls_loss = 0.719734572172165
train_reg_loss = 1.0356085968017579
train_cen_loss = 0.6618244934082032


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.31it/s]


epoch = 1
valid_loss = 2.0221211314201355
valid_cls_loss = 0.3325942486524582
valid_reg_loss = 1.035195004940033
valid_cen_loss = 0.6543318688869476
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.),
 'map_50': tensor(0.),
 'map_75': tensor(0.),
 'map_large': tensor(0.),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.),
 'mar_100': tensor(0.),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]


epoch = 1
lr = 5.050309990155808e-05
train_loss = 1.9198297452926636
train_cls_loss = 0.2388559240102768
train_reg_loss = 1.0358405208587647
train_cen_loss = 0.645133306980133


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.62it/s]


epoch = 2
valid_loss = 1.920315182209015
valid_cls_loss = 0.2506189927458763
valid_reg_loss = 1.0346094131469727
valid_cen_loss = 0.6350867629051209
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.),
 'map_50': tensor(0.),
 'map_75': tensor(0.),
 'map_large': tensor(0.),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.),
 'mar_100': tensor(0.),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]


epoch = 2
lr = 6.35239687047371e-05
train_loss = 1.8325779342651367
train_cls_loss = 0.18071975752711297
train_reg_loss = 1.0350324892997742
train_cen_loss = 0.616825704574585


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.48it/s]


epoch = 3
valid_loss = 1.8712484240531921
valid_cls_loss = 0.22511969804763793
valid_reg_loss = 1.0339999675750733
valid_cen_loss = 0.612128734588623
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.),
 'map_50': tensor(0.),
 'map_75': tensor(0.),
 'map_large': tensor(0.),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.),
 'mar_100': tensor(0.),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:26<00:00,  1.86it/s]


epoch = 3
lr = 8.155275332480708e-05
train_loss = 1.7905151915550233
train_cls_loss = 0.15047323271632196
train_reg_loss = 1.036021852493286
train_cen_loss = 0.604020105600357


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.57it/s]


epoch = 4
valid_loss = 1.8499295949935912
valid_cls_loss = 0.20602262988686562
valid_reg_loss = 1.032843542098999
valid_cen_loss = 0.611063402891159
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.),
 'map_50': tensor(0.),
 'map_75': tensor(0.),
 'map_large': tensor(0.),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.),
 'mar_100': tensor(0.),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:26<00:00,  1.86it/s]


epoch = 4
lr = 0.0001043916632328725
train_loss = 1.750505928993225
train_cls_loss = 0.12604007415473462
train_reg_loss = 1.0296986556053163
train_cen_loss = 0.5947671914100647


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.49it/s]


epoch = 5
valid_loss = 1.846329128742218
valid_cls_loss = 0.22150112241506575
valid_reg_loss = 1.0132998168468474
valid_cen_loss = 0.6115281820297241
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.),
 'map_50': tensor(0.),
 'map_75': tensor(0.),
 'map_large': tensor(0.),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.),
 'mar_10': tensor(0.),
 'mar_100': tensor(0.),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]


epoch = 5
lr = 0.00013179013688719224
train_loss = 1.5944072461128236
train_cls_loss = 0.11711918726563454
train_reg_loss = 0.8811774945259094
train_cen_loss = 0.5961105728149414


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.45it/s]


epoch = 6
valid_loss = 1.4950928688049316
valid_cls_loss = 0.17357703447341918
valid_reg_loss = 0.7125472903251648
valid_cen_loss = 0.6089685082435607
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.0041),
 'map_50': tensor(0.0376),
 'map_75': tensor(0.),
 'map_large': tensor(0.0048),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.0090),
 'mar_10': tensor(0.0203),
 'mar_100': tensor(0.0263),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.0310),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:26<00:00,  1.86it/s]


epoch = 6
lr = 0.0001634475905984478
train_loss = 1.3139541459083557
train_cls_loss = 0.09564495906233787
train_reg_loss = 0.6324558925628662
train_cen_loss = 0.5858532929420471


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.55it/s]


epoch = 7
valid_loss = 1.3440670251846314
valid_cls_loss = 0.17379260882735253
valid_reg_loss = 0.5676292240619659
valid_cen_loss = 0.602645194530487
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.2407),
 'map_50': tensor(0.7173),
 'map_75': tensor(0.0755),
 'map_large': tensor(0.2793),
 'map_medium': tensor(0.0169),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.1556),
 'mar_10': tensor(0.3263),
 'mar_100': tensor(0.3346),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.3832),
 'mar_medium': tensor(0.0800),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00,  1.85it/s]


epoch = 7
lr = 0.00019901671617892735
train_loss = 1.1946789288520814
train_cls_loss = 0.09079403638839721
train_reg_loss = 0.5218484050035477
train_cen_loss = 0.5820364904403686


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.44it/s]


epoch = 8
valid_loss = 1.3013135194778442
valid_cls_loss = 0.20389912873506547
valid_reg_loss = 0.5009279400110245
valid_cen_loss = 0.5964864492416382
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.2426),
 'map_50': tensor(0.6824),
 'map_75': tensor(0.0845),
 'map_large': tensor(0.2806),
 'map_medium': tensor(0.0492),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.1361),
 'mar_10': tensor(0.3774),
 'mar_100': tensor(0.3835),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.4425),
 'mar_medium': tensor(0.0667),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:26<00:00,  1.85it/s]


epoch = 8
lr = 0.00023810729119771383
train_loss = 1.1606859064102173
train_cls_loss = 0.09730872012674809
train_reg_loss = 0.47939042031764983
train_cen_loss = 0.5839867722988129


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.40it/s]


epoch = 9
valid_loss = 1.3110824942588806
valid_cls_loss = 0.23785126879811286
valid_reg_loss = 0.47374489307403567
valid_cen_loss = 0.5994863510131836
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.2586),
 'map_50': tensor(0.6904),
 'map_75': tensor(0.1205),
 'map_large': tensor(0.3040),
 'map_medium': tensor(0.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.1617),
 'mar_10': tensor(0.3331),
 'mar_100': tensor(0.3361),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.3956),
 'mar_medium': tensor(0.),
 'mar_small': tensor(0.)}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00,  1.83it/s]


epoch = 9
lr = 0.00028029046004025805
train_loss = 1.076534835100174
train_cls_loss = 0.09240377202630043
train_reg_loss = 0.3995469397306442
train_cen_loss = 0.5845841181278228


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.59it/s]


epoch = 10
valid_loss = 1.122015082836151
valid_cls_loss = 0.13450675159692765
valid_reg_loss = 0.3954452157020569
valid_cen_loss = 0.5920631289482117
{'classes': tensor(1, dtype=torch.int32),
 'map': tensor(0.3304),
 'map_50': tensor(0.7984),
 'map_75': tensor(0.1482),
 'map_large': tensor(0.3835),
 'map_medium': tensor(0.0492),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.),
 'mar_1': tensor(0.1895),
 'mar_10': tensor(0.4038),
 'mar_100': tensor(0.4098),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.4699),
 'mar_medium': tensor(0.0933),
 'mar_small': tensor(0.)}


 38%|█████████████████████████████████████████████▌                                                                          | 19/50 [00:10<00:17,  1.75it/s]


KeyboardInterrupt: 