In [1]:
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelP6P7
#from torchvision.models.detection.backbone_utils import LastLevelP6P7
import torch.nn as nn
import torch
import torchvision
INF = 100000000
import torch.nn.functional as F

import pytorch_lightning as pl
from dl_toolbox.callbacks import ProgressBar
import gc 

import pytorch_lightning as pl
from dl_toolbox.callbacks import ProgressBar
from dl_toolbox import datamodules
from dl_toolbox import modules
import torchvision.transforms.v2 as v2
from functools import partial

import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from dl_toolbox.callbacks import ProgressBar, Finetuning, Lora, TiffPredsWriter, CalibrationLogger
from functools import partial
import gc


train_tf = v2.Compose(
    [
        v2.RandomCrop(224),
        v2.SanitizeBoundingBoxes(),
        v2.Normalize([0.5]*3, [0.5]*3)
    ]
)

test_tf = v2.Compose(
    [
        v2.CenterCrop(1000),
        v2.Resize(224),
        v2.SanitizeBoundingBoxes(),
        v2.Normalize([0.5]*3, [0.5]*3)
    ]
)

dm = datamodules.xView(
    data_path='/data',
    merge='all',
    train_tf=train_tf,
    test_tf=test_tf,
    batch_tf=None,
    batch_size=2,
    num_workers=0,
    pin_memory=False
)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=20,
    limit_train_batches=5,
    limit_val_batches=5,
    callbacks=[ProgressBar()]
)

num_classes = dm.num_classes

  from .autonotebook import tqdm as notebook_tqdm
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
def _compute_centerness_targets(reg_targets):
    if len(reg_targets) == 0:
        return reg_targets.new_zeros(len(reg_targets))
    left_right = reg_targets[:, [0, 2]]
    top_bottom = reg_targets[:, [1, 3]]
    centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
    return torch.sqrt(centerness)


def _calculate_reg_targets(xs, ys, bbox_targets):
    """ 
    Very important : what is the format of bbox in targets produced by the dataset ?
    """
    l = xs[:, None] - bbox_targets[:, 0][None] # Lx1 - 1xT -> LxT
    t = ys[:, None] - bbox_targets[:, 1][None]
    r = bbox_targets[:, 2][None] - xs[:, None]
    b = bbox_targets[:, 3][None] - ys[:, None]
    return torch.stack([l, t, r, b], dim=2) # LxTx4


def _apply_distance_constraints(reg_targets, level_distances):
    """
    reg_targets: LxTx4
    level_distances: Lx2
    """
    max_reg_targets, _ = reg_targets.max(dim=2)
    return torch.logical_and(max_reg_targets >= level_distances[:, None, 0], \
                             max_reg_targets <= level_distances[:, None, 1])

def _match_pred_format(cls_targets, reg_targets, locations):
    cls_per_level = []
    reg_per_level = []
    for level in range(len(locations)):
        
        cls_per_level.append(torch.cat([ct[level] for ct in cls_targets],
                                       dim=0))

        reg_per_level.append(torch.cat([rt[level] for rt in reg_targets],
                                       dim=0))
    # reg_per_level is a list of num_levels tensors of size Bxnum_loc_per_levelx4
    return cls_per_level, reg_per_level


def _get_positive_samples(cls_labels, reg_labels, box_cls_preds, box_reg_preds,
                          centerness_preds, num_classes):
    box_cls_flatten = []
    box_regression_flatten = []
    centerness_flatten = []
    labels_flatten = []
    reg_targets_flatten = []
    for l in range(len(cls_labels)):
        box_cls_flatten.append(box_cls_preds[l].permute(0, 2, 3, 1).reshape(
            -1, num_classes))
        box_regression_flatten.append(box_reg_preds[l].permute(0, 2, 3,
                                                               1).reshape(
                                                                   -1, 4))
        labels_flatten.append(cls_labels[l].reshape(-1))
        reg_targets_flatten.append(reg_labels[l].reshape(-1, 4))
        centerness_flatten.append(centerness_preds[l].reshape(-1))

    cls_preds = torch.cat(box_cls_flatten, dim=0)
    cls_targets = torch.cat(labels_flatten, dim=0)
    reg_preds = torch.cat(box_regression_flatten, dim=0)
    reg_targets = torch.cat(reg_targets_flatten, dim=0)
    centerness_preds = torch.cat(centerness_flatten, dim=0)
    pos_inds = torch.nonzero(cls_targets > 0).squeeze(1) # dim #loc in all batches where there is one cls to pred not background
    print(f'{pos_inds.shape = }')
    reg_preds = reg_preds[pos_inds]
    reg_targets = reg_targets[pos_inds]
    centerness_preds = centerness_preds[pos_inds]
    return reg_preds, reg_targets, cls_preds, cls_targets, centerness_preds, pos_inds

class LossEvaluator(nn.Module):

    def __init__(self, locs_info, num_classes):
        super(LossEvaluator, self).__init__()
        locs_per_level, bb_sizes_per_level, num_locs_per_level = locs_info
        self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")
        # Let's call L the total nb of locs across all feature levels
        self.register_buffer('locations', torch.cat(locs_per_level, dim=0)) # Lx2
        self.register_buffer('bb_sizes', torch.cat(bb_sizes_per_level, dim=0)) # Lx2
        self.num_locs_per_level = num_locs_per_level # list len L
        self.num_classes = num_classes

    def _get_cls_loss(self, cls_preds, cls_targets, total_num_pos):
        nc = cls_preds.shape[1]
        onehot = F.one_hot(cls_targets.long(), nc+1)[:,1:].float()
        cls_loss = torchvision.ops.sigmoid_focal_loss(cls_preds, onehot)
        return cls_loss.sum() / total_num_pos

    def _get_reg_loss(self, reg_preds, reg_targets, centerness_targets):
        reg_preds = reg_preds.reshape(-1, 4)
        reg_targets = reg_targets.reshape(-1, 4)
        reg_losses = torchvision.ops.distance_box_iou_loss(reg_preds, reg_targets, reduction='none')
        sum_centerness_targets = centerness_targets.sum()
        reg_loss = (reg_losses * centerness_targets).sum() / sum_centerness_targets
        return reg_loss

    def _get_centerness_loss(self, centerness_preds, centerness_targets,
                             total_num_pos):
        centerness_loss = self.centerness_loss_func(centerness_preds,
                                                    centerness_targets)
        return centerness_loss / total_num_pos

    def _evaluate_losses(self, reg_preds, cls_preds, centerness_preds,
                         reg_targets, cls_targets, centerness_targets,
                         pos_inds):
        total_num_pos = max(pos_inds.new_tensor([pos_inds.numel()]), 1.0)

        cls_loss = self._get_cls_loss(cls_preds, cls_targets, total_num_pos)

        if pos_inds.numel() > 0:
            reg_loss = self._get_reg_loss(reg_preds, reg_targets,
                                          centerness_targets)
            centerness_loss = self._get_centerness_loss(centerness_preds,
                                                        centerness_targets,
                                                        total_num_pos)
        else:
            reg_loss = reg_preds.sum() # 0 ??
            centerness_loss = centerness_preds.sum() # 0 ??

        return reg_loss, cls_loss, centerness_loss
    
    def _prepare_labels(self, targets_batch):
        """
        inputs:
            targets_batch: list of dict of tv_tensors {'labels':, 'boxes':}
        outputs:
            target bb and cls for each anchor ?
        """
        # nb of locs for bbox in original image size
        # L = sum locs per level x 2 : for each loc in all_locs, the max size of bb authorized
        xs, ys = self.locations[:, 0], self.locations[:, 1] # L & L
        num_locs = sum(self.num_locs_per_level)

        all_reg_targets = []
        all_cls_targets = []
        for targets in targets_batch:
            
            bbox_targets = targets['boxes'] # Tx4
            # bbox targets which format ??? 
            # Code here expects xyxy, but dataset provides xywh so:
            bbox_targets[:, 2] += bbox_targets[:, 0]
            bbox_targets[:, 3] += bbox_targets[:, 1]
            
            cls_targets = targets['labels'] # T
            num_targets = cls_targets.shape[0]
            print(f'{num_targets =}')

            # for each loc in L and each target in T, the reg target
            reg_targets = _calculate_reg_targets(xs, ys, bbox_targets) # LxTx4
            
            # Which locs are contained in which tgt bb
            is_in_boxes = reg_targets.min(dim=2)[0] > 0 # min returns values and indices -> LxT
            
            # BUG ? Now : which pairs (loc, tgt bb) are such that the max value to regress at this loc for this bb 
            # is inside the bounds associated to this loc;
            # Nothing prevents that the pair requires regressing negative vals ??
            # NO BUG because the rest is filtered by is_in_boxes
            fits_to_feature_level = _apply_distance_constraints(
                reg_targets, self.bb_sizes) # LxT
            
            #bbox_areas = _calc_bbox_area(bbox_targets) # T
            bbox_areas = torchvision.ops.box_area(bbox_targets) # compared to above, does not deal with 0dim bb
            
            # area of each target bbox repeated for each loc with inf where the the loc is not 
            # in the target bbox or if the loc is not at the right level for this bbox size
            locations_to_gt_area = bbox_areas[None].repeat(len(self.locations), 1) # LxT
            locations_to_gt_area[is_in_boxes == 0] = INF
            locations_to_gt_area[fits_to_feature_level == 0] = INF

            # for each loc, area and target idx of the target of min area at that loc
            if num_targets>0:
                # Here the goal is to associate one target bbox ONLY to each loc/anchor
                # So, min over T for each loc to find the smallest tgt bb that:
                # - contains the loc
                # - and is of size in the limits associated to the loc
                loc_min_area, loc_min_idxs = locations_to_gt_area.min(dim=1) # val&idx, size L, idx in [0,T-1]
                reg_targets = reg_targets[range(len(self.locations)), loc_min_idxs] # Lx4
                cls_targets = cls_targets[loc_min_idxs] # L
                cls_targets[loc_min_area == INF] = 0 # 0 is no-obj category ? NO !! XVIEW outputs 0 for first class !!
            else:
                cls_targets = cls_targets.new_zeros((num_locs,))
                reg_targets = reg_targets.new_zeros((num_locs,4))

            all_cls_targets.append(
                torch.split(cls_targets, self.num_locs_per_level, dim=0))
            all_reg_targets.append(
                torch.split(reg_targets, self.num_locs_per_level, dim=0))
        # all_cls_targets contains B lists of num levels elem of loc_per_levelsx1
        locations = torch.split(self.locations, self.num_locs_per_level, dim=0)
        return _match_pred_format(all_cls_targets, all_reg_targets, locations)

    def __call__(self, out, targets_batch):
        
        # reg_targets is a list of num_levels tensors of size Bxnum_loc_per_levelx4
        cls_targets, reg_targets = self._prepare_labels(targets_batch)
        
        # box_regression is a list of num_levels tensors of size Bx4xnum_loc_per_level
        box_cls, box_regression, centerness = out
        reg_p, reg_t, cls_p, cls_t, centerness_p, pos_inds = _get_positive_samples(
            cls_targets,
            reg_targets,
            box_cls,
            box_regression,
            centerness,
            self.num_classes
        )
        centerness_t = _compute_centerness_targets(reg_t)
        losses = {}
        reg_loss, cls_loss, centerness_loss = self._evaluate_losses(
            reg_p, cls_p, centerness_p, reg_t, cls_t, centerness_t, pos_inds)
        losses["cls_loss"] = cls_loss
        losses["reg_loss"] = reg_loss
        losses["centerness_loss"] = centerness_loss
        losses["combined_loss"] = cls_loss + reg_loss + centerness_loss
        return losses

In [9]:
import pytorch_lightning as pl
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from post_processor import *

class FCOS(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        network,
        optimizer,
        scheduler,
        pred_thresh,
        tta=None,
        sliding=None,
        *args,
        **kwargs
    ):
        super().__init__()
        self.num_classes = num_classes
        self.network = network
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.map_metric = MeanAveragePrecision(
            box_format='xywh', # make sure your dataset outputs target in xywh format
            backend='faster_coco_eval'
        )
        self.sliding = sliding
        self.pred_thresh = pred_thresh      

        fpn_strides = [8, 16, 32, 64, 128] # Feat size reduction factor for each FPN stage
        # BB size to be detected by FPN stages: P3=red factor 8 -> detects bb of size between -1 & 64
        bb_sizes = [-1, 64, 128, 256, 512, INF] 
        # anchors is a list of num_feat_level elem, where each elem indicates the tensor of 
        # anchors of size Nx2 in the original image corresponding to each location in the feature map at this level
        # For ex, the top-left anchor for the first feature level is centered on (4,4) and aims at
        # detecting objects of which the bbox sides are not further than 64 pixels, so at max of size 128
        anchors, anchor_sizes, num_anchors = self.get_anchors(
            network.feat_sizes,
            fpn_strides,
            bb_sizes
        )
        self.loss = LossEvaluator(
            (anchors, anchor_sizes, num_anchors),
            num_classes
        )
        self.post_processor = FCOSPostProcessor(
            locs_info = (anchors, anchor_sizes, num_anchors),
            pre_nms_thresh=0.3,
            pre_nms_top_n=1000,
            nms_thresh=0.45,
            fpn_post_nms_top_n=50,
            min_size=0,
            num_classes=num_classes
        )
    
    def configure_optimizers(self):
        train_params = list(filter(lambda p: p[1].requires_grad, self.named_parameters()))
        nb_train = sum([int(torch.numel(p[1])) for p in train_params])
        nb_tot = sum([int(torch.numel(p)) for p in self.parameters()])
        print(f"Training {nb_train} params out of {nb_tot}")
        optimizer = self.optimizer(params=[p[1] for p in train_params])
        scheduler = self.scheduler(optimizer)
        return [optimizer], [scheduler]
    
    @classmethod
    def get_anchors(cls, feat_sizes, fpn_strides, bb_sizes):
        """
        arguments:
            feat_sizes: feature maps sizes
            fpn_strides: Feat map size reduction factor for each FPN stage
            bb_sizes: bbox 
        """
        anchors, anchor_sizes, num_anchors = [], [], []
        
        def _anchors_per_level(h, w, s):
            locs_x = [i for i in range(w)]
            locs_y = [i for i in range(h)]
            locs_x = [s / 2 + x * s for x in locs_x]
            locs_y = [s / 2 + y * s for y in locs_y]
            locs = [(y, x) for x in locs_x for y in locs_y]
            return torch.tensor(locs)
        
        for l, (h,w) in enumerate(feat_sizes):
            # first level : l=0, stride=8, h=w=28
            locs = _anchors_per_level(h, w, fpn_strides[l])
            sizes = torch.tensor([bb_sizes[l], bb_sizes[l+1]], dtype=torch.float32)
            sizes = sizes.repeat(len(locs)).view(len(locs), 2)
            anchors.append(locs)
            anchor_sizes.append(sizes)
            num_anchors.append(len(locs))
        return anchors, anchor_sizes, num_anchors

    def forward(self, x):
        return self.network(x)
    
    def post_process(self, out, images):
        predicted_boxes, scores, all_classes = self.post_processor(*out, images.shape[-1])
        preds = [{'boxes': bb, 'scores': s, 'labels': l} for bb,s,l in zip(
            predicted_boxes, scores, all_classes
        )]
        return preds
    
    def training_step(self, batch, batch_idx):
        x, targets, paths = batch["sup"] #targets is a list of dict
        outputs = self.forward(x) # pred_cls, pred_bb, pred_centerness
        losses = self.loss(outputs, targets)
        loss = losses["combined_loss"]
        self.log(f"loss/train", loss.detach().item())
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, targets, paths = batch
        outputs = self.forward(x)
        losses = self.loss(outputs, targets)
        loss = losses["combined_loss"]
        self.log(f"Total loss/val", loss.detach().item())
        preds = self.post_process(outputs, x)
        self.map_metric.update(preds, targets)
        
    def on_validation_epoch_end(self):
        mapmetric = self.map_metric.compute()['map']
        self.log("map/val", mapmetric)
        print("\nMAP: ", mapmetric)
        self.map_metric.reset()

In [10]:
from resnet_fcos import ResnetDet

network = ResnetDet(256, num_classes)
module = FCOS(
    num_classes=num_classes,
    network=network,
    optimizer=partial(torch.optim.Adam, lr=0.001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    pred_thresh=0.1,
)
 
gc.collect()
torch.cuda.empty_cache()
gc.collect()

trainer.fit(
    module,
    datamodule=dm,
)

loading annotations into memory...
Done (t=2.45s)
creating index...
index created!
loading annotations into memory...


/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /d/pfournie/dl_toolbox/dl_toolbox/à ranger/fcos/lightning_logs/version_7/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done (t=2.76s)
creating index...
index created!



  | Name           | Type                 | Params
--------------------------------------------------------
0 | network        | ResnetDet            | 29.9 M
1 | map_metric     | MeanAveragePrecision | 0     
2 | loss           | LossEvaluator        | 0     
3 | post_processor | FCOSPostProcessor    | 0     
--------------------------------------------------------
29.9 M    Trainable params
0         Non-trainable params
29.9 M    Total params
119.556   Total estimated model params size (MB)


Training 29888902 params out of 29888902
Epoch 16:  40%|█████████████████████████████████████████▌                                                              | 2/5 [00:43<01:04,  0.05it/s, v_num=7]

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.



Sanity Checking DataLoader 0:   0%|                                                                                                     | 0/2 [00:00<?, ?it/s]num_targets =84
num_targets =55
pos_inds.shape = torch.Size([306])
Sanity Checking DataLoader 0:  50%|██████████████████████████████████████████████▌                                              | 1/2 [00:00<00:00, 29.87it/s]num_targets =0
num_targets =68
pos_inds.shape = torch.Size([100])
Sanity Checking DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.25it/s]
MAP:  tensor(0.0143)
                                                                                                                                                              

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 16:   0%|                                                                                                                         | 0/5 [00:00<?, ?it/s]num_targets =0
num_targets =0
pos_inds.shape = torch.Size([0])
Epoch 16:  20%|████████████████████▊                                                                                   | 1/5 [00:00<00:01,  2.17it/s, v_num=7]num_targets =0
num_targets =0
pos_inds.shape = torch.Size([0])
Epoch 16:  40%|█████████████████████████████████████████▌                                                              | 2/5 [00:00<00:01,  2.14it/s, v_num=7]num_targets =0
num_targets =0
pos_inds.shape = torch.Size([0])
Epoch 16:  60%|██████████████████████████████████████████████████████████████▍                                         | 3/5 [00:01<00:00,  2.47it/s, v_num=7]

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
