In [1]:
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelP6P7
#from torchvision.models.detection.backbone_utils import LastLevelP6P7
from dl_toolbox.networks.fcos import Head
import torch.nn as nn
import torch
import torchvision
INF = 100000000
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

def _compute_centerness_targets(reg_targets):
    if len(reg_targets) == 0:
        return reg_targets.new_zeros(len(reg_targets))
    left_right = reg_targets[:, [0, 2]]
    top_bottom = reg_targets[:, [1, 3]]
    centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
    return torch.sqrt(centerness)


def _calculate_reg_targets(xs, ys, bbox_targets):
    l = xs[:, None] - bbox_targets[:, 0][None] # Lx1 - 1xT -> LxT
    t = ys[:, None] - bbox_targets[:, 1][None]
    r = bbox_targets[:, 2][None] - xs[:, None]
    b = bbox_targets[:, 3][None] - ys[:, None]
    return torch.stack([l, t, r, b], dim=2) # LxTx4


def _apply_distance_constraints(reg_targets, level_distances):
    max_reg_targets, _ = reg_targets.max(dim=2)
    return torch.logical_and(max_reg_targets >= level_distances[:, None, 0], \
                             max_reg_targets <= level_distances[:, None, 1])

def _match_pred_format(cls_targets, reg_targets, locations):
    cls_per_level = []
    reg_per_level = []
    for level in range(len(locations)):
        cls_per_level.append(torch.cat([ct[level] for ct in cls_targets],
                                       dim=0))

        reg_per_level.append(torch.cat([rt[level] for rt in reg_targets],
                                       dim=0))
    # reg_per_level is a list of num_levels tensors of size Bxnum_loc_per_levelx4
    return cls_per_level, reg_per_level


def _get_positive_samples(cls_labels, reg_labels, box_cls_preds, box_reg_preds,
                          centerness_preds, num_classes):
    box_cls_flatten = []
    box_regression_flatten = []
    centerness_flatten = []
    labels_flatten = []
    reg_targets_flatten = []
    for l in range(len(cls_labels)):
        box_cls_flatten.append(box_cls_preds[l].permute(0, 2, 3, 1).reshape(
            -1, num_classes))
        box_regression_flatten.append(box_reg_preds[l].permute(0, 2, 3,
                                                               1).reshape(
                                                                   -1, 4))
        labels_flatten.append(cls_labels[l].reshape(-1))
        reg_targets_flatten.append(reg_labels[l].reshape(-1, 4))
        centerness_flatten.append(centerness_preds[l].reshape(-1))

    cls_preds = torch.cat(box_cls_flatten, dim=0)
    cls_targets = torch.cat(labels_flatten, dim=0)
    reg_preds = torch.cat(box_regression_flatten, dim=0)
    reg_targets = torch.cat(reg_targets_flatten, dim=0)
    centerness_preds = torch.cat(centerness_flatten, dim=0)
    pos_inds = torch.nonzero(cls_targets > 0).squeeze(1) # dim #loc in all batches where there is one cls to pred not background

    reg_preds = reg_preds[pos_inds]
    reg_targets = reg_targets[pos_inds]
    centerness_preds = centerness_preds[pos_inds]

    return reg_preds, reg_targets, cls_preds, cls_targets, centerness_preds, pos_inds

class LossEvaluator(nn.Module):

    def __init__(self, locs_info, num_classes):
        super(LossEvaluator, self).__init__()
        locs_per_level, bb_sizes_per_level, num_locs_per_level = locs_info
        self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")
        self.register_buffer('locations', torch.cat(locs_per_level, dim=0))
        self.register_buffer('bb_sizes', torch.cat(bb_sizes_per_level, dim=0))
        self.num_locs_per_level = num_locs_per_level
        self.num_classes = num_classes

    def _get_cls_loss(self, cls_preds, cls_targets, total_num_pos):
        nc = cls_preds.shape[1]
        onehot = F.one_hot(cls_targets.long(), nc+1)[:,1:].float()
        cls_loss = torchvision.ops.sigmoid_focal_loss(cls_preds, onehot)
        return cls_loss.sum() / total_num_pos

    def _get_reg_loss(self, reg_preds, reg_targets, centerness_targets):
        reg_preds = reg_preds.reshape(-1, 4)
        reg_targets = reg_targets.reshape(-1, 4)
        reg_losses = torchvision.ops.distance_box_iou_loss(reg_preds, reg_targets, reduction='none')
        sum_centerness_targets = centerness_targets.sum()
        reg_loss = (reg_losses * centerness_targets).sum() / sum_centerness_targets
        return reg_loss

    def _get_centerness_loss(self, centerness_preds, centerness_targets,
                             total_num_pos):
        centerness_loss = self.centerness_loss_func(centerness_preds,
                                                    centerness_targets)
        return centerness_loss / total_num_pos

    def _evaluate_losses(self, reg_preds, cls_preds, centerness_preds,
                         reg_targets, cls_targets, centerness_targets,
                         pos_inds):
        total_num_pos = max(pos_inds.new_tensor([pos_inds.numel()]), 1.0)

        cls_loss = self._get_cls_loss(cls_preds, cls_targets, total_num_pos)

        if pos_inds.numel() > 0:
            reg_loss = self._get_reg_loss(reg_preds, reg_targets,
                                          centerness_targets)
            centerness_loss = self._get_centerness_loss(centerness_preds,
                                                        centerness_targets,
                                                        total_num_pos)
        else:
            reg_loss = reg_preds.sum() # 0 ??
            centerness_loss = centerness_preds.sum() # 0 ??

        return reg_loss, cls_loss, centerness_loss
    
    def _prepare_labels(self, targets_batch):
        # nb of locs for bbox in original image size
        # L = sum locs per level x 2 : for each loc in all_locs, the max size of bb authorized
        xs, ys = self.locations[:, 0], self.locations[:, 1] # L & L
        num_locs = sum(self.num_locs_per_level)

        all_reg_targets = []
        all_cls_targets = []
        for targets in targets_batch:
            
            bbox_targets = targets['boxes'] # Tx4
            cls_targets = targets['labels'] # T
            num_targets = cls_targets.shape[0]

            # for each loc in L and each target in T, the reg target
            reg_targets = _calculate_reg_targets(xs, ys, bbox_targets) # LxTx4

            is_in_boxes = reg_targets.min(dim=2)[0] > 0 # min returns values and indices -> LxT

            fits_to_feature_level = _apply_distance_constraints(
                reg_targets, self.bb_sizes) # LxT

            #bbox_areas = _calc_bbox_area(bbox_targets) # T
            bbox_areas = torchvision.ops.box_area(bbox_targets) # compared to above, does not deal with 0dim bb

            # area of each target bbox repeated for each loc with inf where the the loc is not 
            # in the target bbox or if the loc is not at the right level for this bbox size
            locations_to_gt_area = bbox_areas[None].repeat(len(self.locations), 1) # LxT
            locations_to_gt_area[is_in_boxes == 0] = INF
            locations_to_gt_area[fits_to_feature_level == 0] = INF

            # for each loc, area and target idx of the target of min area at that loc
            if num_targets>0:
                loc_min_area, loc_mind_idxs = locations_to_gt_area.min(dim=1) # val&idx, size L, idx in [0,T-1]
                reg_targets = reg_targets[range(len(self.locations)), loc_mind_idxs] # Lx4
                cls_targets = cls_targets[loc_mind_idxs] # L
                cls_targets[loc_min_area == INF] = 0
            else:
                cls_targets = cls_targets.new_zeros((num_locs,))
                reg_targets = reg_targets.new_zeros((num_locs,4))

            all_cls_targets.append(
                torch.split(cls_targets, self.num_locs_per_level, dim=0))
            all_reg_targets.append(
                torch.split(reg_targets, self.num_locs_per_level, dim=0))
        # all_cls_targets contains B lists of num levels elem of loc_per_levelsx1
        locations = torch.split(self.locations, self.num_locs_per_level, dim=0)
        return _match_pred_format(all_cls_targets, all_reg_targets, locations)

    def __call__(self, out, targets_batch):
        # reg_targets is a list of num_levels tensors of size Bxnum_loc_per_levelx4
        cls_targets, reg_targets = self._prepare_labels(targets_batch)
        box_cls, box_regression, centerness = out
        reg_p, reg_t, cls_p, cls_t, centerness_p, pos_inds = _get_positive_samples(
            cls_targets,
            reg_targets,
            box_cls,
            box_regression,
            centerness,
            self.num_classes
        )
        centerness_t = _compute_centerness_targets(reg_t)
        losses = {}
        reg_loss, cls_loss, centerness_loss = self._evaluate_losses(
            reg_p, cls_p, centerness_p, reg_t, cls_t, centerness_t, pos_inds)
        losses["cls_loss"] = cls_loss
        losses["reg_loss"] = reg_loss
        losses["centerness_loss"] = centerness_loss
        losses["combined_loss"] = cls_loss + reg_loss + centerness_loss
        return losses

In [3]:
class FCOSPostProcessor(nn.Module):

    def __init__(self, locs_info, pre_nms_thresh, pre_nms_top_n, nms_thresh,
                 fpn_post_nms_top_n, min_size, num_classes):
        super(FCOSPostProcessor, self).__init__()
        self.pre_nms_thresh = pre_nms_thresh
        self.pre_nms_top_n = pre_nms_top_n
        self.nms_thresh = nms_thresh
        self.fpn_post_nms_top_n = fpn_post_nms_top_n
        self.min_size = min_size
        self.num_classes = num_classes
        locs_per_level, bb_sizes_per_level, num_locs_per_level = locs_info
        self.register_buffer('locations', torch.cat(locs_per_level, dim=0))
        self.num_locs_per_level = num_locs_per_level

    def forward_for_single_feature_map(self, locations, cls_preds, reg_preds,
                                       cness_preds, image_size):
        B, C, _, _ = cls_preds.shape
        cls_preds = cls_preds.permute(0, 2, 3, 1).reshape(B, -1, C).sigmoid() # BxHWxC in [0,1]
        reg_preds = reg_preds.permute(0, 2, 3, 1).reshape(B, -1, 4)
        cness_preds = cness_preds.permute(0, 2, 3, 1).reshape(B, -1).sigmoid()

        candidate_inds = cls_preds > self.pre_nms_thresh # BxHWxC
        pre_nms_top_n = candidate_inds.reshape(B, -1).sum(1) # B
        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)

        cls_preds = cls_preds * cness_preds[:, :, None] # BxHWxC
        
        # Conversion en liste de bbox,scores,cls par image du batch
        # POURQUOI le filtre cls_preds > nms_thresh arrive pas après la mul par cness_preds ?
        bboxes = []
        cls_labels = []
        scores = []
        for i in range(B):
            # Tensor with true where score for loc l and class c > pre_nms_thresh
            per_candidate_inds = candidate_inds[i] # HWxC
            # tenseur de taille Lx2 avec les indices des elem de cls_preds où > nms_thresh
            per_candidate_nonzeros = per_candidate_inds.nonzero() 
            # L : positions dans [0,HW] des elem dont cls_preds(c) > nms_thresh 
            per_box_loc = per_candidate_nonzeros[:, 0]
            # L : classe dans [1, C] des elem dont cls_preds(h,w) > nms_thresh
            per_class = per_candidate_nonzeros[:, 1] + 1

            per_reg_preds = reg_preds[i] # HWx4
            # liste des bb des elem dont cls_preds(c) > nms_thresh 
            per_reg_preds = per_reg_preds[per_box_loc] # Lx4
            per_locations = locations[per_box_loc] # Lx2

            per_pre_nms_top_n = pre_nms_top_n[i]
            
            per_cls_preds = cls_preds[i] # HWxC
            # tenseur de taille L avec les elem de cls_preds*centerness tels que cls_preds > nms_thresh
            per_cls_preds = per_cls_preds[per_candidate_inds] 
            
            # si y a plus de per_prenms_topn qui passe nms_thresh (si L est trop longue)
            if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
                per_cls_preds, top_k_indices = per_cls_preds.topk(
                    per_pre_nms_top_n, sorted=False)
                per_class = per_class[top_k_indices]
                per_reg_preds = per_reg_preds[top_k_indices]
                per_locations = per_locations[top_k_indices]
            
            # Rewrites bbox (x0,y0,x1,y1) from reg targets (l,t,r,b) following eq (1) in paper
            per_bboxes = torch.stack([
                per_locations[:, 0] - per_reg_preds[:, 0],
                per_locations[:, 1] - per_reg_preds[:, 1],
                per_locations[:, 0] + per_reg_preds[:, 2],
                per_locations[:, 1] + per_reg_preds[:, 3],
            ],
                                     dim=1)
            per_bboxes = torchvision.ops.clip_boxes_to_image(per_bboxes, (image_size, image_size))
            #detections = _clip_to_image(detections, (image_size, image_size))
            per_bboxes = per_bboxes[torchvision.ops.remove_small_boxes(per_bboxes, self.min_size)]
            #detections = remove_small_boxes(detections, self.min_size)
            bboxes.append(per_bboxes)
            cls_labels.append(per_class)
            scores.append(torch.sqrt(per_cls_preds))
            
        #bboxes is a list of B tensors of size Lx4 (potentially filtered with pre_nms_threshold)
        return bboxes, scores, cls_labels

    def forward(self, cls_preds, reg_preds, cness_preds, image_size):
        # loc: list of n_feat_level tensors of size HW(level)
        # reg_preds: list of n_feat_level tensors BxHW(level)x4
        
        # list of n_feat_level lists of B tensors of size Lx4
        sampled_boxes = []
        all_scores = []
        all_classes = []
        locations = torch.split(self.locations, self.num_locs_per_level, dim=0)
        for l, o, b, c in list(zip(locations, cls_preds, reg_preds,
                                   cness_preds)):
            boxes, scores, cls_labels = self.forward_for_single_feature_map(
                l, o, b, c, image_size)
            # boxes : list of B tensors Lx4
            sampled_boxes.append(boxes)
            all_scores.append(scores)
            all_classes.append(cls_labels)
        
        # list of B lists of n_feat_level bbox preds
        all_bboxes = list(zip(*sampled_boxes))
        all_scores = list(zip(*all_scores))
        all_classes = list(zip(*all_classes))
    
        # list of B tensors with all feature level bbox preds grouped
        all_bboxes = [torch.cat(bboxes, dim=0) for bboxes in all_bboxes]
        all_scores = [torch.cat(scores, dim=0) for scores in all_scores]
        all_classes = [torch.cat(classes, dim=0) for classes in all_classes]
        boxes, scores, classes = self.select_over_all_levels(
            all_bboxes, all_scores, all_classes)

        return boxes, scores, classes

    def select_over_all_levels(self, boxlists, scores, classes):
        num_images = len(boxlists)
        all_picked_boxes, all_confidence_scores, all_classes = [], [], []
        for i in range(num_images):
            picked_indices = torchvision.ops.nms(boxlists[i], scores[i], self.nms_thresh)
            picked_boxes = boxlists[i][picked_indices]
            confidence_scores = scores[i][picked_indices]
            picked_classes = classes[i][picked_indices]

            number_of_detections = len(picked_indices)
            if number_of_detections > self.fpn_post_nms_top_n > 0:
                image_thresh, _ = torch.kthvalue(
                    confidence_scores.cpu(),
                    number_of_detections - self.fpn_post_nms_top_n + 1)
                keep = confidence_scores >= image_thresh.item()

                keep = torch.nonzero(keep).squeeze(1)
                picked_boxes, confidence_scores, picked_classes = picked_boxes[
                    keep], confidence_scores[keep], picked_classes[keep]

            keep = confidence_scores >= self.pre_nms_thresh
            picked_boxes, confidence_scores, picked_classes = picked_boxes[
                keep], confidence_scores[keep], picked_classes[keep]

            all_picked_boxes.append(picked_boxes)
            all_confidence_scores.append(confidence_scores)
            all_classes.append(picked_classes)
        
        # all_picked_boxes : list of B tensors with all feature level bbox preds filtered by nms
        return all_picked_boxes, all_confidence_scores, all_classes

In [4]:
import pytorch_lightning as pl
from torchmetrics.detection.mean_ap import MeanAveragePrecision

class FCOS(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        network,
        optimizer,
        scheduler,
        pred_thresh,
        tta=None,
        sliding=None,
        *args,
        **kwargs
    ):
        super().__init__()
        self.num_classes = num_classes
        self.network = network
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.map_metric = MeanAveragePrecision(box_format='xywh', backend='faster_coco_eval')
        self.sliding = sliding
        self.pred_thresh = pred_thresh      

        
        fpn_strides = [8, 16, 32, 64, 128]
        bb_sizes = [-1, 64, 128, 256, 512, INF]
        # locations is a list of num_feat_level elem, where each elem indicates the tensor of 
        # locations in the original image corresponding to each location in the feature map at this level
        anchors = self.get_anchors(network.feat_sizes, fpn_strides, bb_sizes)
        self.loss = LossEvaluator(
            anchors,
            num_classes
        )
        self.post_processor = FCOSPostProcessor(
            locs_info = anchors,
            pre_nms_thresh=0.3,
            pre_nms_top_n=1000,
            nms_thresh=0.45,
            fpn_post_nms_top_n=50,
            min_size=0,
            num_classes=num_classes
        )
    
    def configure_optimizers(self):
        train_params = list(filter(lambda p: p[1].requires_grad, self.named_parameters()))
        nb_train = sum([int(torch.numel(p[1])) for p in train_params])
        nb_tot = sum([int(torch.numel(p)) for p in self.parameters()])
        print(f"Training {nb_train} params out of {nb_tot}")
        optimizer = self.optimizer(params=[p[1] for p in train_params])
        scheduler = self.scheduler(optimizer)
        return [optimizer], [scheduler]
    
    @classmethod
    def get_anchors(cls, feat_sizes, fpn_strides, bb_sizes):
        anchors, anchor_sizes, num_anchors = [], [], []
        def _locations_per_level(h, w, s):
            locs_x = [i for i in range(w)]
            locs_y = [i for i in range(h)]
            locs_x = [s / 2 + x * s for x in locs_x]
            locs_y = [s / 2 + y * s for y in locs_y]
            locs = [(y, x) for x in locs_x for y in locs_y]
            return torch.tensor(locs)
        for l, (h,w) in enumerate(feat_sizes):
            locs = _locations_per_level(h, w, fpn_strides[l])
            sizes = torch.tensor([bb_sizes[l], bb_sizes[l+1]], dtype=torch.float32)
            sizes = sizes.repeat(len(locs)).view(len(locs), 2)
            anchors.append(locs)
            anchor_sizes.append(sizes)
            num_anchors.append(len(locs))
        return anchors, anchor_sizes, num_anchors

    def forward(self, x):
        return self.network(x)
    
    def post_process(self, out, images):
        predicted_boxes, scores, all_classes = self.post_processor(*out, images.shape[-1])
        preds = [{'boxes': bb, 'scores': s, 'labels': l} for bb,s,l in zip(
            predicted_boxes, scores, all_classes
        )]
        return preds
    
    def training_step(self, batch, batch_idx):
        x, targets, paths = batch["sup"]
        outputs = self.forward(x)
        losses = self.loss(outputs, targets)
        loss = losses["combined_loss"]
        self.log(f"loss/train", loss.detach().item())
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, targets, paths = batch
        outputs = self.forward(x)
        losses = self.loss(outputs, targets)
        loss = losses["combined_loss"]
        self.log(f"Total loss/val", loss.detach().item())
        preds = self.post_process(outputs, x)
        self.map_metric.update(preds, targets)
        
    def on_validation_epoch_end(self):
        mapmetric = self.map_metric.compute()['map']
        self.log("map/val", mapmetric)
        print("\nMAP: ", mapmetric)
        self.map_metric.reset()

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock
from torchvision.ops.misc import Conv2dNormActivation
from typing import Callable, Dict, List, Optional, Tuple
from torch import nn, Tensor
import timm
from torchvision.ops.feature_pyramid_network import LastLevelP6P7
from dl_toolbox.networks.fcos import Head
import torch.nn as nn
from torchvision.models.feature_extraction import get_graph_node_names
from einops import rearrange
import torch

class LayerNorm2d(nn.LayerNorm):
    """ LayerNorm for channels of '2D' spatial NCHW tensors """
    def __init__(self, num_channels, eps=1e-6, affine=True):
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(
            x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)

class SimpleFeaturePyramidNetwork(nn.Module):
    """
    Module that adds a Simple FPN from on top of a set of feature maps. This is based on
    `"Exploring Plain Vision Transformer Backbones for Object Detection" <https://arxiv.org/abs/2203.16527>`_.

    Unlike regular FPN, Simple FPN expects a single feature map,
    on which the Simple FPN will be added.

    Args:
        in_channels (int): number of channels for the input feature map that
            is passed to the module
        out_channels (int): number of channels of the Simple FPN representation
        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: LayerNorm

    Examples::

        >>> m = torchvision.ops.SimpleFeaturePyramidNetwork(10, 5)
        >>> # get some dummy data
        >>> x = torch.rand(1, 10, 64, 64)
        >>> # compute the Simple FPN on top of x
        >>> output = m(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('feat0', torch.Size([1, 5, 64, 64])),
        >>>    ('feat2', torch.Size([1, 5, 16, 16])),
        >>>    ('feat3', torch.Size([1, 5, 8, 8]))]

    """

    _version = 2

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ):
        super().__init__()
        self.blocks = nn.ModuleList()
        for block_index in range(0,4):
            layers = []
            current_in_channels = in_channels
            if block_index == 0:
                layers.extend([
                    nn.ConvTranspose2d(
                        in_channels,
                        in_channels // 2,
                        kernel_size=2,
                        stride=2,
                    ),
                    norm_layer(in_channels // 2),
                    nn.GELU(),
                    nn.ConvTranspose2d(
                        in_channels // 2,
                        in_channels // 4,
                        kernel_size=2,
                        stride=2,
                    ),
                ])
                current_in_channels = in_channels // 4
            elif block_index == 1:
                layers.append(
                    nn.ConvTranspose2d(
                        in_channels,
                        in_channels // 2,
                        kernel_size=2,
                        stride=2,
                    ),
                )
                current_in_channels = in_channels // 2
            elif block_index == 2:
                # nothing to do for this scale
                pass
            elif block_index == 3:
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

            layers.extend([
                Conv2dNormActivation(
                    current_in_channels,
                    out_channels,
                    kernel_size=1,
                    padding=0,
                    norm_layer=norm_layer,
                    activation_layer=None
                ),
                Conv2dNormActivation(
                    out_channels,
                    out_channels,
                    kernel_size=3,
                    norm_layer=norm_layer,
                    activation_layer=None
                )
            ])
            self.blocks.append(nn.Sequential(*layers))

        if extra_blocks is not None:
            if not isinstance(extra_blocks, ExtraFPNBlock):
                raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
        self.extra_blocks = extra_blocks

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        """
        Computes the Simple FPN for a feature map.

        Args:
            x (Tensor): input feature map.

        Returns:
            results (list[Tensor]): feature maps after FPN layers.
                They are ordered from highest resolution first.
        """
        results = [block(x) for block in self.blocks]
        names = [f"{i}" for i in range(len(self.blocks))]

        if self.extra_blocks is not None:
            results, names = self.extra_blocks(results, [x], names)

        # make it back an OrderedDict
        out = OrderedDict([(k, v) for k, v in zip(names, results)])

        return out

class ViTDet(nn.Module):
    
    def __init__(self, out_channels, num_classes):
        super(ViTDet, self).__init__()
        self.backbone = timm.create_model('samvit_base_patch16.sa1b', pretrained=True)
        self.sfpn = SimpleFeaturePyramidNetwork(
            in_channels=768,
            out_channels=out_channels,
            #extra_blocks=LastLevelP6P7(out_channels,out_channels),
            norm_layer=LayerNorm2d
        )
        inp = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            out = self.forward_feat(inp)
        self.feat_sizes = [o.shape[2:] for o in out.values()]
        self.head = Head(out_channels, num_classes, n_feat_levels=6)
        
    def forward_feat(self, x):
        intermediates = self.backbone.forward_intermediates(x, indices=1, norm=False, intermediates_only=True)
        features = self.sfpn(intermediates[0])
        return features
    
    def forward(self, x):
        feat_dict = self.forward_feat(x)
        features = list(feat_dict.values())
        box_cls, box_regression, centerness = self.head(features)
        # box_reg: lists of n_feat_level tensors BxHW(level)x4
        # why not tensor Bxsum_level(HW)x4 ?
        return box_cls, box_regression, centerness
    
vitdet = ViTDet(256, 4)

In [3]:
x = torch.rand(2, 3, 224, 224)
features = vitdet.forward_feat(x)
print(f'feat shapes: {[f.shape for f in features.values()] = }')

feat shapes: [f.shape for f in features.values()] = [torch.Size([2, 256, 56, 56]), torch.Size([2, 256, 28, 28]), torch.Size([2, 256, 14, 14]), torch.Size([2, 256, 7, 7])]


In [37]:
x = torch.rand(2, 3, 224, 224)
box_cls, box_regression, centerness = vitdet.forward(x)
print(f'box pred shapes: {[f.shape for f in box_cls] = }')

box pred shapes: [f.shape for f in box_cls] = [torch.Size([2, 4, 56, 56]), torch.Size([2, 4, 28, 28]), torch.Size([2, 4, 14, 14]), torch.Size([2, 4, 7, 7])]


In [38]:
import pytorch_lightning as pl
from dl_toolbox.callbacks import ProgressBar
import gc 

import pytorch_lightning as pl
from dl_toolbox.callbacks import ProgressBar
from dl_toolbox import datamodules
from dl_toolbox import modules
import torchvision.transforms.v2 as v2
from functools import partial

import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from dl_toolbox.callbacks import ProgressBar, Finetuning, Lora, TiffPredsWriter, CalibrationLogger
from functools import partial
import gc


train_tf = v2.Compose(
    [
        v2.RandomCrop(224),
        v2.SanitizeBoundingBoxes(),
        v2.Normalize([0.5]*3, [0.5]*3)
    ]
)

test_tf = v2.Compose(
    [
        v2.CenterCrop(1000),
        v2.Resize(224),
        v2.SanitizeBoundingBoxes(),
        v2.Normalize([0.5]*3, [0.5]*3)
    ]
)

dm = datamodules.xView(
    data_path='/data',
    merge='all',
    train_tf=train_tf,
    test_tf=test_tf,
    batch_tf=None,
    batch_size=2,
    num_workers=0,
    pin_memory=False
)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=20,
    limit_train_batches=5,
    limit_val_batches=5,
    callbacks=[ProgressBar()]
)

num_classes = dm.num_classes
network = ViTDet(num_classes=num_classes, out_channels=256)

module = FCOS(
    num_classes=num_classes,
    network=network,
    optimizer=partial(torch.optim.Adam, lr=0.001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    pred_thresh=0.1,
)

gc.collect()
torch.cuda.empty_cache()
gc.collect()

trainer.fit(
    module,
    datamodule=dm,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


loading annotations into memory...
Done (t=5.83s)
creating index...
index created!
loading annotations into memory...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done (t=3.24s)
creating index...
index created!



  | Name           | Type                 | Params
--------------------------------------------------------
0 | network        | ViTDet               | 97.7 M
1 | map_metric     | MeanAveragePrecision | 0     
2 | loss           | LossEvaluator        | 0     
3 | post_processor | FCOSPostProcessor    | 0     
--------------------------------------------------------
97.7 M    Trainable params
0         Non-trainable params
97.7 M    Total params
390.972   Total estimated model params size (MB)


Training 97743111 params out of 97743111
Sanity Checking: |                                                                                                                        | 0/? [00:00<?, ?it/s]

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|                                                                                                       | 0/2 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,
/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:21<00:00,  0.09it/s]
MAP:  tensor(0.)
                                                                                                                                                                

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:   0%|                                                                                                                            | 0/5 [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:07<00:00,  0.65it/s, v_num=64]
Validation: |                                                                                                                             | 0/? [00:00<?, ?it/s][A
MAP:  tensor(0.)
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:10<00:00,  0.50it/s, v_num=64]

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
