In [28]:
import timm
import torch
import torchvision
from timm.layers import resample_abs_pos_embed     
import torch.nn as nn
from torchvision.ops import box_convert, generalized_box_iou
from scipy.optimize import linear_sum_assignment
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision import tv_tensors
from torchvision.tv_tensors import BoundingBoxFormat
from dl_toolbox.utils import *


class SetCriterion(nn.Module):
    def __init__(self, num_classes, eos_coef):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.eos_coef = eos_coef
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o
        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes
        loss_giou = 1 - torch.diag(generalized_box_iou(
            box_convert(src_boxes, 'xywh', 'xyxy'),
            box_convert(target_boxes, 'xywh', 'xyxy')))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx
    
    def norm_targets(self, targets):
        return [{'labels':t['labels'], 'boxes':to_xywh(norm(to_xyxy(t['boxes'])))} for t in targets]

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        targets = self.norm_targets(targets)
        matches = self.hungarian_matching(outputs, targets)
        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        dev = next(iter(outputs.values())).device
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=dev)
        # Compute all the requested losses
        losses = {}
        losses.update(self.loss_labels(outputs, targets, matches, num_boxes))
        losses.update(self.loss_boxes(outputs, targets, matches, num_boxes))
        loss = sum(losses.values())
        return loss
    
    @torch.no_grad()
    def hungarian_matching(self, outputs, targets):
        """ 
        Params:
            outputs=dict:
                 "pred_logits": Tensor of dim [B, num_queries, num_classes] with the class logits
                 "pred_boxes": Tensor of dim [B, num_queries, 4] with the pred box coord
            targets=list (len(targets) = batch_size) of dicts, each dict:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coord

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])
            
        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
        
        # Compute the giou cost betwen boxes
        out_bbox = box_convert(out_bbox, 'xywh', 'xyxy')
        tgt_bbox = box_convert(tgt_bbox, 'xywh', 'xyxy')
        cost_giou = -generalized_box_iou(out_bbox, tgt_bbox)

        # Final cost matrix
        C = cost_bbox + cost_class + cost_giou
        C = C.view(bs, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]

        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
    
class Yolos(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        det_token_num,
        backbone,
        optimizer,
        scheduler,
        pred_thresh,
        tta=None,
        sliding=None,
        *args,
        **kwargs
    ):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=True, dynamic_img_size=True)
        hdim = self.backbone.embed_dim 
        self.det_token_num = det_token_num
        self.add_det_tokens_to_backbone()
        self.class_embed = torchvision.ops.MLP(hdim, [hdim,hdim,num_classes+1])
        self.bbox_embed = torchvision.ops.MLP(hdim, [hdim,hdim,4])
        self.loss = SetCriterion(num_classes, 0.5)
        self.num_classes = num_classes
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.map_metric = MeanAveragePrecision(box_format='xywh', backend='faster_coco_eval')
        self.sliding = sliding
        self.pred_thresh = pred_thresh
        
    def add_det_tokens_to_backbone(self):
        det_token = nn.Parameter(torch.zeros(1, self.det_token_num, self.backbone.embed_dim))
        self.det_token = torch.nn.init.trunc_normal_(det_token, std=.02)
        det_pos_embed = nn.Parameter(torch.zeros(1, self.det_token_num, self.backbone.embed_dim))
        self.det_pos_embed = torch.nn.init.trunc_normal_(det_pos_embed, std=.02)
        self.backbone.num_prefix_tokens += self.det_token_num
    
    def configure_optimizers(self):
        train_params = list(filter(lambda p: p[1].requires_grad, self.named_parameters()))
        nb_train = sum([int(torch.numel(p[1])) for p in train_params])
        nb_tot = sum([int(torch.numel(p)) for p in self.parameters()])
        print(f"Training {nb_train} params out of {nb_tot}")
        optimizer = self.optimizer(params=[p[1] for p in train_params])
        scheduler = self.scheduler(optimizer)
        return [optimizer], [scheduler]
    
    def raw_logits_and_bboxs(self, x):
        x = self.backbone.patch_embed(x)
        # If cls token in backbone
        cls_pos_embed = self.backbone.pos_embed[:, 0, :][:,None] # size 1x1xembed_dim
        patch_pos_embed = self.backbone.pos_embed[:, 1:, :] # 1xnum_patchxembed_dim
        pos_embed = torch.cat((cls_pos_embed, self.det_pos_embed, patch_pos_embed), dim=1)
        if self.backbone.dynamic_img_size:
            B, H, W, C = x.shape
            pos_embed = resample_abs_pos_embed(
                pos_embed,
                (H, W),
                num_prefix_tokens=self.backbone.num_prefix_tokens,
            )
            x = x.view(B, -1, C)
        # If cls token in backbone
        cls_token = self.backbone.cls_token.expand(x.shape[0], -1, -1) 
        det_token = self.det_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, det_token, x], dim=1)
        x += pos_embed
        x = self.backbone.pos_drop(x)
        x = self.backbone.patch_drop(x)
        x = self.backbone.norm_pre(x)
        x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        x = x[:,1:1+self.det_token_num,...]
        outputs_class = self.class_embed(x)
        outputs_coord = self.bbox_embed(x).sigmoid()
        out = {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}
        return out
    
    def forward(self, x, sliding=None):
        if sliding is not None:
            auxs = [self.forward(aux) for aux in sliding(x)]
            return sliding.merge(auxs)
        else:
            return self.raw_logits_and_bboxs(x)
    
    @torch.no_grad()
    def post_process(self, outputs, images):
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
        b,c,h,w = images.shape
        xywh = [tv_tensors.BoundingBoxes(
            bb,
            format=tv_tensors.BoundingBoxFormat.XYWH,
            canvas_size=(h,w)
        ) for bb in out_bbox]
        xyxy = [unnorm(to_xyxy(bb)) for bb in xywh]
        
        prob = F.softmax(out_logits, -1)
        scores, labels = prob[..., :-1].max(-1)

        if self.sliding:
            idxs = [torchvision.ops.nms(bb, s, 0.8) for bb, s in zip(xyxy, scores)]
            xyxy = [bb[idx] for bb, idx in zip(xyxy, idxs)]
            scores = [s[idx] for s, idx in zip(scores, idxs)]
            labels = [l[idx] for l, idx in zip(labels, idxs)]

        xywh = [to_xywh(bb, old_format='XYXY') for bb in xyxy]
        
        results = [{'scores': s, 'labels': l, 'boxes': bb} for s, l, bb in zip(scores, labels, xywh)]
        return results
    
    def filter_preds(self, preds):
        keep = [torch.where(p['scores']>self.pred_thresh) for p in preds]
        filtered = [{
            'scores': p['scores'][k], 
            'labels': p['labels'][k],
            'boxes': p['boxes'][k]
        } for p, k in zip(preds, keep)]
        return filtered

    def training_step(self, batch, batch_idx):
        x, targets, paths = batch["sup"]
        outputs = self.forward(x)
        loss = self.loss(outputs, targets)
        self.log(f"loss/train", loss.detach().item())
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, targets, paths = batch
        outputs = self.forward(x, sliding=self.sliding)
        loss = self.loss(outputs, targets)
        self.log(f"Total loss/val", loss.detach().item())
        preds = self.post_process(outputs, x)
        self.map_metric.update(preds, targets)
        
    def on_validation_epoch_end(self):
        mapmetric = self.map_metric.compute()['map']
        self.log("map/val", mapmetric)
        print("\nMAP: ", mapmetric)
        self.map_metric.reset()
        
    def predict_step(self, batch, batch_idx):
        x, targets, paths = batch
        outputs = self.forward(x, sliding=self.sliding)
        preds = self.post_process(outputs, x)
        return preds

In [29]:
from dl_toolbox.utils import get_tiles

class SlidingDet:
    def __init__(
        self,
        nols,
        nrows,
        width,
        height,
        step_w,
        step_h
    ):
        self.nols = nols
        self.nrows = nrows
        self.tiles = list(get_tiles(nols, nrows, width, height, step_w, step_h))
        
    def __call__(self, img):
        imgs = []
        for co, ro, w, h in self.tiles:
            imgs.append(img[...,ro:ro+h,co:co+w])
        return imgs
    
    def offset_bb(self, xyxy, co, ro):
        xyxy[..., 0::2].add_(co)
        xyxy[..., 1::2].add_(ro)
        xyxy.canvas_size = (float(self.nrows), float(self.nols))
        return xyxy
    
    def merge(self, outputs):
        # outputs is a list of dictionaries, each dict contains logits (bs, nbb, numcls) and bboxs (bs, nbb, 4) for one tile of the initial batch of big images
        # bbox format xywh, normalisé par rapport à la tile size
        # obj : convert this list into a single dict with one logits tensor (bs, nbb * nb_tiles, numcls) and one bbx tensor (bs, nbb*ntiles, 4) where bbx are normalized xywh with respect to the full img size
        logits = torch.cat([out['pred_logits'] for out in outputs], dim=1)
        boxes = []
        logits = []
        for (co, ro, w, h), out in zip(self.tiles, outputs):
            logits.append(out['pred_logits'])
            tv_bb = [tv_tensors.BoundingBoxes(
                bb,
                format=tv_tensors.BoundingBoxFormat.XYWH,
                canvas_size=(h,w)
            ) for bb in out['pred_boxes']]
            unnorm_tv_bb = [unnorm(to_xyxy(bb)) for bb in tv_bb]
            offset_tv_bb = [self.offset_bb(bb, co, ro) for bb in unnorm_tv_bb]
            norm_tv_bb = [norm(bb) for bb in offset_tv_bb]
            bb = torch.stack([to_xywh(bb).as_subclass(torch.Tensor) for bb in norm_tv_bb])
            boxes.append(bb)
        all_logits = torch.cat(logits, dim=1)
        all_bb = torch.cat(boxes, dim=1)
        return {'pred_logits': all_logits, 'pred_boxes': all_bb}

In [30]:
import pytorch_lightning as pl
from dl_toolbox.callbacks import ProgressBar
from dl_toolbox import datamodules
from dl_toolbox import modules
import torchvision.transforms.v2 as v2
from functools import partial

import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from dl_toolbox.callbacks import ProgressBar, Finetuning, Lora, TiffPredsWriter, CalibrationLogger
from functools import partial
import gc


train_tf = v2.Compose(
    [
        v2.RandomCrop(224),
        v2.SanitizeBoundingBoxes(),
        v2.Normalize([0.5]*3, [0.5]*3)
    ]
)

test_tf = v2.Compose(
    [
        v2.CenterCrop(400),
        v2.SanitizeBoundingBoxes(),
        v2.Normalize([0.5]*3, [0.5]*3)
    ]
)

sliding = SlidingDet(
    nols=400,
    nrows=400,
    width=224,
    height=224,
    step_w=112,
    step_h=112
)
 
dm = datamodules.xView(
    data_path='/data',
    merge='building',
    train_tf=train_tf,
    test_tf=test_tf,
    batch_tf=None,
    batch_size=2,
    num_workers=0,
    pin_memory=False
)

lora = Lora('backbone', 4, True)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=2,
    limit_train_batches=1,
    limit_val_batches=1,
    limit_predict_batches=1,
    callbacks=[ProgressBar(), lora]
)

module = Yolos(
    num_classes=91,
    det_token_num=10,
    backbone='vit_small_patch14_dinov2',
    optimizer=partial(torch.optim.Adam, lr=0.001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    pred_thresh=0.1,
    sliding=sliding
)

gc.collect()
torch.cuda.empty_cache()
gc.collect()

trainer.fit(
    module,
    datamodule=dm,
)

trainer.predict(
    module,
    datamodule=dm,
    #ckpt_path='/data/outputs/coco_yolos/2024-06-05_180011/0/checkpoints/last.ckpt',
    return_predictions=False
) 

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_predict_batches=1)` was configured so 1 batch will be used.


loading annotations into memory...
Done (t=2.46s)
creating index...
index created!
loading annotations into memory...
Done (t=2.71s)
creating index...
index created!


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                 | Params
------------------------------------------------------
0 | backbone     | VisionTransformer    | 22.4 M
1 | class_embed  | MLP                  | 331 K 
2 | bbox_embed   | MLP                  | 297 K 
3 | loss         | SetCriterion         | 0     
4 | map_metric   | MeanAveragePrecision | 0     
  | other params | n/a                  | 7.7 K 
------------------------------------------------------
930 K     Trainable params
22.1 M    Non-trainable params
23.0 M    Total params
91.948    Total estimated model params size (MB)


Training 930912 params out of 22987104
Sanity Checking DataLoader 0:   0%|                                                                                                       | 0/1 [00:00<?, ?it/s][torch.Size([90, 4]), torch.Size([90, 4])]
[torch.Size([4, 4]), torch.Size([4, 4])]
Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11.46it/s]
MAP:  tensor(-1.)
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.19it/s, v_num=144]
Validation: |                                                                                                                             | 0/? [00:00<?, ?it/s][A[torch.Size([90, 4]), torch.Size([90, 4])]
[torch.Size([5, 4]), torch.Size([4, 4])]

MAP:  tensor(-1.)
Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.41it/s, v_num=144]
loading annotations into memory...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done (t=2.95s)
creating index...
index created!
Predicting: |                                                                                                                             | 0/? [00:00<?, ?it/s][torch.Size([90, 4]), torch.Size([90, 4])]
[torch.Size([5, 4]), torch.Size([6, 4])]
