In [1]:
from torch.utils.data import Dataset
from PIL import Image
import torch
from torchvision.transforms import v2
from torchvision import tv_tensors

class ObjDetDataset(Dataset):

    def __init__(self, data, transforms=None):
        image_paths = []
        targets = []
        for instance in data:
            image_paths.append(instance['image_path'])
            targets.append(instance["target"])
        self.image_paths = image_paths
        self.targets = targets
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        w, h = image.size
        image = v2.functional.pil_to_tensor(image)
        targets = self.targets[idx]
        targets = torch.Tensor(targets)
        bboxes = tv_tensors.BoundingBoxes(targets[:,:4], format="XYXY", canvas_size=(h,w))
        labels = targets[:, 4:]
        if self.transforms:
            image, bboxes = self.transforms(image, bboxes)
        return image, bboxes, labels, image_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import DataLoader, RandomSampler
from pathlib import Path
import pandas as pd
import numpy as np
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import CombinedLoader
from functools import partial

class PascalVOC(LightningDataModule):
    
    def __init__(
        self,
        data_path,
        train_tf,
        test_tf,
        batch_size_s,
        steps_per_epoch,
        num_workers,
        pin_memory,
        *args,
        **kwargs
    ):
        super().__init__()
        self.data_path = Path(data_path)
        self.train_tf = train_tf
        self.test_tf = test_tf
        self.batch_size_s = batch_size_s
        self.steps_per_epoch = steps_per_epoch
        self.num_workers = num_workers
        self.pin_memory = pin_memory
    
    def prepare_data(self):
        img_dir = self.data_path/"PASCALVOC/VOCdevkit/VOC2012/JPEGImages"
        self.instances = []
        labels = ['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'tvmonitor']
        for _, row in pd.read_pickle("voc_combined.csv").iterrows():
            img_path = row["filename"]
            labels_ = row["labels"]
            image_path = f"{img_dir}/{img_path}"
            labels_ = [[labels.index(l)] for l in labels_]
            targets_ = np.concatenate([row["bboxes"], labels_],
                                      axis=-1).tolist()
            self.instances.append({"image_path": image_path, "target": targets_})
    
    def setup(self, stage=None):
        split = int(0.95*len(self.instances))
        train_data = self.instances[:split]
        val_data = self.instances[split:]
        self.train_s_set = ObjDetDataset(train_data, transforms=self.train_tf)
        self.val_set = ObjDetDataset(val_data, transforms=self.test_tf)
    
    @staticmethod
    def _collate(batch):
        images_b, bboxes_b, labels_b, image_paths_b = list(zip(*batch))
        # don't stack bb because each batch elem may not have the same nb of bb
        return torch.stack(images_b), bboxes_b, labels_b, image_paths_b 
                
    def _dataloader(self, dataset):
        return partial(
            DataLoader,
            dataset=dataset,
            collate_fn=self._collate,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory
        )
                       
    def train_dataloader(self):
        return self._dataloader(self.train_s_set)(
            sampler=RandomSampler(
                self.train_s_set,
                replacement=True,
                num_samples=self.steps_per_epoch*self.batch_size_s
            ),
            drop_last=True,
            batch_size=self.batch_size_s
        )
    
    def val_dataloader(self):
        return self._dataloader(self.val_set)(
            shuffle=False,
            drop_last=False,
            batch_size=self.batch_size_s
        )
    
train_tf = v2.Compose([
    v2.Resize(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_tf = v2.Compose([
    v2.Resize(size=(224, 224), antialias=True),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dm = PascalVOC(
    data_path='/data',
    train_tf=train_tf,
    test_tf=test_tf,
    batch_size_s=4,
    steps_per_epoch=10,
    num_workers=4,
    pin_memory=True,
)
dm.prepare_data()
dm.setup()

In [3]:
dl = dm.train_dataloader()
for step, (images, bboxes, labels, image_paths) in enumerate(dl):
    print(images.shape)
    print(bboxes[0])
    print(labels[0])
    break

torch.Size([4, 3, 224, 224])
tensor([[  9.5727,  71.6800, 199.7493, 206.0800]])
tensor([[19.]])


In [4]:
#from dl_toolbox.networks import FCOS
#model = FCOS(num_classes=20)

from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelP6P7
#from torchvision.models.detection.backbone_utils import LastLevelP6P7
from dl_toolbox.networks.fcos import Head
import torch.nn as nn

#print(get_graph_node_names(resnet50())[0])

class FCOS(torch.nn.Module):
    
    def __init__(self, num_classes=19, out_channels=256):
        super(FCOS, self).__init__()
        backbone = resnet50()
        return_nodes = {
            'layer2.3.relu_2': 'layer2',
            'layer3.5.relu_2': 'layer3',
            'layer4.2.relu_2': 'layer4',
        }
        # Extract 4 main layers
        backbone_features = create_feature_extractor(backbone, return_nodes)
        # Dry run to get number of channels for FPN
        inp = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            out = backbone_features(inp)
        in_channels_list = [o.shape[1] for o in out.values()]
        # Build FPN
        fpn = FeaturePyramidNetwork(
            in_channels_list,
            out_channels=out_channels,
            extra_blocks=LastLevelP6P7(out_channels,out_channels)
        )
        self.feature_extractor = nn.Sequential(backbone_features, fpn)
        inp = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            out = self.feature_extractor(inp)
        self.feat_sizes = [o.shape[2:] for o in out.values()]
        self.head = Head(out_channels, num_classes)

    def forward(self, images):
        features = list(self.feature_extractor(images).values())
        box_cls, box_regression, centerness = self.head(features)
        # box_reg: lists of n_feat_level tensors BxHW(level)x4
        # why not tensor Bxsum_level(HW)x4 ?
        return features, box_cls, box_regression, centerness
    
network = FCOS(num_classes=19)
print(network.feat_sizes)
features, box_cls, box_regression, centerness = network(images)

[torch.Size([28, 28]), torch.Size([14, 14]), torch.Size([7, 7]), torch.Size([4, 4]), torch.Size([2, 2])]


In [5]:
import torch.nn as nn
import torchvision

INF = 100000000
MAXIMUM_DISTANCES_PER_LEVEL = [-1, 64, 128, 256, 512, INF]

def _match_reg_distances_shape(MAXIMUM_DISTANCES_PER_LEVEL, num_locs_per_level):
    level_reg_distances = []
    for m in range(1, len(MAXIMUM_DISTANCES_PER_LEVEL)):
        level_distances = torch.tensor([
            MAXIMUM_DISTANCES_PER_LEVEL[m - 1], MAXIMUM_DISTANCES_PER_LEVEL[m]
        ],
                                       dtype=torch.float32)
        locs_per_level = num_locs_per_level[m - 1]
        level_distances = level_distances.repeat(locs_per_level).view(
            locs_per_level, 2)
        level_reg_distances.append(level_distances)
    # return tensor of size sum of locs_per_level x 2
    return torch.cat(level_reg_distances, dim=0)

def _compute_centerness_targets(reg_targets):
    if len(reg_targets) == 0:
        return reg_targets.new_zeros(len(reg_targets))
    left_right = reg_targets[:, [0, 2]]
    top_bottom = reg_targets[:, [1, 3]]
    centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
    return torch.sqrt(centerness)


def _calculate_reg_targets(xs, ys, bbox_targets):
    l = xs[:, None] - bbox_targets[:, 0][None] # Lx1 - 1xT -> LxT
    t = ys[:, None] - bbox_targets[:, 1][None]
    r = bbox_targets[:, 2][None] - xs[:, None]
    b = bbox_targets[:, 3][None] - ys[:, None]
    return torch.stack([l, t, r, b], dim=2) # LxTx4


def _apply_distance_constraints(reg_targets, level_distances):
    max_reg_targets, _ = reg_targets.max(dim=2)
    return torch.logical_and(max_reg_targets >= level_distances[:, None, 0], \
                             max_reg_targets <= level_distances[:, None, 1])


def _prepare_labels(locations, targets_batch):
    device = targets_batch[0].device
    # nb of locs for bbox in original image size
    num_locs_per_level = [len(l) for l in locations]
    # L = sum locs per level x 2 : for each loc in all_locs, the max size of bb authorized
    level_distances = _match_reg_distances_shape(MAXIMUM_DISTANCES_PER_LEVEL,
                                                 num_locs_per_level).to(device)
    all_locations = torch.cat(locations, dim=0).to(device) # Lx2
    xs, ys = all_locations[:, 0], all_locations[:, 1] # L & L

    all_reg_targets = []
    all_cls_targets = []
    for targets in targets_batch:
        bbox_targets = targets[:, :4] # Tx4
        cls_targets = targets[:, 4] # T
        
        # for each loc in L and each target in T, the reg target
        reg_targets = _calculate_reg_targets(xs, ys, bbox_targets) # LxTx4

        is_in_boxes = reg_targets.min(dim=2)[0] > 0 # min returns values and indices -> LxT

        fits_to_feature_level = _apply_distance_constraints(
            reg_targets, level_distances).to(device) # LxT

        #bbox_areas = _calc_bbox_area(bbox_targets) # T
        bbox_areas = torchvision.ops.box_area(bbox_targets) # compared to above, does not deal with 0dim bb
        
        # area of each target bbox repeated for each loc with inf where the the loc is not 
        # in the target bbox or if the loc is not at the right level for this bbox size
        locations_to_gt_area = bbox_areas[None].repeat(len(all_locations), 1) # LxT
        locations_to_gt_area[is_in_boxes == 0] = INF
        locations_to_gt_area[fits_to_feature_level == 0] = INF
        
        # for each loc, area and target idx of the target of min area at that loc
        loc_min_area, loc_mind_idxs = locations_to_gt_area.min(dim=1) # val&idx, size L, idx in [0,T-1]

        reg_targets = reg_targets[range(len(all_locations)), loc_mind_idxs] # Lx4

        cls_targets = cls_targets[loc_mind_idxs] # L
        cls_targets[loc_min_area == INF] = 0
        
        all_cls_targets.append(
            torch.split(cls_targets, num_locs_per_level, dim=0))
        all_reg_targets.append(
            torch.split(reg_targets, num_locs_per_level, dim=0))
    # all_cls_targets contains B lists of num levels elem of loc_per_levelsx1
    return _match_pred_format(all_cls_targets, all_reg_targets, locations)


def _match_pred_format(cls_targets, reg_targets, locations):
    cls_per_level = []
    reg_per_level = []
    for level in range(len(locations)):
        cls_per_level.append(torch.cat([ct[level] for ct in cls_targets],
                                       dim=0))

        reg_per_level.append(torch.cat([rt[level] for rt in reg_targets],
                                       dim=0))
    # reg_per_level is a list of num_levels tensors of size Bxnum_loc_per_levelx4
    return cls_per_level, reg_per_level


def _get_positive_samples(cls_labels, reg_labels, box_cls_preds, box_reg_preds,
                          centerness_preds, num_classes):
    box_cls_flatten = []
    box_regression_flatten = []
    centerness_flatten = []
    labels_flatten = []
    reg_targets_flatten = []
    for l in range(len(cls_labels)):
        box_cls_flatten.append(box_cls_preds[l].permute(0, 2, 3, 1).reshape(
            -1, num_classes))
        box_regression_flatten.append(box_reg_preds[l].permute(0, 2, 3,
                                                               1).reshape(
                                                                   -1, 4))
        labels_flatten.append(cls_labels[l].reshape(-1))
        reg_targets_flatten.append(reg_labels[l].reshape(-1, 4))
        centerness_flatten.append(centerness_preds[l].reshape(-1))

    cls_preds = torch.cat(box_cls_flatten, dim=0)
    cls_targets = torch.cat(labels_flatten, dim=0)
    reg_preds = torch.cat(box_regression_flatten, dim=0)
    reg_targets = torch.cat(reg_targets_flatten, dim=0)
    centerness_preds = torch.cat(centerness_flatten, dim=0)
    pos_inds = torch.nonzero(cls_targets > 0).squeeze(1) # dim #loc in all batches where there is one cls to pred not background

    reg_preds = reg_preds[pos_inds]
    reg_targets = reg_targets[pos_inds]
    centerness_preds = centerness_preds[pos_inds]

    return reg_preds, reg_targets, cls_preds, cls_targets, centerness_preds, pos_inds

In [6]:
import torch.nn.functional as F

class LossEvaluator:

    def __init__(self):
        self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")

    def _get_cls_loss(self, cls_preds, cls_targets, total_num_pos):
        nc = cls_preds.shape[1]
        onehot = F.one_hot(cls_targets.long(), nc+1)[:,1:].float()
        cls_loss = torchvision.ops.sigmoid_focal_loss(cls_preds, onehot)
        return cls_loss.sum() / total_num_pos

    def _get_reg_loss(self, reg_preds, reg_targets, centerness_targets):
        reg_preds = reg_preds.reshape(-1, 4)
        reg_targets = reg_targets.reshape(-1, 4)
        reg_losses = torchvision.ops.distance_box_iou_loss(reg_preds, reg_targets, reduction='none')
        sum_centerness_targets = centerness_targets.sum()
        reg_loss = (reg_losses * centerness_targets).sum() / sum_centerness_targets
        return reg_loss

    def _get_centerness_loss(self, centerness_preds, centerness_targets,
                             total_num_pos):
        centerness_loss = self.centerness_loss_func(centerness_preds,
                                                    centerness_targets)
        return centerness_loss / total_num_pos

    def _evaluate_losses(self, reg_preds, cls_preds, centerness_preds,
                         reg_targets, cls_targets, centerness_targets,
                         pos_inds):
        total_num_pos = max(pos_inds.new_tensor([pos_inds.numel()]), 1.0)

        cls_loss = self._get_cls_loss(cls_preds, cls_targets, total_num_pos)

        if pos_inds.numel() > 0:
            reg_loss = self._get_reg_loss(reg_preds, reg_targets,
                                          centerness_targets)
            centerness_loss = self._get_centerness_loss(centerness_preds,
                                                        centerness_targets,
                                                        total_num_pos)
        else:
            reg_loss = reg_preds.sum() # 0 ??
            centerness_loss = centerness_preds.sum() # 0 ??

        return reg_loss, cls_loss, centerness_loss

    def __call__(self, locations, preds, targets_batch, num_classes):
        # reg_targets is a list of num_levels tensors of size Bxnum_loc_per_levelx4
        cls_targets, reg_targets = _prepare_labels(locations, targets_batch)

        cls_preds, reg_preds, centerness_preds = preds

        reg_preds, reg_targets, cls_preds, cls_targets, centerness_preds, pos_inds = _get_positive_samples(
            cls_targets, reg_targets, cls_preds, reg_preds, centerness_preds,
            num_classes)

        centerness_targets = _compute_centerness_targets(reg_targets)

        reg_loss, cls_loss, centerness_loss = self._evaluate_losses(
            reg_preds, cls_preds, centerness_preds, reg_targets, cls_targets,
            centerness_targets, pos_inds)

        return cls_loss, reg_loss, centerness_loss

In [14]:
import pytorch_lightning as pl
from torchmetrics.detection.mean_ap import MeanAveragePrecision

class FCOSPostProcessor(torch.nn.Module):

    def __init__(self, pre_nms_thresh, pre_nms_top_n, nms_thresh,
                 fpn_post_nms_top_n, min_size, num_classes):
        super(FCOSPostProcessor, self).__init__()
        self.pre_nms_thresh = pre_nms_thresh
        self.pre_nms_top_n = pre_nms_top_n
        self.nms_thresh = nms_thresh
        self.fpn_post_nms_top_n = fpn_post_nms_top_n
        self.min_size = min_size
        self.num_classes = num_classes

    def forward_for_single_feature_map(self, locations, cls_preds, reg_preds,
                                       cness_preds, image_size):
        B, C, _, _ = cls_preds.shape

        cls_preds = cls_preds.permute(0, 2, 3, 1).reshape(B, -1, C).sigmoid() # BxHWxC in [0,1]
        reg_preds = reg_preds.permute(0, 2, 3, 1).reshape(B, -1, 4)
        cness_preds = cness_preds.permute(0, 2, 3, 1).reshape(B, -1).sigmoid()

        candidate_inds = cls_preds > self.pre_nms_thresh # BxHWxC
        pre_nms_top_n = candidate_inds.reshape(B, -1).sum(1) # B
        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)

        cls_preds = cls_preds * cness_preds[:, :, None] # BxHWxC
        
        # Conversion en liste de bbox,scores,cls par image du batch
        # POURQUOI le filtre cls_preds > nms_thresh arrive pas après la mul par cness_preds ?
        bboxes = []
        cls_labels = []
        scores = []
        for i in range(B):
            # Tensor with true where score for loc l and class c > pre_nms_thresh
            per_candidate_inds = candidate_inds[i] # HWxC
            # tenseur de taille Lx2 avec les indices des elem de cls_preds où > nms_thresh
            per_candidate_nonzeros = per_candidate_inds.nonzero() 
            # L : positions dans [0,HW] des elem dont cls_preds(c) > nms_thresh 
            per_box_loc = per_candidate_nonzeros[:, 0]
            # L : classe dans [1, C] des elem dont cls_preds(h,w) > nms_thresh
            per_class = per_candidate_nonzeros[:, 1] + 1

            per_reg_preds = reg_preds[i] # HWx4
            # liste des bb des elem dont cls_preds(c) > nms_thresh 
            per_reg_preds = per_reg_preds[per_box_loc] # Lx4
            per_locations = locations[per_box_loc] # Lx2

            per_pre_nms_top_n = pre_nms_top_n[i]
            
            per_cls_preds = cls_preds[i] # HWxC
            # tenseur de taille L avec les elem de cls_preds*centerness tels que cls_preds > nms_thresh
            per_cls_preds = per_cls_preds[per_candidate_inds] 
            # si y a plus de per_prenms_topn qui passe nms_thresh (si L est trop longue)
            if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
                per_cls_preds, top_k_indices = per_cls_preds.topk(
                    per_pre_nms_top_n, sorted=False)
                per_class = per_class[top_k_indices]
                per_reg_preds = per_reg_preds[top_k_indices]
                per_locations = per_locations[top_k_indices]
            
            # Rewrites bbox (x0,y0,x1,y1) from reg targets (l,t,r,b) following eq (1) in paper
            per_bboxes = torch.stack([
                per_locations[:, 0] - per_reg_preds[:, 0],
                per_locations[:, 1] - per_reg_preds[:, 1],
                per_locations[:, 0] + per_reg_preds[:, 2],
                per_locations[:, 1] + per_reg_preds[:, 3],
            ],
                                     dim=1)
            per_bboxes = torchvision.ops.clip_boxes_to_image(per_bboxes, (image_size, image_size))
            #detections = _clip_to_image(detections, (image_size, image_size))
            per_bboxes = per_bboxes[torchvision.ops.remove_small_boxes(per_bboxes, self.min_size)]
            #detections = remove_small_boxes(detections, self.min_size)
            bboxes.append(per_bboxes)
            cls_labels.append(per_class)
            scores.append(torch.sqrt(per_cls_preds))
        
        #bboxes is a list of B tensors of size Lx4 (potentially filtered with pre_nms_threshold)
        return bboxes, scores, cls_labels

    def forward(self, locations, cls_preds, reg_preds, cness_preds, image_size):
        # loc: list of n_feat_level tensors of size HW(level)
        # reg_preds: list of n_feat_level tensors BxHW(level)x4
        
        # list of n_feat_level lists of B tensors of size Lx4
        sampled_boxes = []
        all_scores = []
        all_classes = []
        for l, o, b, c in list(zip(locations, cls_preds, reg_preds,
                                   cness_preds)):
            boxes, scores, cls_labels = self.forward_for_single_feature_map(
                l, o, b, c, image_size)
            # boxes : list of B tensors Lx4
            sampled_boxes.append(boxes)
            all_scores.append(scores)
            all_classes.append(cls_labels)
        
        # list of B lists of n_feat_level bbox preds
        all_bboxes = list(zip(*sampled_boxes))
        all_scores = list(zip(*all_scores))
        all_classes = list(zip(*all_classes))
    
        # list of B tensors with all feature level bbox preds grouped
        all_bboxes = [torch.cat(bboxes, dim=0) for bboxes in all_bboxes]
        all_scores = [torch.cat(scores, dim=0) for scores in all_scores]
        all_classes = [torch.cat(classes, dim=0) for classes in all_classes]
        boxes, scores, classes = self.select_over_all_levels(
            all_bboxes, all_scores, all_classes)

        return boxes, scores, classes

    def select_over_all_levels(self, boxlists, scores, classes):
        num_images = len(boxlists)
        all_picked_boxes, all_confidence_scores, all_classes = [], [], []
        for i in range(num_images):
            picked_indices = torchvision.ops.nms(boxlists[i], scores[i], self.nms_thresh)
            picked_boxes = boxlists[i][picked_indices]
            confidence_scores = scores[i][picked_indices]
            picked_classes = classes[i][picked_indices]

            number_of_detections = len(picked_indices)
            if number_of_detections > self.fpn_post_nms_top_n > 0:
                image_thresh, _ = torch.kthvalue(
                    confidence_scores.cpu(),
                    number_of_detections - self.fpn_post_nms_top_n + 1)
                keep = confidence_scores >= image_thresh.item()

                keep = torch.nonzero(keep).squeeze(1)
                picked_boxes, confidence_scores, picked_classes = picked_boxes[
                    keep], confidence_scores[keep], picked_classes[keep]

            keep = confidence_scores >= self.pre_nms_thresh
            picked_boxes, confidence_scores, picked_classes = picked_boxes[
                keep], confidence_scores[keep], picked_classes[keep]

            all_picked_boxes.append(picked_boxes)
            all_confidence_scores.append(confidence_scores)
            all_classes.append(picked_classes)
        
        # all_picked_boxes : list of B tensors with all feature level bbox preds filtered by nms
        return all_picked_boxes, all_confidence_scores, all_classes

import schedulefree

class FCOS(pl.LightningModule):
    def __init__(
        self,
        network,
        num_classes,
        *args,
        **kwargs
    ):
        super().__init__()
        self.le = LossEvaluator()
        self.post_processor = FCOSPostProcessor(
            pre_nms_thresh=0.3,
            pre_nms_top_n=1000,
            nms_thresh=0.45,
            fpn_post_nms_top_n=50,
            min_size=0,
            num_classes=num_classes)
        self.fpn_strides = [8, 16, 32, 64, 128]
        self.feat_sizes = network.feat_sizes
        self.num_classes = num_classes
        self.network = network
        self.map_metric = MeanAveragePrecision()
        # locations is a list of num_feat_level elem, where each elem indicates the tensor of 
        # locations in the original image corresponding to each location in the feature map at this level
        self.locations = self._compute_locations()
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=1e-3,
            betas=(0.9, 0.999),
            weight_decay=5e-2,
            eps=1e-8,
        )
        #opt = schedulefree.AdamWScheduleFree(self.parameters(), lr=0.0025)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=1e-3,
            steps_per_epoch=10,
            epochs=20
        )
        #return opt
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step"
            },
        }
    
    def _compute_locations(self):
        locations = []
        
        def _locations_per_level(h, w, s):
            locs_x = [i for i in range(w)]
            locs_y = [i for i in range(h)]
            locs_x = [s / 2 + x * s for x in locs_x]
            locs_y = [s / 2 + y * s for y in locs_y]
            locs = [(y, x) for x in locs_x for y in locs_y]
            return torch.tensor(locs)
        
        for level, (h,w) in enumerate(self.feat_sizes):
            locs = _locations_per_level(h, w, self.fpn_strides[level])
            locations.append(locs)
        return locations
    
    def forward(self, images, targets_batch=None):
        features, box_cls, box_regression, centerness = self.network(images)
        locations = [l.to(features[0].device) for l in self.locations]
        image_size = images.shape[-1]
        outputs = {}
        predicted_boxes, scores, all_classes = self.post_processor(
            locations, box_cls, box_regression, centerness, image_size)
        
        
        if targets_batch != None:
            cls_loss, reg_loss, centerness_loss = self.le(
                locations, (box_cls, box_regression, centerness),
                targets_batch,
                num_classes=self.num_classes)
            outputs["cls_loss"] = cls_loss
            outputs["reg_loss"] = reg_loss
            outputs["centerness_loss"] = centerness_loss
            outputs["combined_loss"] = cls_loss + reg_loss + centerness_loss



        outputs["predicted_boxes"] = predicted_boxes
        outputs["scores"] = scores
        outputs["pred_classes"] = all_classes
        return outputs
    
    #def on_train_epoch_start(self):
    #    print('\n opt train')
    #    self.optimizers().train()
    #    
    #def on_validation_start(self):
    #    print('\n opt eval')
    #    self.optimizers().eval()    
    
    def training_step(self, batch, batch_idx):
        x, bboxes, labels, image_paths = batch
        y = [torch.cat([bb, l], dim=1) for bb, l in zip(bboxes, labels)]
        results = self.forward(x, y)
        loss = results["combined_loss"]
        self.log(f"loss/train", loss.detach().item())
        self.log(f"cls_loss/train", results["cls_loss"].detach().item())
        self.log(f"reg_loss/train", results["reg_loss"].detach().item())
        self.log(f"centerness_loss/train", results["centerness_loss"].detach().item())
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, bboxes, labels, image_paths = batch
        y = [torch.cat([bb, l], dim=1) for bb, l in zip(bboxes, labels)]
        results = self.forward(x, y)
        loss = results["combined_loss"]
        preds = [{'boxes': bb, 'scores': s, 'labels': l} for bb,s,l in zip(
            results["predicted_boxes"], results["scores"], results["pred_classes"]
        )]
        target_bb = [t[:, :4] for t in y]
        target_l = [t[:, 4] for t in y]
        targets = [{'boxes': bb, 'labels': l} for bb,l in zip(target_bb, target_l)]
        self.map_metric.update(preds, targets)
        self.log(f"loss/val", loss.detach().item())
        self.log(f"cls_loss/val", results["cls_loss"].detach().item())
        self.log(f"reg_loss/val", results["reg_loss"].detach().item())
        self.log(f"centerness_loss/val", results["centerness_loss"].detach().item())
        
    def on_validation_epoch_end(self):
        mapmetric = self.map_metric.compute()['map']
        self.log("map/val", mapmetric)
        print("\nMAP: ", mapmetric)
        self.map_metric.reset()

In [15]:

import pytorch_lightning as pl
from dl_toolbox.callbacks import ProgressBar

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=20,
    limit_train_batches=5,
    limit_val_batches=5,
    callbacks=[ProgressBar()]
)

module = FCOS(
    network,
    num_classes=19
)


trainer.fit(
    module,
    datamodule=dm,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


                                   

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                 | Params
--------------------------------------------------------
0 | post_processor | FCOSPostProcessor    | 0     
1 | network        | FCOS                 | 29.8 M
2 | map_metric     | MeanAveragePrecision | 0     
--------------------------------------------------------
29.8 M    Trainable params
0         Non-trainable params
29.8 M    Total params
119.178   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]
 opt eval
Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 20.41it/s]
MAP:  tensor(0.)
                                                                                                                                        

  rank_zero_warn(


Epoch 0:   0%|                                                                                                    | 0/5 [00:00<?, ?it/s]
 opt train
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.30it/s, v_num=3]
Validation: 0it [00:00, ?it/s][A
 opt eval

MAP:  tensor(0.)
Epoch 1:   0%|                                                                                           | 0/5 [00:00<?, ?it/s, v_num=3]
 opt train
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.01it/s, v_num=3]
Validation: 0it [00:00, ?it/s][A
 opt eval

MAP:  tensor(0.)
Epoch 2:   0%|                                                                                           | 0/5 [00:00<?, ?it/s, v_num=3]
 opt train
Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.43it/s, v_num=3]
Validation: 0it [00:0

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.07it/s, v_num=3]
