In [1]:
import torchvision
from pathlib import Path
data_path = Path("/data/coco")

ds = torchvision.datasets.CocoDetection(
    data_path/"train2017",
    data_path/"annotations/instances_train2017.json", 
    None
)

  from .autonotebook import tqdm as notebook_tqdm


loading annotations into memory...
Done (t=12.07s)
creating index...
index created!


In [2]:
sample = ds[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")
print(target[0]['id'])

dataset = torchvision.datasets.wrap_dataset_for_transforms_v2(ds, target_keys=("boxes", "labels", "masks"))

sample = dataset[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }")
print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['masks']) = }")
print(target['boxes'])

type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'list'>
type(target[0]) = <class 'dict'>
target[0].keys() = dict_keys(['segmentation', 'area', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])
1038967
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'dict'>
target.keys() = dict_keys(['boxes', 'masks', 'labels'])
type(target['boxes']) = <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>
type(target['labels']) = <class 'torch.Tensor'>
type(target['masks']) = <class 'torchvision.tv_tensors._mask.Mask'>
BoundingBoxes([[  1.0800, 187.6900, 612.6700, 473.5300],
               [311.7300,   4.3100, 631.0100, 232.9900],
               [249.6000, 229.2700, 565.8400, 474.3500],
               [  0.0000,  13.5100, 434.4800, 388.6300],
               [376.2000,  40.3600, 451.7500,  86.8900],
               [465.7800,  38.9700, 523.8500,  85.6400],
               [385.7000,  73.6600, 469.7200, 144.1700],
               [364.0500,   2.4900, 458.8100,  73.5600]], 

In [3]:
from torch.utils.data import DataLoader, RandomSampler
from pathlib import Path
import pandas as pd
import numpy as np
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import CombinedLoader
from functools import partial
from torchvision.transforms import v2
import torch

class Coco(LightningDataModule):
    
    def __init__(
        self,
        data_path,
        train_tf,
        test_tf,
        batch_size_s,
        steps_per_epoch,
        num_workers,
        pin_memory,
        *args,
        **kwargs
    ):
        super().__init__()
        self.data_path = Path(data_path)/"coco"
        self.train_tf = train_tf
        self.test_tf = test_tf
        self.batch_size_s = batch_size_s
        self.steps_per_epoch = steps_per_epoch
        self.num_workers = num_workers
        self.pin_memory = pin_memory
    
    def setup(self, stage=None):
        train_s_set = torchvision.datasets.CocoDetection(
            self.data_path/"train2017",
            self.data_path/"annotations/instances_train2017.json", 
            self.train_tf)
        self.train_s_set = torchvision.datasets.wrap_dataset_for_transforms_v2(
            train_s_set, target_keys=("boxes", "labels"))
        val_set = torchvision.datasets.CocoDetection(
            self.data_path/"val2017",
            self.data_path/"annotations/instances_val2017.json", 
            self.test_tf)
        self.val_set = torchvision.datasets.wrap_dataset_for_transforms_v2(
            val_set, target_keys=("boxes", "labels"))
    
    @staticmethod
    def _collate(batch):
        images_b, targets_b = list(zip(*batch))
        boxes = [t['boxes'] for t in targets_b]
        labels = [t['labels'] for t in targets_b]
        # don't stack bb because each batch elem may not have the same nb of bb
        return torch.stack(images_b), boxes, labels
                
    def _dataloader(self, dataset):
        return partial(
            DataLoader,
            dataset=dataset,
            collate_fn=self._collate,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory
        )
                       
    def train_dataloader(self):
        return self._dataloader(self.train_s_set)(
            sampler=RandomSampler(
                self.train_s_set,
                replacement=True,
                num_samples=self.steps_per_epoch*self.batch_size_s
            ),
            drop_last=True,
            batch_size=self.batch_size_s
        )
    
    def val_dataloader(self):
        return self._dataloader(self.val_set)(
            shuffle=False,
            drop_last=False,
            batch_size=self.batch_size_s
        )
    
train_tf = v2.Compose([
    v2.ToImage(),
    v2.Resize(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_tf = v2.Compose([
    v2.ToImage(),
    v2.Resize(size=(224, 224), antialias=True),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dm = Coco(
    data_path='/data',
    train_tf=train_tf,
    test_tf=test_tf,
    batch_size_s=4,
    steps_per_epoch=10,
    num_workers=4,
    pin_memory=True,
)
#dm.prepare_data()
#dm.setup()

In [4]:
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock
from torchvision.ops.misc import Conv2dNormActivation
from typing import Callable, Dict, List, Optional, Tuple
from torch import nn, Tensor
import timm
from torchvision.ops.feature_pyramid_network import LastLevelP6P7
from dl_toolbox.networks.fcos import Head
import torch.nn as nn
from torchvision.models.feature_extraction import get_graph_node_names
from einops import rearrange

class LayerNorm2d(nn.LayerNorm):
    """ LayerNorm for channels of '2D' spatial NCHW tensors """
    def __init__(self, num_channels, eps=1e-6, affine=True):
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(
            x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)

class SimpleFeaturePyramidNetwork(nn.Module):
    """
    Module that adds a Simple FPN from on top of a set of feature maps. This is based on
    `"Exploring Plain Vision Transformer Backbones for Object Detection" <https://arxiv.org/abs/2203.16527>`_.

    Unlike regular FPN, Simple FPN expects a single feature map,
    on which the Simple FPN will be added.

    Args:
        in_channels (int): number of channels for the input feature map that
            is passed to the module
        out_channels (int): number of channels of the Simple FPN representation
        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: LayerNorm

    Examples::

        >>> m = torchvision.ops.SimpleFeaturePyramidNetwork(10, 5)
        >>> # get some dummy data
        >>> x = torch.rand(1, 10, 64, 64)
        >>> # compute the Simple FPN on top of x
        >>> output = m(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('feat0', torch.Size([1, 5, 64, 64])),
        >>>    ('feat2', torch.Size([1, 5, 16, 16])),
        >>>    ('feat3', torch.Size([1, 5, 8, 8]))]

    """

    _version = 2

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ):
        super().__init__()
        self.blocks = nn.ModuleList()
        if in_channels <= 0:
            raise ValueError("in_channels <= 0 is currently not supported")

        for block_index in range(1,4):
            layers = []

            current_in_channels = in_channels
            if block_index == 0:
                layers.extend([
                    nn.ConvTranspose2d(
                        in_channels,
                        in_channels // 2,
                        kernel_size=2,
                        stride=2,
                    ),
                    norm_layer(in_channels // 2),
                    nn.GELU(),
                    nn.ConvTranspose2d(
                        in_channels // 2,
                        in_channels // 4,
                        kernel_size=2,
                        stride=2,
                    ),
                ])
                current_in_channels = in_channels // 4
            elif block_index == 1:
                layers.append(
                    nn.ConvTranspose2d(
                        in_channels,
                        in_channels // 2,
                        kernel_size=2,
                        stride=2,
                    ),
                )
                current_in_channels = in_channels // 2
            elif block_index == 2:
                # nothing to do for this scale
                pass
            elif block_index == 3:
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

            layers.extend([
                Conv2dNormActivation(
                    current_in_channels,
                    out_channels,
                    kernel_size=1,
                    padding=0,
                    norm_layer=norm_layer,
                    activation_layer=None
                ),
                Conv2dNormActivation(
                    out_channels,
                    out_channels,
                    kernel_size=3,
                    norm_layer=norm_layer,
                    activation_layer=None
                )
            ])

            self.blocks.append(nn.Sequential(*layers))

        if extra_blocks is not None:
            if not isinstance(extra_blocks, ExtraFPNBlock):
                raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
        self.extra_blocks = extra_blocks

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        """
        Computes the Simple FPN for a feature map.

        Args:
            x (Tensor): input feature map.

        Returns:
            results (list[Tensor]): feature maps after FPN layers.
                They are ordered from highest resolution first.
        """
        results = [block(x) for block in self.blocks]
        names = [f"{i}" for i in range(len(self.blocks))]

        if self.extra_blocks is not None:
            results, names = self.extra_blocks(results, [x], names)

        # make it back an OrderedDict
        out = OrderedDict([(k, v) for k, v in zip(names, results)])

        return out
    
x = torch.rand(1, 10, 64, 64)
m = SimpleFeaturePyramidNetwork(10, 5, 
        extra_blocks=LastLevelP6P7(5,5),
        norm_layer=LayerNorm2d)
output = m(x)

In [5]:
#from dl_toolbox.networks import FCOS
#model = FCOS(num_classes=20)


#print(get_graph_node_names(resnet50())[0])

class ViTSimpleFPN(nn.Module):
    
    def __init__(self, num_classes=19, out_channels=256):
        super(ViTSimpleFPN, self).__init__()
        self.backbone = timm.create_model('vit_small_patch14_dinov2', pretrained=True, dynamic_img_size=True)
        self.sfpn = SimpleFeaturePyramidNetwork(
            in_channels=384,
            out_channels=256,
            extra_blocks=LastLevelP6P7(256,256),
            norm_layer=LayerNorm2d
        )
        
        inp = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            out = self.forward_feat(inp)
        self.feat_sizes = [o.shape[2:] for o in out.values()]
        self.head = Head(out_channels, num_classes, n_feat_levels=6)
    
    def forward_feat(self, x):
        H, W = x.size(2), x.size(3)
        GS = H // self.backbone.patch_embed.patch_size[0]
        x = self.backbone.forward_features(x)
        x = x[:,self.backbone.num_prefix_tokens:,...]
        x = rearrange(x, "b (h w) c -> b c h w", h=GS)
        x = self.sfpn(x)
        return x 
    
    def forward(self, x):
        feat_dict = self.forward_feat(x)
        features = list(feat_dict.values())
        box_cls, box_regression, centerness = self.head(features)
        # box_reg: lists of n_feat_level tensors BxHW(level)x4
        # why not tensor Bxsum_level(HW)x4 ?
        return features, box_cls, box_regression, centerness
        
    
network = ViTSimpleFPN(num_classes=19)
print(network.feat_sizes)

[torch.Size([32, 32]), torch.Size([16, 16]), torch.Size([8, 8]), torch.Size([4, 4]), torch.Size([2, 2])]


In [6]:
import torch.nn as nn
import torchvision

INF = 100000000
MAXIMUM_DISTANCES_PER_LEVEL = [-1, 64, 128, 256, 512, INF]

def _match_reg_distances_shape(MAXIMUM_DISTANCES_PER_LEVEL, num_locs_per_level):
    level_reg_distances = []
    for m in range(1, len(MAXIMUM_DISTANCES_PER_LEVEL)):
        level_distances = torch.tensor([
            MAXIMUM_DISTANCES_PER_LEVEL[m - 1], MAXIMUM_DISTANCES_PER_LEVEL[m]
        ],
                                       dtype=torch.float32)
        locs_per_level = num_locs_per_level[m - 1]
        level_distances = level_distances.repeat(locs_per_level).view(
            locs_per_level, 2)
        level_reg_distances.append(level_distances)
    # return tensor of size sum of locs_per_level x 2
    return torch.cat(level_reg_distances, dim=0)

def _compute_centerness_targets(reg_targets):
    if len(reg_targets) == 0:
        return reg_targets.new_zeros(len(reg_targets))
    left_right = reg_targets[:, [0, 2]]
    top_bottom = reg_targets[:, [1, 3]]
    centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
    return torch.sqrt(centerness)


def _calculate_reg_targets(xs, ys, bbox_targets):
    l = xs[:, None] - bbox_targets[:, 0][None] # Lx1 - 1xT -> LxT
    t = ys[:, None] - bbox_targets[:, 1][None]
    r = bbox_targets[:, 2][None] - xs[:, None]
    b = bbox_targets[:, 3][None] - ys[:, None]
    return torch.stack([l, t, r, b], dim=2) # LxTx4


def _apply_distance_constraints(reg_targets, level_distances):
    max_reg_targets, _ = reg_targets.max(dim=2)
    return torch.logical_and(max_reg_targets >= level_distances[:, None, 0], \
                             max_reg_targets <= level_distances[:, None, 1])


def _prepare_labels(locations, targets_batch):
    device = targets_batch[0].device
    # nb of locs for bbox in original image size
    num_locs_per_level = [len(l) for l in locations]
    # L = sum locs per level x 2 : for each loc in all_locs, the max size of bb authorized
    level_distances = _match_reg_distances_shape(MAXIMUM_DISTANCES_PER_LEVEL,
                                                 num_locs_per_level).to(device)
    all_locations = torch.cat(locations, dim=0).to(device) # Lx2
    xs, ys = all_locations[:, 0], all_locations[:, 1] # L & L

    all_reg_targets = []
    all_cls_targets = []
    for targets in targets_batch:
        bbox_targets = targets[:, :4] # Tx4
        cls_targets = targets[:, 4] # T
        
        # for each loc in L and each target in T, the reg target
        reg_targets = _calculate_reg_targets(xs, ys, bbox_targets) # LxTx4

        is_in_boxes = reg_targets.min(dim=2)[0] > 0 # min returns values and indices -> LxT

        fits_to_feature_level = _apply_distance_constraints(
            reg_targets, level_distances).to(device) # LxT

        #bbox_areas = _calc_bbox_area(bbox_targets) # T
        bbox_areas = torchvision.ops.box_area(bbox_targets) # compared to above, does not deal with 0dim bb
        
        # area of each target bbox repeated for each loc with inf where the the loc is not 
        # in the target bbox or if the loc is not at the right level for this bbox size
        locations_to_gt_area = bbox_areas[None].repeat(len(all_locations), 1) # LxT
        locations_to_gt_area[is_in_boxes == 0] = INF
        locations_to_gt_area[fits_to_feature_level == 0] = INF
        
        # for each loc, area and target idx of the target of min area at that loc
        loc_min_area, loc_mind_idxs = locations_to_gt_area.min(dim=1) # val&idx, size L, idx in [0,T-1]

        reg_targets = reg_targets[range(len(all_locations)), loc_mind_idxs] # Lx4

        cls_targets = cls_targets[loc_mind_idxs] # L
        cls_targets[loc_min_area == INF] = 0
        
        all_cls_targets.append(
            torch.split(cls_targets, num_locs_per_level, dim=0))
        all_reg_targets.append(
            torch.split(reg_targets, num_locs_per_level, dim=0))
    # all_cls_targets contains B lists of num levels elem of loc_per_levelsx1
    return _match_pred_format(all_cls_targets, all_reg_targets, locations)


def _match_pred_format(cls_targets, reg_targets, locations):
    cls_per_level = []
    reg_per_level = []
    for level in range(len(locations)):
        cls_per_level.append(torch.cat([ct[level] for ct in cls_targets],
                                       dim=0))

        reg_per_level.append(torch.cat([rt[level] for rt in reg_targets],
                                       dim=0))
    # reg_per_level is a list of num_levels tensors of size Bxnum_loc_per_levelx4
    return cls_per_level, reg_per_level


def _get_positive_samples(cls_labels, reg_labels, box_cls_preds, box_reg_preds,
                          centerness_preds, num_classes):
    box_cls_flatten = []
    box_regression_flatten = []
    centerness_flatten = []
    labels_flatten = []
    reg_targets_flatten = []
    for l in range(len(cls_labels)):
        box_cls_flatten.append(box_cls_preds[l].permute(0, 2, 3, 1).reshape(
            -1, num_classes))
        box_regression_flatten.append(box_reg_preds[l].permute(0, 2, 3,
                                                               1).reshape(
                                                                   -1, 4))
        labels_flatten.append(cls_labels[l].reshape(-1))
        reg_targets_flatten.append(reg_labels[l].reshape(-1, 4))
        centerness_flatten.append(centerness_preds[l].reshape(-1))

    cls_preds = torch.cat(box_cls_flatten, dim=0)
    cls_targets = torch.cat(labels_flatten, dim=0)
    reg_preds = torch.cat(box_regression_flatten, dim=0)
    reg_targets = torch.cat(reg_targets_flatten, dim=0)
    centerness_preds = torch.cat(centerness_flatten, dim=0)
    pos_inds = torch.nonzero(cls_targets > 0).squeeze(1) # dim #loc in all batches where there is one cls to pred not background

    reg_preds = reg_preds[pos_inds]
    reg_targets = reg_targets[pos_inds]
    centerness_preds = centerness_preds[pos_inds]

    return reg_preds, reg_targets, cls_preds, cls_targets, centerness_preds, pos_inds

In [7]:
import torch.nn.functional as F

class LossEvaluator:

    def __init__(self):
        self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")

    def _get_cls_loss(self, cls_preds, cls_targets, total_num_pos):
        nc = cls_preds.shape[1]
        onehot = F.one_hot(cls_targets.long(), nc+1)[:,1:].float()
        cls_loss = torchvision.ops.sigmoid_focal_loss(cls_preds, onehot)
        return cls_loss.sum() / total_num_pos

    def _get_reg_loss(self, reg_preds, reg_targets, centerness_targets):
        reg_preds = reg_preds.reshape(-1, 4)
        reg_targets = reg_targets.reshape(-1, 4)
        reg_losses = torchvision.ops.distance_box_iou_loss(reg_preds, reg_targets, reduction='none')
        sum_centerness_targets = centerness_targets.sum()
        reg_loss = (reg_losses * centerness_targets).sum() / sum_centerness_targets
        return reg_loss

    def _get_centerness_loss(self, centerness_preds, centerness_targets,
                             total_num_pos):
        centerness_loss = self.centerness_loss_func(centerness_preds,
                                                    centerness_targets)
        return centerness_loss / total_num_pos

    def _evaluate_losses(self, reg_preds, cls_preds, centerness_preds,
                         reg_targets, cls_targets, centerness_targets,
                         pos_inds):
        total_num_pos = max(pos_inds.new_tensor([pos_inds.numel()]), 1.0)

        cls_loss = self._get_cls_loss(cls_preds, cls_targets, total_num_pos)

        if pos_inds.numel() > 0:
            reg_loss = self._get_reg_loss(reg_preds, reg_targets,
                                          centerness_targets)
            centerness_loss = self._get_centerness_loss(centerness_preds,
                                                        centerness_targets,
                                                        total_num_pos)
        else:
            reg_loss = reg_preds.sum() # 0 ??
            centerness_loss = centerness_preds.sum() # 0 ??

        return reg_loss, cls_loss, centerness_loss

    def __call__(self, locations, preds, targets_batch, num_classes):
        # reg_targets is a list of num_levels tensors of size Bxnum_loc_per_levelx4
        cls_targets, reg_targets = _prepare_labels(locations, targets_batch)

        cls_preds, reg_preds, centerness_preds = preds

        reg_preds, reg_targets, cls_preds, cls_targets, centerness_preds, pos_inds = _get_positive_samples(
            cls_targets, reg_targets, cls_preds, reg_preds, centerness_preds,
            num_classes)

        centerness_targets = _compute_centerness_targets(reg_targets)

        reg_loss, cls_loss, centerness_loss = self._evaluate_losses(
            reg_preds, cls_preds, centerness_preds, reg_targets, cls_targets,
            centerness_targets, pos_inds)

        return cls_loss, reg_loss, centerness_loss

In [8]:
import pytorch_lightning as pl
from torchmetrics.detection.mean_ap import MeanAveragePrecision

class FCOSPostProcessor(torch.nn.Module):

    def __init__(self, pre_nms_thresh, pre_nms_top_n, nms_thresh,
                 fpn_post_nms_top_n, min_size, num_classes):
        super(FCOSPostProcessor, self).__init__()
        self.pre_nms_thresh = pre_nms_thresh
        self.pre_nms_top_n = pre_nms_top_n
        self.nms_thresh = nms_thresh
        self.fpn_post_nms_top_n = fpn_post_nms_top_n
        self.min_size = min_size
        self.num_classes = num_classes

    def forward_for_single_feature_map(self, locations, cls_preds, reg_preds,
                                       cness_preds, image_size):
        B, C, _, _ = cls_preds.shape

        cls_preds = cls_preds.permute(0, 2, 3, 1).reshape(B, -1, C).sigmoid() # BxHWxC in [0,1]
        reg_preds = reg_preds.permute(0, 2, 3, 1).reshape(B, -1, 4)
        cness_preds = cness_preds.permute(0, 2, 3, 1).reshape(B, -1).sigmoid()

        candidate_inds = cls_preds > self.pre_nms_thresh # BxHWxC
        pre_nms_top_n = candidate_inds.reshape(B, -1).sum(1) # B
        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)

        cls_preds = cls_preds * cness_preds[:, :, None] # BxHWxC
        
        # Conversion en liste de bbox,scores,cls par image du batch
        # POURQUOI le filtre cls_preds > nms_thresh arrive pas après la mul par cness_preds ?
        bboxes = []
        cls_labels = []
        scores = []
        for i in range(B):
            # Tensor with true where score for loc l and class c > pre_nms_thresh
            per_candidate_inds = candidate_inds[i] # HWxC
            # tenseur de taille Lx2 avec les indices des elem de cls_preds où > nms_thresh
            per_candidate_nonzeros = per_candidate_inds.nonzero() 
            # L : positions dans [0,HW] des elem dont cls_preds(c) > nms_thresh 
            per_box_loc = per_candidate_nonzeros[:, 0]
            # L : classe dans [1, C] des elem dont cls_preds(h,w) > nms_thresh
            per_class = per_candidate_nonzeros[:, 1] + 1

            per_reg_preds = reg_preds[i] # HWx4
            # liste des bb des elem dont cls_preds(c) > nms_thresh 
            per_reg_preds = per_reg_preds[per_box_loc] # Lx4
            per_locations = locations[per_box_loc] # Lx2

            per_pre_nms_top_n = pre_nms_top_n[i]
            
            per_cls_preds = cls_preds[i] # HWxC
            # tenseur de taille L avec les elem de cls_preds*centerness tels que cls_preds > nms_thresh
            per_cls_preds = per_cls_preds[per_candidate_inds] 
            # si y a plus de per_prenms_topn qui passe nms_thresh (si L est trop longue)
            if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
                per_cls_preds, top_k_indices = per_cls_preds.topk(
                    per_pre_nms_top_n, sorted=False)
                per_class = per_class[top_k_indices]
                per_reg_preds = per_reg_preds[top_k_indices]
                per_locations = per_locations[top_k_indices]
            
            # Rewrites bbox (x0,y0,x1,y1) from reg targets (l,t,r,b) following eq (1) in paper
            per_bboxes = torch.stack([
                per_locations[:, 0] - per_reg_preds[:, 0],
                per_locations[:, 1] - per_reg_preds[:, 1],
                per_locations[:, 0] + per_reg_preds[:, 2],
                per_locations[:, 1] + per_reg_preds[:, 3],
            ],
                                     dim=1)
            per_bboxes = torchvision.ops.clip_boxes_to_image(per_bboxes, (image_size, image_size))
            #detections = _clip_to_image(detections, (image_size, image_size))
            per_bboxes = per_bboxes[torchvision.ops.remove_small_boxes(per_bboxes, self.min_size)]
            #detections = remove_small_boxes(detections, self.min_size)
            bboxes.append(per_bboxes)
            cls_labels.append(per_class)
            scores.append(torch.sqrt(per_cls_preds))
        
        #bboxes is a list of B tensors of size Lx4 (potentially filtered with pre_nms_threshold)
        return bboxes, scores, cls_labels

    def forward(self, locations, cls_preds, reg_preds, cness_preds, image_size):
        # loc: list of n_feat_level tensors of size HW(level)
        # reg_preds: list of n_feat_level tensors BxHW(level)x4
        
        # list of n_feat_level lists of B tensors of size Lx4
        sampled_boxes = []
        all_scores = []
        all_classes = []
        for l, o, b, c in list(zip(locations, cls_preds, reg_preds,
                                   cness_preds)):
            boxes, scores, cls_labels = self.forward_for_single_feature_map(
                l, o, b, c, image_size)
            # boxes : list of B tensors Lx4
            sampled_boxes.append(boxes)
            all_scores.append(scores)
            all_classes.append(cls_labels)
        
        # list of B lists of n_feat_level bbox preds
        all_bboxes = list(zip(*sampled_boxes))
        all_scores = list(zip(*all_scores))
        all_classes = list(zip(*all_classes))
    
        # list of B tensors with all feature level bbox preds grouped
        all_bboxes = [torch.cat(bboxes, dim=0) for bboxes in all_bboxes]
        all_scores = [torch.cat(scores, dim=0) for scores in all_scores]
        all_classes = [torch.cat(classes, dim=0) for classes in all_classes]
        boxes, scores, classes = self.select_over_all_levels(
            all_bboxes, all_scores, all_classes)

        return boxes, scores, classes

    def select_over_all_levels(self, boxlists, scores, classes):
        num_images = len(boxlists)
        all_picked_boxes, all_confidence_scores, all_classes = [], [], []
        for i in range(num_images):
            picked_indices = torchvision.ops.nms(boxlists[i], scores[i], self.nms_thresh)
            picked_boxes = boxlists[i][picked_indices]
            confidence_scores = scores[i][picked_indices]
            picked_classes = classes[i][picked_indices]

            number_of_detections = len(picked_indices)
            if number_of_detections > self.fpn_post_nms_top_n > 0:
                image_thresh, _ = torch.kthvalue(
                    confidence_scores.cpu(),
                    number_of_detections - self.fpn_post_nms_top_n + 1)
                keep = confidence_scores >= image_thresh.item()

                keep = torch.nonzero(keep).squeeze(1)
                picked_boxes, confidence_scores, picked_classes = picked_boxes[
                    keep], confidence_scores[keep], picked_classes[keep]

            keep = confidence_scores >= self.pre_nms_thresh
            picked_boxes, confidence_scores, picked_classes = picked_boxes[
                keep], confidence_scores[keep], picked_classes[keep]

            all_picked_boxes.append(picked_boxes)
            all_confidence_scores.append(confidence_scores)
            all_classes.append(picked_classes)
        
        # all_picked_boxes : list of B tensors with all feature level bbox preds filtered by nms
        return all_picked_boxes, all_confidence_scores, all_classes

import schedulefree

class FCOS(pl.LightningModule):
    def __init__(
        self,
        network,
        num_classes,
        *args,
        **kwargs
    ):
        super().__init__()
        self.le = LossEvaluator()
        self.post_processor = FCOSPostProcessor(
            pre_nms_thresh=0.3,
            pre_nms_top_n=1000,
            nms_thresh=0.45,
            fpn_post_nms_top_n=50,
            min_size=0,
            num_classes=num_classes)
        self.fpn_strides = [8, 16, 32, 64, 128]
        self.feat_sizes = network.feat_sizes
        self.num_classes = num_classes
        self.network = network
        self.map_metric = MeanAveragePrecision()
        # locations is a list of num_feat_level elem, where each elem indicates the tensor of 
        # locations in the original image corresponding to each location in the feature map at this level
        self.locations = self._compute_locations()
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=1e-3,
            betas=(0.9, 0.999),
            weight_decay=5e-2,
            eps=1e-8,
        )
        #opt = schedulefree.AdamWScheduleFree(self.parameters(), lr=0.0025)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=1e-3,
            steps_per_epoch=10,
            epochs=20
        )
        #return opt
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step"
            },
        }
    
    def _compute_locations(self):
        locations = []
        
        def _locations_per_level(h, w, s):
            locs_x = [i for i in range(w)]
            locs_y = [i for i in range(h)]
            locs_x = [s / 2 + x * s for x in locs_x]
            locs_y = [s / 2 + y * s for y in locs_y]
            locs = [(y, x) for x in locs_x for y in locs_y]
            return torch.tensor(locs)
        
        for level, (h,w) in enumerate(self.feat_sizes):
            locs = _locations_per_level(h, w, self.fpn_strides[level])
            locations.append(locs)
        return locations
    
    def forward(self, images, targets_batch=None):
        features, box_cls, box_regression, centerness = self.network(images)
        locations = [l.to(features[0].device) for l in self.locations]
        image_size = images.shape[-1]
        outputs = {}
        predicted_boxes, scores, all_classes = self.post_processor(
            locations, box_cls, box_regression, centerness, image_size)
        
        
        if targets_batch != None:
            cls_loss, reg_loss, centerness_loss = self.le(
                locations, (box_cls, box_regression, centerness),
                targets_batch,
                num_classes=self.num_classes)
            outputs["cls_loss"] = cls_loss
            outputs["reg_loss"] = reg_loss
            outputs["centerness_loss"] = centerness_loss
            outputs["combined_loss"] = cls_loss + reg_loss + centerness_loss



        outputs["predicted_boxes"] = predicted_boxes
        outputs["scores"] = scores
        outputs["pred_classes"] = all_classes
        return outputs
    
    def training_step(self, batch, batch_idx):
        x, bboxes, labels = batch
        y = [torch.cat([bb, l[:,None]], dim=1) for bb, l in zip(bboxes, labels)]
        results = self.forward(x, y)
        loss = results["combined_loss"]
        self.log(f"loss/train", loss.detach().item())
        self.log(f"cls_loss/train", results["cls_loss"].detach().item())
        self.log(f"reg_loss/train", results["reg_loss"].detach().item())
        self.log(f"centerness_loss/train", results["centerness_loss"].detach().item())
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, bboxes, labels = batch
        y = [torch.cat([bb, l[:,None]], dim=1) for bb, l in zip(bboxes, labels)]
        results = self.forward(x, y)
        loss = results["combined_loss"]
        preds = [{'boxes': bb, 'scores': s, 'labels': l} for bb,s,l in zip(
            results["predicted_boxes"], results["scores"], results["pred_classes"]
        )]
        target_bb = [t[:, :4] for t in y]
        target_l = [t[:, 4] for t in y]
        targets = [{'boxes': bb, 'labels': l} for bb,l in zip(target_bb, target_l)]
        self.map_metric.update(preds, targets)
        self.log(f"loss/val", loss.detach().item())
        self.log(f"cls_loss/val", results["cls_loss"].detach().item())
        self.log(f"reg_loss/val", results["reg_loss"].detach().item())
        self.log(f"centerness_loss/val", results["centerness_loss"].detach().item())
        
    def on_validation_epoch_end(self):
        mapmetric = self.map_metric.compute()['map']
        self.log("map/val", mapmetric)
        print("\nMAP: ", mapmetric)
        self.map_metric.reset()

In [9]:
import pytorch_lightning as pl
from dl_toolbox.callbacks import ProgressBar
import gc

trainer = pl.Trainer(
    accelerator='cpu',
    devices=1,
    max_epochs=20,
    limit_train_batches=5,
    limit_val_batches=5,
    callbacks=[ProgressBar()]
)

module = FCOS(
    network,
    num_classes=19
)

gc.collect()
torch.cuda.empty_cache()
gc.collect()


trainer.fit(
    module,
    datamodule=dm,
)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


loading annotations into memory...
Done (t=10.01s)
creating index...
index created!
loading annotations into memory...



  | Name           | Type                 | Params
--------------------------------------------------------
0 | post_processor | FCOSPostProcessor    | 0     
1 | network        | ViTSimpleFPN         | 28.0 M
2 | map_metric     | MeanAveragePrecision | 0     
--------------------------------------------------------
28.0 M    Trainable params
0         Non-trainable params
28.0 M    Total params
111.870   Total estimated model params size (MB)


Done (t=0.35s)
creating index...
index created!
Sanity Checking DataLoader 0:   0%|                                                                                                       | 0/2 [00:00<?, ?it/s]

RuntimeError: Class values must be smaller than num_classes.