**Notice**: Rerun is suggested to carry out on Kaggle.

# **1 - Setup**

In [1]:
!git clone -b "main" "https://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@github.com/minhngt62/dl-sirst.git"
%mv "dl-sirst" "cv_sirst"

Cloning into 'dl-sirst'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 35 (delta 3), reused 31 (delta 3), pack-reused 0[K
Unpacking objects: 100% (35/35), 1.73 MiB | 2.29 MiB/s, done.


In [2]:
!pip install pycocotools

Collecting pycocotools
  Downloading pycocotools-2.0.6.tar.gz (24 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: pycocotools
  Building wheel for pycocotools (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pycocotools: filename=pycocotools-2.0.6-cp37-cp37m-linux_x86_64.whl size=373762 sha256=a31790318bdaefce8f48cc4295bf590716396fc01706750e7c9eb0d69a9cff39
  Stored in directory: /root/.cache/pip/wheels/06/f6/f9/9cc49c6de8e3cf27dfddd91bf46595a057141d4583a2adaf03
Successfully built pycocotools
Installing collected packages: pycocotools
Successfully installed pycocotools-2.0.6
[0m

In [3]:
import time
import gc
import copy
import cv2
import numpy as np
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from typing import List, Union, Dict, Optional, Tuple, Any
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

import torch
from torch import nn, Tensor
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision.ops as O
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.detection.mean_ap import MeanAveragePrecision as mAP
import torchvision

# **2 - Dataset**

In [4]:
# load NUDT-SIRST dataset (train)
from cv_sirst.datasets.datasets import NUDTSIRSTDataset
from cv_sirst.datasets.transforms import TransformComposer, XywhToXyxy

target_transform = TransformComposer(transforms={
    "default": [XywhToXyxy()],
})
transform = TransformComposer()

nudtsirst_train = NUDTSIRSTDataset('/kaggle/input/nudtsirst/annotation_train.csv', 
                                   '/kaggle/input/nudtsirst/nudtsirst', 
                                   transform=transform, 
                                   target_transform=target_transform)
nudtsirst_train[0]

(tensor([[[0.7451, 0.7451, 0.7451,  ..., 0.4118, 0.4078, 0.4039],
          [0.7451, 0.7451, 0.7451,  ..., 0.4118, 0.4078, 0.4039],
          [0.7490, 0.7490, 0.7490,  ..., 0.4118, 0.4078, 0.4039],
          ...,
          [0.6157, 0.6078, 0.6039,  ..., 0.6078, 0.6118, 0.6118],
          [0.6118, 0.6078, 0.6039,  ..., 0.6118, 0.6157, 0.6157],
          [0.6078, 0.6039, 0.6000,  ..., 0.6157, 0.6196, 0.6196]]]),
 {'boxes': tensor([[555.,  27., 562.,  34.]]), 'labels': tensor([1])})

In [5]:
# load NUDT-SIRST dataset (test)
target_transform = TransformComposer(transforms={
    "default": [XywhToXyxy()],
}, training=False)
transform = TransformComposer(training=False)

nudtsirst_test = NUDTSIRSTDataset('/kaggle/input/nudtsirst/annotation_test.csv', 
                                  '/kaggle/input/nudtsirst/nudtsirst', 
                                  transform=transform, 
                                  target_transform=target_transform)
nudtsirst_test.eval()
nudtsirst_test[0]

(tensor([[[0.4824, 0.4824, 0.4824,  ..., 0.4980, 0.4980, 0.4980],
          [0.4824, 0.4824, 0.4824,  ..., 0.4980, 0.4980, 0.4980],
          [0.4824, 0.4824, 0.4824,  ..., 0.4980, 0.4980, 0.4980],
          ...,
          [0.3882, 0.3843, 0.3765,  ..., 0.7922, 0.7725, 0.7569],
          [0.3922, 0.3882, 0.3765,  ..., 0.7412, 0.7333, 0.7216],
          [0.4039, 0.3961, 0.3843,  ..., 0.7176, 0.7137, 0.7137]]]),
 {'boxes': tensor([[1194.,  689., 1199.,  694.]]), 'labels': tensor([1])})

# **3 - Models**

## 3.1 - Corner Proposal

In [6]:
# build anchor generator based on corner detection algorithm
class AnchorGenerator(nn.Module):
    def __init__(
        self, 
        anc_size: Optional[Tuple[int, int]] = (31, 31), #  h, w
        max_corners: int = 600, 
        quality_level: int = 0.002, 
        min_distance: int = 31,
        ):
        super().__init__()
        self.anc_size = anc_size
        
        self.max_corners = max_corners
        self.quality_level = quality_level
        self.min_distance = min_distance
    
    def forward(
        self, 
        images: List[Tensor]
        ) -> Tensor:
        anc_bases = torch.neg(torch.ones(len(images), self.max_corners, 4)) * self.min_distance # [B, n_ancs, 4]
        for b, image in enumerate(images):
            image = image.squeeze().cpu().detach().numpy()
            corners = cv2.goodFeaturesToTrack(image, self.max_corners, self.quality_level, self.min_distance)
            corners = np.int0(corners)
            anc_centers = torch.from_numpy(corners).squeeze()
            for anc_id, (x, y) in enumerate(anc_centers):
                xmin = x - self.anc_size[1] // 2
                ymin = y - self.anc_size[0] // 2
                xmax = x + self.anc_size[1] // 2
                ymax = y + self.anc_size[0] // 2
                anc_boxes = torch.Tensor([xmin, ymin, xmax, ymax])
                anc_bases[b, anc_id, :] = O.clip_boxes_to_image(anc_boxes, size=(image.shape[0], image.shape[1]))
            del image
            del corners
            gc.collect()
        anc_bases = anc_bases.cuda() if torch.cuda.is_available() else anc_bases
        return anc_bases

In [7]:
# build corner proposal module
class CornerProposal(nn.Module):
    def __init__(
        self, 
        max_corners: int = 600, 
        min_distance: int = 31,
        pos_thresh: float = 0.8,
        neg_thresh: float = 0.2,
        ):
        super().__init__()
        self.max_corners = max_corners
        self.min_distance = min_distance

        self.pos_thresh = pos_thresh
        self.neg_thresh = neg_thresh
    
    def forward(
        self, 
        images: Tensor, # [B, c, h, w]
        anc_bases: Tensor, # [B, max_corners, 4]
        targets: Optional[Dict[str, Tensor]] = None
        ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None], Tensor]:
        B, c0, h, w = len(images), images[0].size(dim=0), images[0].size(dim=1), images[0].size(dim=2)
        
        if self.training:
            N = targets["labels"].size(dim=1) # max number of objects per image across batch
            gts, cls = targets["boxes"], targets["labels"] # [B, N, 4], [B, N]
            # compute an "ioa" matrix per image
            ioas_mat = torch.zeros((B, N, self.max_corners), device=gts.device) # [B, N, max_corners]
            for i in range(B):
                gt_boxes = gts[i] # [N, 4]
                anc_boxes = anc_bases[i] # [max_corners, 4]
                ioas_mat[i, :] = self._box_ioa(gt_boxes, anc_boxes) # [N, max_corners]
            ioas_mat = torch.transpose(ioas_mat, 1, 2) # [B, max_corners, N]
            gt_most_overlap_inds = torch.argmax(ioas_mat, dim=2) # [B, max_corners]
            max_iou_per_gt_box, _ = ioas_mat.max(dim=1, keepdim=True) # [B, 1, N]

            # get positive anchor boxes
            positive_anc_mask = torch.logical_and(ioas_mat == max_iou_per_gt_box, max_iou_per_gt_box > 0)
            positive_anc_mask = torch.logical_or(positive_anc_mask, ioas_mat > self.pos_thresh) # [B, max_corners, 1]
            pos_inds = positive_anc_mask.nonzero(as_tuple=True)[:-1] # 2 tensors of indices
            # get negative anchor boxes
            negative_anc_mask = ioas_mat < self.neg_thresh # [B, max_corners, 1]
            neg_inds = negative_anc_mask.nonzero(as_tuple=True)[:-1] # 2 tensors of indices

            if pos_inds[0].numel() != 0 or neg_inds[0].numel() != 0:
                # map gts to corr anchors
                gts_expand = gts.view(B, 1, N, 4).expand(B, self.max_corners, N, 4)
                ancs_to_gts = torch.gather(gts_expand, -2, gt_most_overlap_inds.reshape(B, self.max_corners, 1, 1).repeat(1, 1, 1, 4)) # [B, max_corners, 1, 4]
                ancs_to_gts = ancs_to_gts.flatten(start_dim=2) # [B, max_corners, 4]
                ancs_to_gts = torch.where(positive_anc_mask, ancs_to_gts, torch.tensor(-self.min_distance, device=gts.device).float())

                # map cls to corr anchors
                cls_expand = cls.view(B, 1, N).expand(B, self.max_corners, N)
                ancs_to_cls = torch.gather(cls_expand, -1, gt_most_overlap_inds.unsqueeze(-1)).squeeze(-1) # [B, max_corners]
                ancs_to_cls = torch.where(positive_anc_mask.flatten(start_dim=1), ancs_to_cls, 0)

                # extract a sastified anchor's indices
                pos_ind_selected = torch.stack(pos_inds)[:, :B] # [2, ~B]
                neg_ind_selected = torch.stack(neg_inds)[:, :pos_ind_selected.size(dim=1)] # [2, ~B]
                roi_inds = tuple(torch.cat((pos_ind_selected, neg_ind_selected), dim=1)) # [2, ~2xB]

                # extract ROIs
                roi_bases = anc_bases[roi_inds] # [~2xB, 4] -> xmin, ymin, xmax, ymax
                roi_centers = (roi_bases[:, :2] + roi_bases[:, 2:]) // 2 # [~2xB, 2]
                images = images[roi_inds[:-1]] # [~2xB, c, h, w]
                rois = self._extract_glimpse(
                    images, 
                    size=(self.min_distance, self.min_distance), 
                    offsets=roi_centers
                    ) # [~2xB, c, min_distance, min_distance]
                roi_cls = ancs_to_cls[roi_inds] # [~2xB]
                #roi_gts = (ancs_to_gts[roi_inds] - torch.cat((roi_bases[roi_inds[:-1]][:, :2], roi_bases[roi_inds[:-1]][:, :2]), dim=1)) / (self.min_distance - 1) # [~2xB, 4]
                roi_gts = ancs_to_gts[roi_inds]

                return rois, roi_cls, roi_gts, roi_bases[roi_inds[:-1]][:, :2] # [~2xB, 2]
        
        anc_centers = (anc_bases[:, :, :2] + anc_bases[:, :, :2]) // 2 # [B, max_corners, 2]
        rois = self._extract_glimpses(
            images, 
            size=(self.min_distance, self.min_distance), 
            offsets=anc_centers
            ) # [B, max_corners, c, min_distance, min_distance]
        return rois, None, None, anc_bases[:, :, :2] # [B, max_corners, 2]

    def _box_ioa(
        self, 
        gt_boxes: Tensor, 
        anc_boxes: Tensor
        ) -> Tensor:
        ioa_mat = torch.zeros((gt_boxes.size(dim=0), anc_boxes.size(dim=0)))
        for i in range(len(gt_boxes)):
            xmin = torch.max(gt_boxes[i, 0], anc_boxes[:, 0])
            ymin = torch.max(gt_boxes[i, 1], anc_boxes[:, 1])
            xmax = torch.min(gt_boxes[i, 2], anc_boxes[:, 2])
            ymax = torch.min(gt_boxes[i, 3], anc_boxes[:, 3])

            w = (xmax - xmin + 1).double()
            h = (ymax - ymin + 1).double()
            intersection = torch.where((w > 0) & (h > 0), w * h, 0.)
            gt_area = (gt_boxes[i, 2] - gt_boxes[i, 0] + 1) * (gt_boxes[i, 3] - gt_boxes[i, 1] + 1)
            ioa_mat[i, :] = intersection / gt_area
        return ioa_mat

    def _extract_glimpse(
        self,
        input: Tensor, # [B, C, H, W]
        size: Tuple[int, int],
        offsets: Tensor, # [B, 2]
        centered=False, 
        normalized=False, 
        mode='bilinear', 
        padding_mode='zeros'
        ) -> Tensor:
        W, H = input.size(-1), input.size(-2)

        if normalized and centered:
            offsets = (offsets + 1) * offsets.new_tensor([W/2, H/2])
        elif normalized:
            offsets = offsets * offsets.new_tensor([W, H])
        elif centered:
            raise ValueError(
                f'Invalid parameter that offsets centered but not normlized')

        h, w = size
        xs = torch.arange(0, w, dtype=input.dtype,
                        device=input.device) - (w - 1) / 2.0
        ys = torch.arange(0, h, dtype=input.dtype,
                        device=input.device) - (h - 1) / 2.0

        vy, vx = torch.meshgrid(ys, xs)
        grid = torch.stack([vx, vy], dim=-1)  # h, w, 2

        offsets_grid = offsets[:, None, None, :] + grid[None, ...]

        # normalised grid to [-1, 1]
        offsets_grid = (
            offsets_grid - offsets_grid.new_tensor([W/2, H/2])) / offsets_grid.new_tensor([W/2, H/2])

        return torch.nn.functional.grid_sample(
            input, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode)

    def _extract_glimpses(
        self,
        input: Tensor, # [B, C, H, W]
        size: Tuple[int, int],
        offsets: Tensor, # [B, max_corners, 2]
        centered=False, 
        normalized=False, 
        mode='bilinear', 
        padding_mode='zeros'
        ) -> Tensor:
        patches = [] # [max_corners, B, c, size, size]
        for i in range(offsets.size(-2)):
            patch = self._extract_glimpse(input, size, offsets[:, i, :], centered, normalized, mode) # [B, c, size, size]
            patches.append(patch)
        return torch.stack(patches, dim=1) # [B, max_corners, c, size, size]

## 3.2 - Convolutional Neural Block 

In [8]:
# build the MLP head
class TwoMLPHead(nn.Module):
    def __init__(self, in_features=512):
        super().__init__()
        self.flatten = nn.Flatten()
        self.loc_fc = nn.Linear(in_features, 4)
        self.cls_fc = nn.Linear(in_features, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.flatten(x)
        box = self.sigmoid(self.loc_fc(x))
        score = self.sigmoid(self.cls_fc(x))
        return box, score

In [9]:
# setup resnet-34
resnet = torchvision.models.resnet34(pretrained=True)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet.fc = TwoMLPHead()

# transfer learning
for param in resnet.parameters():
    param.requires_grad = False
resnet.conv1.weight.requires_grad = True
for name, weight in resnet.fc.named_parameters():
    weight.requires_grad = True

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

In [10]:
# build a light-weight CNN --DEPRECATED
class LightWeightCNN(nn.Module):
    def __init__(self, in_channels, out_channels=4):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        conv_block = []
        channels = [in_channels, 8, 16, 32]
        for i in range(1, len(channels)):
            conv = nn.Sequential(*[
                nn.Conv2d(in_channels=channels[i-1], out_channels=channels[i], kernel_size=3, stride=1, padding="same", bias=False),
                nn.BatchNorm2d(num_features=channels[i]),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=channels[i], out_channels=channels[i], kernel_size=3, stride=1, padding="same", bias=False),
                nn.BatchNorm2d(num_features=channels[i]),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
            ])
            conv_block += list(conv)
        self.conv_block = nn.Sequential(*conv_block)
        self.bottleneck = O.misc.ConvNormActivation(in_channels=channels[-1], out_channels=out_channels, kernel_size=1, stride=1, padding="same")
        self.global_avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.mlp_head = TwoMLPHead(out_channels)
        
    def forward(self, x):
        x = self.conv_block(x)
        x = self.bottleneck(x)
        x = self.global_avg_pool(x)
        x = self.mlp_head(x)
        return x

## 3.3 - Main Model

In [11]:
# build loss function DIoU
class DIoULoss(nn.Module):
    """
    Distance Intersection over Union Loss (Zhaohui Zheng et. al)
    https://arxiv.org/abs/1911.08287
    Args:
        input, target (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
        eps (float): small number to prevent division by zero
    """
    __constant__ = ["none", "sum", "mean"]
    
    def __init__(
        self,
        eps: float = 1e-7,
        reduction: Optional[str] = None,
        weights: Optional[Tensor] = None
        ):
        super(DIoULoss, self).__init__()
        self.eps = eps
        self.reduction = reduction
        self.weights = weights
    
    def forward(
        self,
        input: Tensor,
        target: Tensor
        ) -> Tensor:
        intsct, union = self._loss_inter_union(input, target)
        iou = intsct / (union + self.eps)
        
        # smallest enclosing box
        x1, y1, x2, y2 = input.unbind(dim=-1)
        x1g, y1g, x2g, y2g = target.unbind(dim=-1)
        xc1 = torch.min(x1, x1g)
        yc1 = torch.min(y1, y1g)
        xc2 = torch.max(x2, x2g)
        yc2 = torch.max(y2, y2g)
        
        # the diagonal distance of the smallest enclosing box squared
        diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + self.eps
        
        # centers of boxes
        x_p = (x2 + x1) / 2
        y_p = (y2 + y1) / 2
        x_g = (x1g + x2g) / 2
        y_g = (y1g + y2g) / 2
        
        # the distance between boxes' centers squared.
        centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
        
        # distance between boxes' centers squared.
        loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
        
        # eqn. (7)
        loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
        if self.weights is not None:
            loss = loss * self.weights
        loss = loss[torch.nonzero(loss, as_tuple=True)]
        
        if self.reduction == "mean":
            loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
        elif self.reduction == "sum":
            loss = loss.sum()
        return loss
    
    def _loss_inter_union(
        self,
        boxes1: torch.Tensor,
        boxes2: torch.Tensor
        ) -> Tuple[torch.Tensor, torch.Tensor]:

        x1, y1, x2, y2 = boxes1.unbind(dim=-1)
        x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

        # Intersection keypoints
        xkis1 = torch.max(x1, x1g)
        ykis1 = torch.max(y1, y1g)
        xkis2 = torch.min(x2, x2g)
        ykis2 = torch.min(y2, y2g)

        intsctk = torch.zeros_like(x1)
        mask = (ykis2 > ykis1) & (xkis2 > xkis1)
        intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
        unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk

        return intsctk, unionk

In [12]:
# build the total model
class CCNN(nn.Module):
    def __init__(
        self,
        predictor: nn.Module,
        roi_transform: Any = None,
        
        anchor_size: Optional[Tuple[int, int]] = (31, 31),
        max_corners: int = 600, 
        quality_level: int = 0.002, 
        min_distance: int = 11,
        
        pos_thresh: float = 0.7,
        neg_thresh: float = 0.2,
        
        score_thresh: float = 0.5,
        loss_reg_cls_ratio: float = 0.80,
        nms_thresh: float = 0.6
        ):
        super().__init__()
        self.score_thresh = score_thresh
        self.reg_ratio = loss_reg_cls_ratio
        self.nms_thresh = nms_thresh
        
        self.anchor_generator = AnchorGenerator(anchor_size, max_corners, quality_level, min_distance)
        self.corner_proposal = CornerProposal(max_corners, min_distance, pos_thresh, neg_thresh)
        self.roi_transform = roi_transform
        self.predictor = predictor
    
    def forward(
        self, 
        images: Optional[List[Tensor]], 
        targets: Optional[List[Dict[str, Tensor]]] = None
        ) -> Union[Tensor, List[Dict[str, Tensor]]]:
        torch._assert(images is not None, "[ERROR] Images cannot be missing")
        if self.training:
            torch._assert(targets is not None, "[ERROR] Targets should not be none during training")
            targets = self._stack_targets(targets) # [B, N, 4], [B, N]
        
        anc_bases = self.anchor_generator(images)
        if images is not Tensor:
            images = torch.stack(images)
        h, w = images.size(dim=-2), images.size(dim=-1)
        rois, roi_cls, roi_gts, roi_uplefts = self.corner_proposal(images, anc_bases, targets)
        if self.roi_transform:
            rois = self.roi_transform(rois)
        
        if self.training:
            locs, scores = self.predictor(rois) # [B, out_channels, 1, 1]
            locs = self._decode_loc(locs, roi_uplefts) # [~2xB, 4]
            scores = torch.flatten(scores, start_dim=0) # [~2xB]
            loss = self.compute_loss(locs, scores, roi_gts, roi_cls)
            return loss
        
        detections = []
        for b, ins_rois in enumerate(rois):
            locs, scores = self.predictor(ins_rois)
            locs = self._decode_loc(locs, roi_uplefts[b]) # [~2xB, 4]
            scores = torch.flatten(scores, start_dim=0) # [~2xB]
            locs, scores, labels = self.postprocess_detections(locs, scores, (h, w))
            detections.append(self._one_detection(locs, scores, labels))
        return detections
            
    def compute_loss(self, locs, scores, gts, labels) -> Tensor:
        loc_loss_fn = DIoULoss(reduction="mean", weights=labels)
        cls_loss_fn = nn.BCELoss()
        cls_loss = cls_loss_fn(scores, labels.float())
        loc_loss = loc_loss_fn(locs, gts)
        return (1 - self.reg_ratio) * cls_loss + self.reg_ratio * loc_loss
    
    def _stack_targets(
        self, 
        targets: List[Dict[str, Tensor]]
        ) -> Dict[str, Tensor]:
        target_stack = {}
        for k in targets[0].keys():
            tensors = (targets[0][k],)
            for i in range(1, len(targets)):
                tensors = tensors + (targets[i][k],)
            target_stack[k] = torch.stack(tensors)
        return target_stack
        
    def _one_detection(self, locs, scores, labels) -> Dict[str, Tensor]:
        return {
            "boxes": locs, # [~2xB, 4]
            "scores": scores, # [~2xB]
            "labels": labels # [~2xB]
        }
    
    def _decode_loc(self, locs, roi_uplefts):
        locs = locs * (self.corner_proposal.min_distance - 1) + torch.cat((roi_uplefts, roi_uplefts), dim=1) # [~2xB, 4]
        locs[:, 2:] = torch.ceil(locs[:, 2:])
        locs[:, :2] = torch.floor(locs[:, :2])
        return locs # [~2xB, 4]
    
    def postprocess_detections(self, locs, scores, image_shape):
        locs = O.clip_boxes_to_image(locs, image_shape)
        labels = torch.where(scores.double() > self.score_thresh, 1, 0) # [~2xB]

        # remove low scoring boxes
        keep = torch.where(scores > self.score_thresh)
        locs, labels, scores = locs[keep], labels[keep], scores[keep]

        # remove empty boxes
        keep = O.remove_small_boxes(locs, min_size=1e-3)
        locs, scores, labels = locs[keep], scores[keep], labels[keep]

        # non-maximum suppression
        keep = O.nms(locs, scores, self.nms_thresh)
        locs, scores, labels = locs[keep], scores[keep], labels[keep]

        return locs, scores, labels

# **4 - Training Phase**

## 4.1 - Metrics

In [13]:
# build the custom metrics for 
class SIRSTMetrics:
    def __init__(
        self, 
        iou_thresholds: List[float] = [0.0, 0.5, 1.0],
        eps: float = 1e-7
        ):
        self.true_pos = [0] * len(iou_thresholds)
        self.false_pos = [0] * len(iou_thresholds)
        self.n_preds = 0
        self.n_gts = 0
        self.iou_thresholds = iou_thresholds
        self.eps = eps
    
    def compute(self) -> Dict[str, Tensor]:
        true_pos = Tensor(self.true_pos, device=self._device)
        false_pos = Tensor(self.false_pos, device=self._device)
        detect_rate = true_pos / (self.n_gts + self.eps) # true / n_targets
        false_alarm = false_pos / (self.n_preds + self.eps) # false / n_preds
        return {f"detection_rate_{self.iou_thresholds[i]}": detect_rate[i] for i in range(len(self.iou_thresholds))}, {f"false_alarm_rate_{self.iou_thresholds[i]}": false_alarm[i] for i in range(len(self.iou_thresholds))}
    
    def update(
        self, 
        preds: List[Dict[str, Tensor]], # [N, ...]
        targets: List[Dict[str, Tensor]] # [M, ...]
        ):
        self._device = targets[0]["boxes"].device
        max_preds = len(max(preds, key=lambda r: len(r["boxes"]))["boxes"])
        max_targets = len(max(targets, key=lambda r: len(r["boxes"]))["boxes"])
        ious_mat = torch.zeros((len(preds), max_preds, max_targets), device=self._device) # [B, max_preds, max_targets]
        for i in range(len(preds)):
            pred = preds[i] # [N, 4]
            target = targets[i] # [M, 4]
            ious_mat[i, :] = O.box_iou(pred["boxes"], target["boxes"]) # [N, M]
            self.n_preds += pred["boxes"].size(dim=0)
            self.n_gts += target["boxes"].size(dim=0)
        for i, iou_threshold in enumerate(self.iou_thresholds):
            true_pos_inds = torch.where(ious_mat > iou_threshold)
            false_pos_inds = torch.where(ious_mat <= iou_threshold)
            self.true_pos[i] += len(true_pos_inds[0])
            self.false_pos[i] += len(false_pos_inds[0])

## 4.2 - Setup Live Tensorboard

In [14]:
# dowload ngrok to launch tensorboard
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ./ngrok-stable-linux-amd64.zip

--2023-01-02 08:15:18--  https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
Resolving bin.equinox.io (bin.equinox.io)... 52.202.168.65, 54.237.133.81, 18.205.222.128, ...
Connecting to bin.equinox.io (bin.equinox.io)|52.202.168.65|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13832437 (13M) [application/octet-stream]
Saving to: ‘ngrok-stable-linux-amd64.zip’


2023-01-02 08:15:23 (2.63 MB/s) - ‘ngrok-stable-linux-amd64.zip’ saved [13832437/13832437]

Archive:  ./ngrok-stable-linux-amd64.zip
  inflating: ngrok                   


In [15]:
# add auth-token
!./ngrok authtoken 2IuQZfwfPwjKtQzNIBEzatlGJ1U_5qibyn8Ef9htdqoBus2Kw

Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml


In [16]:
# launch and tunnel tensorboard
import os
import multiprocessing
 
pool = multiprocessing.Pool(processes = 10)
results_of_processes = [pool.apply_async(os.system, args=(cmd, ), callback = None )
                        for cmd in [
                        f"tensorboard --logdir ./runs/ --host 0.0.0.0 --port 6006 &",
                        "./ngrok http 6006 &"
                        ]]

2023-01-02 08:15:31.475112: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-01-02 08:15:31.554954: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-01-02 08:15:31.555883: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.10.1 at http://0.0.0.0:6006/ (Press CTRL+C to quit)


In [17]:
# curl ngrok port
import time
time.sleep(10) # wait for tensorboard host
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

http://5a94-35-233-141-223.ngrok.io


## 4.3 - Trainer: Model Wrapper

In [18]:
# build trainer to wrap model
from cv_sirst.datasets.datasets import SIRSTDataset
from cv_sirst.trainers import callbacks as cb

class Trainer:
    MAX_LOAD_SIZE = 8 # 15GB VRAM

    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        dataset: SIRSTDataset,
        batch_size: int = 32,
        valid_ratio: int = 0.05,
        valid_freq: int  = 1,
        tensorboard: SummaryWriter = None,
        output: Union[str, None] = None,
        input: Union[str, None] = None,
        log_freq_gradient: int = 10,
        metric: Any = mAP(iou_thresholds=[0.1*i for i in range(1, 6)], rec_thresholds=list(np.linspace(0, 1, 11)), max_detection_thresholds=100),
        lr_scheduler=None
    ):
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        
        self.dataset = dataset
        self.batch_size = batch_size
        self.valid_ratio = valid_ratio
        self.valid_freq = valid_freq
        self.data_loader, self.data_loader_valid = self._train_test_split()

        self.input = input
        self.output = output
        
        self.tensorboard = tensorboard
        self.save_model = cb.BestCheckpoint(output=output)
        self.display = cb.Displayer()
        self.metric = metric

        self.log_freq_gradient = log_freq_gradient

        self.model.to(self.device)

    def fit(
        self,
        start_epoch: int = 0, # inclusive
        end_epoch: int = 20, # exclusive
        ):
        warm_up=True
        for epoch in range(start_epoch, end_epoch):
            if not warm_up:
                print() # vanish \r
            else:
                warm_up=False
            print(f"Epoch {epoch+1}/{end_epoch}")
            self.train_one_epoch(epoch)

    def train_one_epoch(self, epoch):
        for iter, (images, targets) in enumerate(self.data_loader):
            self.train_one_step(iter, images, targets, epoch)
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            
    def train_one_step(self, iter, images, targets, epoch):
        ts = time.time()
        self.model.train()
        
        self.optimizer.zero_grad()
        chunks = [(images[i:i+self.MAX_LOAD_SIZE], targets[i:i+self.MAX_LOAD_SIZE]) for i in range(0, len(images), self.MAX_LOAD_SIZE)]
        for imgs, tars in chunks:
            imgs = list(image.to(self.device) for image in imgs)
            tars = [{k: v.to(self.device) for k, v in t.items()} for t in tars]
            
            losses = self.model(imgs, tars)
            losses = losses / len(chunks) # normalize across chunks

            del imgs
            del tars
            gc.collect()
            torch.cuda.empty_cache()

            losses.backward()
        
        if self.log_freq_gradient != 0 and iter % self.log_freq_gradient == 0:
            for name, weight in self.model.named_parameters():
                if weight.requires_grad:
                    self.tensorboard.add_histogram(name + ".grad", weight.grad, epoch * len(self.data_loader) + iter)
        self.optimizer.step()

        te = time.time()
        valid_losses = None
        if iter % int(self.valid_freq * len(self.data_loader)) == 0 and iter != 0:
            valid_losses = self.validate()
            for i, valid_loss in enumerate(valid_losses):
                self.tensorboard.add_scalar("valid/metric_{i}", valid_loss, epoch * len(self.data_loader) + iter)
                self.save_model.quality_save((sum(valid_loss.values()) / len(valid_loss)).item(), epoch, self.model, self.optimizer, output=f"metric_{i}_" + self.output)
        
        valid_loss = None
        if valid_losses is not None:
            valid_loss = valid_losses[0]
            valid_loss = (sum(valid_loss.values()) / len(valid_loss)).item()
        
        self.display(len(self.data_loader), iter+1, te-ts, losses.item(), valid_loss=valid_loss)
        self.tensorboard.add_scalar("train/loss", losses.item(), epoch * len(self.data_loader) + iter)
        #self.tensorboard.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], epoch * len(self.data_loader) + iter)
        self.save_model(losses.item(), epoch, self.model, self.optimizer)

    def validate(self):
        self.model.eval()
        with torch.no_grad():
            for images, targets in self.data_loader_valid:
                chunks = [(images[i:i+self.MAX_LOAD_SIZE], targets[i:i+self.MAX_LOAD_SIZE]) for i in range(0, len(images), self.MAX_LOAD_SIZE)]
                for imgs, tars in chunks:
                    imgs = list(image.to(self.device) for image in imgs)
                    tars = [{k: v.to(self.device) for k, v in t.items()} for t in tars]
                    detections = self.model(imgs, tars)
                    
                    self.metric.update(detections, tars)
                    del imgs
                    del tars
                    gc.collect()
                    torch.cuda.empty_cache()
        return self.metric.compute()
    
    @torch.inference_mode()
    def evaluate(self, data_loader_test: DataLoader, print_freq=10) -> Dict[str, Tensor]:
        with torch.no_grad():
            for images, targets in self.data_loader_test:
                chunks = [(images[i:i+self.MAX_LOAD_SIZE], targets[i:i+self.MAX_LOAD_SIZE]) for i in range(0, len(images), self.MAX_LOAD_SIZE)]
                for imgs, tars in chunks:
                    imgs = list(image.to(self.device) for image in imgs)
                    tars = [{k: v.to(self.device) for k, v in t.items()} for t in tars]
                    detections = self.model(imgs, tars)
                    
                    self.metric.update(detections, tars)
                    del imgs
                    del tars
                    gc.collect()
                    torch.cuda.empty_cache()
        return self.metric.compute()
        
    def _train_test_split(self) -> Tuple[DataLoader]:
        dataset_valid = copy.deepcopy(self.dataset)
        dataset_valid.transform.training = False
        dataset_valid.target_transform.training = False
        
        torch.manual_seed(1)
        split_idx = int(self.valid_ratio * len(self.dataset))
        indices = torch.randperm(len(self.dataset)).tolist()
        dataset = Subset(self.dataset, indices[:-split_idx])
        dataset_valid = Subset(dataset_valid, indices[-split_idx:])

        data_loader = DataLoader(
            dataset, batch_size=self.batch_size, 
            shuffle=True, collate_fn=lambda batch: tuple(zip(*batch)))
        valid_data_loader = DataLoader(
            dataset_valid, batch_size=self.batch_size,
            shuffle=False, collate_fn=lambda batch: tuple(zip(*batch)))
        
        return data_loader, valid_data_loader
    
    def __setattr__(self, __name: str, __value: Any) -> None:
        if __name == "input":
            if __value is not None:
                try:
                    checkpoint = torch.load(__value, map_location="cpu")
                    self.model.load_state_dict(checkpoint['model_state_dict'])
                    self.model.to(self.device)
                    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                    cb.optimizer_to(self.optimizer, self.device)
                except:
                    if self.model is None or self.optimizer is None:
                        print("[ERROR] - Must have the desired model/optimizer")
                    print("[ERROR] - The path would be wrong.")

        super(Trainer, self).__setattr__(__name, __value)
        
        if __name in ("dataset", "batch_size", "valid_ratio"):
            if __name == "batch_size":
                assert self.batch_size % self.MAX_LOAD_SIZE == 0
            try:
                self.data_loader, self.data_loader_valid = self._train_test_split()
            except AttributeError:
                pass

## 4.4 - Training Models

In [19]:
!nvidia-smi

Mon Jan  2 08:16:00 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    26W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [21]:
# constrast adjustment
class ContrastAdjustment:
    def __init__(self, factor=5):
        self.factor = factor
    
    def __call__(self, images):
        return torchvision.transforms.functional.adjust_contrast(images, self.factor)

# init model
model = CCNN(predictor=resnet, roi_transform=ContrastAdjustment(factor=5))

# init optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.NAdam(params, lr=0.001) # nadam
#optimizer = torch.optim.RMSprop(params, lr=0.001, weight_decay=0.9, momentum=0.9)
#lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1)

# apply he initialization
def weights_init(module: nn.Module):
    if isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
        if module.bias is not None:
            nn.init.constant_(module.bias, 0)
    elif isinstance(module, nn.BatchNorm2d):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)
    elif isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
        nn.init.constant_(module.bias, 0)

model.apply(weights_init)

CCNN(
  (anchor_generator): AnchorGenerator()
  (corner_proposal): CornerProposal()
  (predictor): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [27]:
# setup tensorboard
import os
PRJ_ROOT = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
import time
from torch.utils.tensorboard import SummaryWriter

root_logdir = "/kaggle/working/runs"
run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
writer = SummaryWriter(os.path.join(root_logdir, run_id))

# compile model with the trainer
ccnn_compiled = Trainer(
    model, optimizer, nudtsirst_train,
    batch_size=32, 
    valid_ratio=0.003,
    valid_freq=0.3, # valid per epoch
    tensorboard=writer,
    output=os.path.join(PRJ_ROOT, "logs", "checkpoints", "ccnn_v1248_0201.pth"),
    input=os.path.join(PRJ_ROOT, "logs", "checkpoints", "ccnn_v0431_0201.pth"),
    metric=SIRSTMetrics()
)




In [30]:
# training the model
ccnn_compiled.MAX_LOAD_SIZE = 32
ccnn_compiled.fit() # 20 epochs

Epoch 1/20
1180/1180 [100%] - eta: 12.49s - train_loss: 1.6100 - valid_metrics: 0.0298
Epoch 2/20
1180/1180 [100%] - eta: 11.50s - train_loss: 1.2245 - valid_metrics: 0.1100
Epoch 3/20
1180/1180 [100%] - eta: 11.24s - train_loss: 1.1001 - valid_metrics: 0.1850
Epoch 4/20
1180/1180 [100%] - eta: 11.71s - train_loss: 0.0812 - valid_metrics: 0.2434
Epoch 5/20
1180/1180 [100%] - eta: 10.46s - train_loss: 0.0457 - valid_metrics: 0.2982
Epoch 6/20
1180/1180 [100%] - eta: 10.59s - train_loss: 0.0113 - valid_metrics: 0.3074
Epoch 7/20
1180/1180 [100%] - eta: 12.49s - train_loss: 0.0087 - valid_metrics: 0.3125
Epoch 8/20
1180/1180 [100%] - eta: 13.11s - train_loss: 0.0052 - valid_metrics: 0.3395
Epoch 9/20
1180/1180 [100%] - eta: 13.02s - train_loss: 0.0041 - valid_metrics: 0.3381
Epoch 10/20
1180/1180 [100%] - eta: 14.01s - train_loss: 0.0032 - valid_metrics: 0.3265
Epoch 11/20
1180/1180 [100%] - eta: 09.50s - train_loss: 0.0035 - valid_metrics: 0.3149
Epoch 12/20
1180/1180 [100%] - eta: 11.32

# **5 - Testing Phase**

In [26]:
# setup test dataloader
data_loader_test = DataLoader(
    nudtsirst_test, batch_size=32, 
    shuffle=False, collate_fn=lambda batch: tuple(zip(*batch)))

In [33]:
# testing the model
ccnn_compiled.MAX_LOAD_SIZE = 32
ccnn_compiled.evaluate(data_loader_test)

({'detection_rate_0.0': tensor(0.8130), 'detection_rate_0.5': tensor(0.3341), 'detection_rate_1.0': tensor(0.0025)}, {'false_alarm_rate_0.0': tensor(0.2271), 'false_alarm_rate_0.5': tensor(0.8120), 'false_alarm_rate_1.0': tensor(0.9999)})
