# Imports

In [None]:
import os
import datetime
from copy import deepcopy

import numpy as np
import pandas as pd

from pprint import pprint

import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch import Tensor
import torchvision.transforms as T
from torch.nn import Conv2d, Linear, AvgPool2d, Sigmoid
from torch import optim
from torchvision.ops.misc import ConvNormActivation

from PIL import Image
from typing import Union, Dict, List, Tuple, Any, Optional

# Logging preparation

In [None]:
def log(message, end='\n', print_out=False, file_path=f'log.txt'):
    if print_out:
        print(message, end=end)
    with open(file_path, mode='a+', encoding='utf-8') as f:
        f.write(message + end)

# Model preparation

In [None]:
class DilatedModule(nn.Module):
    def __init__(self, dilation, padding):
        super().__init__()
        self.dilation = dilation
        self.dilated_conv = ConvNormActivation(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=padding, padding_mode='replicate', dilation=dilation)
        self.conv1 = ConvNormActivation(in_channels=64, out_channels=32, kernel_size=1, stride=1)
        self.conv2 = ConvNormActivation(in_channels=64, out_channels=32, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        y = self.dilated_conv(x)
        y = self.conv1(y)
        z = torch.cat((x, y), 1)
        z = self.conv2(z)
        return z

In [None]:
class PassthroughModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvNormActivation(in_channels=64, out_channels=16, kernel_size=1, stride=1)
        self.conv2 = ConvNormActivation(in_channels=192, out_channels=128, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        
    def forward(self, x, y):
        x = self.conv1(x)
        x = torch.cat(
            [x[:,:,::2,::2], x[:,:,::2,1::2], x[:,:,1::2,::2], x[:,:,1::2,1::2]], 1)
        
        x = torch.cat((x, y), 1)
        x = self.conv2(x)
        return x

In [None]:
class Simpeff(torch.nn.Module):  # fixed
    def __init__(self):
        super().__init__()
        # 0-8
        self.conv0 = ConvNormActivation( 1, 16, (3, 3), 1, padding=1, padding_mode='replicate')
        self.conv1 = ConvNormActivation(16, 32, (3, 3), 2, padding=1, padding_mode='replicate')
        self.conv2 = ConvNormActivation(32, 16, (1, 1), 1)
        self.conv3 = ConvNormActivation(16, 32, (3, 3), 1, padding=1, padding_mode='replicate')
        self.conv4 = ConvNormActivation(32, 64, (3, 3), 2, padding=1, padding_mode='replicate')

        self.conv5 = ConvNormActivation(64, 32, (1, 1), 1)
        self.conv6 = ConvNormActivation(32, 64, (3, 3), 1, padding=1, padding_mode='replicate')
        self.conv7 = ConvNormActivation(64, 32, (1, 1), 1)
        self.conv8 = ConvNormActivation(32, 64, (3, 3), 1, padding=1, padding_mode='replicate')
        # 9
        self.conv9 = ConvNormActivation(64, 32, (1, 1), 1)
        # 10-13
        self.dila10_13 = DilatedModule(2, 2)
        # 14-17
        self.dila14_17 = DilatedModule(4, 4)
        # 18
        # concat
        # 19
        self.conv19 = ConvNormActivation(64, 32, (1, 1), 1)
        # 20
        # concat
        # 21-25
        self.conv21 = ConvNormActivation(64, 128, (3, 3), 2, padding=1, padding_mode='replicate')
        self.conv22 = ConvNormActivation(128, 64, (1, 1), 1)
        self.conv23 = ConvNormActivation(64, 128, (3, 3), 1, padding=1, padding_mode='replicate')
        self.conv24 = ConvNormActivation(128, 64, (1, 1), 1)
        self.conv25 = ConvNormActivation(64, 128, (3, 3), 1, padding=1, padding_mode='replicate')
        # 26
        # route
        # 27-30
        self.pass27_30 = PassthroughModule()
        # 31
        self.conv31 = ConvNormActivation(128, 256, (3, 3), 1, padding=1, padding_mode='replicate')
        # 32
        # N = boxes * (classes + 4 + 1)
        # N = 1 * (1 + 4 + 1)
        self.conv32 = ConvNormActivation(256, 6, (1, 1), 1)
        # 33
        # addition to produce desired output
        self.avgpool33 = AvgPool2d((128, 160))
        # 34
        # flatten
        self.fc35 = Linear(6, 4)
        self.sigmoid36 = Sigmoid()
        
    def forward(self, x):
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        y = self.conv9(x)
        z = self.dila10_13(y)
        m = self.dila14_17(z)
        # 18
        n = torch.cat((y, z), 1)
        p = self.conv19(n)
        # 20
        q = torch.cat((m, p), 1)
        q = self.conv21(q)
        q = self.conv22(q)
        q = self.conv23(q)
        q = self.conv24(q)
        q = self.conv25(q)
        q = self.pass27_30(x, q)
        q = self.conv31(q)
        q = self.conv32(q)
        q = self.avgpool33(q)
        # 34
        q = torch.flatten(q, 1)
        q = self.fc35(q)
        q = self.sigmoid36(q)
        return q

# Loss preparation

In [None]:
# build loss function DIoU
class DIoULoss(torch.nn.Module):
    """
    Distance Intersection over Union Loss (Zhaohui Zheng et. al)
    https://arxiv.org/abs/1911.08287
    Args:
        input, target (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
        eps (float): small number to prevent division by zero
    """
    __constant__ = ["none", "sum", "mean"]
    
    def __init__(
        self,
        eps: float = 1e-7,
        reduction: Optional[str] = None,
        weights: Optional[Tensor] = None
        ):
        super(DIoULoss, self).__init__()
        self.eps = eps
        self.reduction = reduction
        self.weights = weights
    
    def forward(
        self,
        input: Tensor,
        target: Tensor
        ) -> Tensor:
        intsct, union = self._loss_inter_union(input, target)
        iou = intsct / (union + self.eps)
        
        # smallest enclosing box
        x1, y1, x2, y2 = input.unbind(dim=-1)
        x1g, y1g, x2g, y2g = target.unbind(dim=-1)
        xc1 = torch.min(x1, x1g)
        yc1 = torch.min(y1, y1g)
        xc2 = torch.max(x2, x2g)
        yc2 = torch.max(y2, y2g)
        
        # the diagonal distance of the smallest enclosing box squared
        diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + self.eps
        
        # centers of boxes
        x_p = (x2 + x1) / 2
        y_p = (y2 + y1) / 2
        x_g = (x1g + x2g) / 2
        y_g = (y1g + y2g) / 2
        
        # the distance between boxes' centers squared.
        centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
        
        # distance between boxes' centers squared.
        loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
        
        # eqn. (7)
        loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
        if self.weights is not None:
            loss = loss * self.weights

        if self.reduction == "mean":
            loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
        elif self.reduction == "sum":
            loss = loss.sum()
        return loss
    
    def _loss_inter_union(
        self,
        boxes1: torch.Tensor,
        boxes2: torch.Tensor
        ) -> Tuple[torch.Tensor, torch.Tensor]:

        x1, y1, x2, y2 = boxes1.unbind(dim=-1)
        x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

        # Intersection keypoints
        xkis1 = torch.max(x1, x1g)
        ykis1 = torch.max(y1, y1g)
        xkis2 = torch.min(x2, x2g)
        ykis2 = torch.min(y2, y2g)

        intsctk = torch.zeros_like(x1)
        mask = (ykis2 > ykis1) & (xkis2 > xkis1)
        intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
        unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk

        return intsctk, unionk

# Data preparation

In [None]:
# build data transforms
class TransformComposer:
    def __init__(
        self, 
        training: bool = True,
        transforms: Union[Dict[str, List], None] = {"default": [T.ToTensor(), T.ConvertImageDtype(torch.float32)], "train": [], "eval": []}
        ):
        self.training = training
        self.transforms = transforms if transforms is not None else {"default": [], "train": [], "eval": []}
    
    def __call__(self, sample):
        if self.training:
            transform = T.Compose(self.transforms.get("default", []) + self.transforms.get("train", []))
            return transform(sample)
        else:
            transform = T.Compose(self.transforms.get("default", []) + self.transforms.get("eval", []))
            return transform(sample)
        
        
class XywhToXyxy:
    def __call__(self, xywh_boxes):
        xyxy_boxes = []
        for i in range(len(xywh_boxes)):
            xmin = xywh_boxes[i][0]
            ymin = xywh_boxes[i][1]
            xmax = xywh_boxes[i][2] + xmin
            ymax = xywh_boxes[i][3] + ymin
            xyxy_boxes.append([xmin, ymin, xmax, ymax])
        return np.array(xyxy_boxes)


class SIRSTDataset(Dataset):
    def __init__(
        self, 
        annotations_file: str, 
        img_dir: str, 
        transform: Union[TransformComposer, None] = None, 
        target_transform: Union[TransformComposer, None] = None
        ):
        self.img_labels = pd.read_csv(annotations_file).to_numpy()
        self.img_dir = img_dir
        self.transform = transform if transform else TransformComposer()
        self.target_transform = target_transform if target_transform else TransformComposer(transforms=None)

        
class NUDTSIRSTDataset(SIRSTDataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        super().__init__(annotations_file, img_dir, transform, target_transform)
        self.training = True

    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(
        self, ix: int
        ) -> Tuple[Tensor, Dict[str, Tensor]]:
        img_path = os.path.join(self.img_dir, 
                                os.path.join(self.img_dir, self.img_labels[ix, 0]))
        img = Image.open(img_path.replace('\\', os.sep)).convert('L')
        boxes = self.img_labels[ix:ix+1, 1:-1] # exclude contrast
        
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            boxes = self.target_transform(boxes)
        
        # FIXED: transform boxes from 1024x1280 to 1x1
        boxes = boxes / np.array([1280, 1024, 1280, 1024])
        
        target = {}
        target['boxes'] = torch.from_numpy(boxes.astype('float32'))
        target['labels'] = torch.from_numpy(np.ones(len(boxes))).type(torch.int64)
        target['image_id'] = torch.tensor([ix])
        
        if not self.training:
            # PASSED
            # target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
            # target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64)
            pass
        
        # FIXED
        target['boxes'] = target['boxes'].reshape(4)
        target['scores'] = torch.ones(1)
        
        return img, target
    
    def eval(self):
        self.training = False

In [None]:
# setup training set
nudtsirst_train = NUDTSIRSTDataset(
   '/kaggle/input/nudtsirst/annotation_train.csv', 
   '/kaggle/input/nudtsirst/nudtsirst', 
   transform=TransformComposer(), 
   target_transform=TransformComposer(
       transforms={"default": [XywhToXyxy()],}
   )
)

nudtsirst_train

In [None]:
# setup test set
nudtsirst_test = NUDTSIRSTDataset(
   '/kaggle/input/nudtsirst/annotation_test.csv', 
   '/kaggle/input/nudtsirst/nudtsirst', 
   transform=TransformComposer(), 
   target_transform=TransformComposer(
       transforms={"default": [XywhToXyxy()],}
   )
)

nudtsirst_test.eval()

# Test preparation

In [None]:
# build the custom metrics for SIRST 
class SIRSTMetrics:
    def __init__(
        self, 
        iou_thresholds: List[float] = [0.0, 0.5, 1.0],
        eps: float = 1e-7
        ):
        self.true_pos = [0] * len(iou_thresholds)
        self.false_pos = [0] * len(iou_thresholds)
        self.n_preds = 0
        self.n_gts = 0
        self.iou_thresholds = iou_thresholds
        self.eps = eps
    
    def compute(self) -> Dict[str, Tensor]:
        true_pos = Tensor(self.true_pos, device=self._device)
        false_pos = Tensor(self.false_pos, device=self._device)
        detect_rate = true_pos / (self.n_gts + self.eps) # true / n_targets
        false_alarm = false_pos / (self.n_preds + self.eps) # false / n_preds
        return {f"detection_rate_{self.iou_thresholds[i]}": detect_rate[i] for i in range(len(self.iou_thresholds))}, {f"false_alarm_rate_{self.iou_thresholds[i]}": false_alarm[i] for i in range(len(self.iou_thresholds))}
    
    def update(
        self, 
        preds: List[Dict[str, Tensor]], # [N, ...]
        targets: List[Dict[str, Tensor]] # [M, ...]
        ):
        self._device = targets[0]["boxes"].device
        max_preds = len(max(preds, key=lambda r: len(r["boxes"]))["boxes"])
        max_targets = len(max(targets, key=lambda r: len(r["boxes"]))["boxes"])
        ious_mat = torch.zeros((len(preds), max_preds, max_targets), device=self._device) # [B, max_preds, max_targets]
        for i in range(len(preds)):
            pred = preds[i] # [N, 4]
            target = targets[i] # [M, 4]
            ious_mat[i, :] = O.box_iou(pred["boxes"], target["boxes"]) # [N, M]
            self.n_preds += pred["boxes"].size(dim=0)
            self.n_gts += target["boxes"].size(dim=0)
        for i, iou_threshold in enumerate(self.iou_thresholds):
            true_pos_inds = torch.where(ious_mat > iou_threshold)
            false_pos_inds = torch.where(ious_mat <= iou_threshold)
            self.true_pos[i] += len(true_pos_inds[0])
            self.false_pos[i] += len(false_pos_inds[0])

data_loader_test = DataLoader(nudtsirst_test, batch_size=BATCH_SIZE_TEST,
                              collate_fn=lambda batch: tuple(zip(*batch)))
metric = SIRSTMetrics()

In [None]:
# build test loop
def test(model, device):
    for batch_test_no, batch_test in enumerate(data_loader_test):
        pred_boxes = model(torch.stack(batch_test[0]).to(device))
        target = batch_test[1]
        pred = deepcopy(target)

        pred = list(pred)
        target = list(target)
        
        for i in range(len(pred)):
            pred[i]['boxes'] = torch.stack((pred_boxes[i],)).to(device)
            target[i]['boxes'] = torch.stack((target[i]['boxes'],)).to(device)

        metric.update(pred, target)

        # print('pred_boxes =', pred_boxes)
        # print('target =', list(target))
        # print()
        # print('pred =', list(pred))
        # break
        
        if batch_test_no % 50 == 0:
            log(f'Updated batch_test_no={batch_test_no}', print_out=True)
        
        torch.cuda.empty_cache()
        
    res = metric.compute()
    log(str(res), print_out=True)

# Saving and loading preparation

In [None]:
def save(model, epoch_no):
    file_name = f'saved_simpeff_model_{epoch_no}.pth'
    torch.save(model.state_dict(), file_name)
    log(f'Saved model: epoch_no={epoch_no}')

# Training preparation

In [None]:
EPOCHS = 8
BATCH_SIZE = 8
BATCH_SIZE_TEST = 8
data_loader_train = DataLoader(nudtsirst_train, batch_size=BATCH_SIZE, shuffle=True)
criterion = DIoULoss(reduction='mean')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
model = Simpeff().to(device)
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad])

In [None]:
training_losses = [list() for i in range(EPOCHS)]

In [None]:
for epoch_no in range(EPOCHS):
    log(f'Epoch: {epoch_no}', print_out=True)
    
    for batch_no, batch in enumerate(data_loader_train):
        log_check = True if batch_no%200==0 else False
        
        if log_check:
            log(f'Batch: {batch_no}', end='\t', print_out=True)
        
        optimizer.zero_grad()
        
        inp = batch[0].requires_grad_(requires_grad=True).to(device)
        
        target = batch[1]['boxes'].to(device)
        out = model(inp).to(device)

        loss = criterion(out, target).to(device)
        loss_mean = loss.to(device)
        loss_mean.backward()
        
        optimizer.step()
        
        if log_check:
            log(f'Loss mean: {loss_mean}', print_out=True)
        
        training_losses[epoch_no].append(loss_mean.detach().cpu().resolve_conj().resolve_neg().numpy())
        
        torch.cuda.empty_cache()
        
        # TEST
        # if batch_no == 4:
        #    break
    
    save(model, epoch_no)
    log(f'Training losses: {str(training_losses)}')
    test(model, device)


Epoch: 0
Batch: 0	Loss mean: 2.3072590827941895
Batch: 200	Loss mean: 1.0716334581375122
Batch: 400	Loss mean: 1.0623656511306763
Batch: 600	Loss mean: 1.060208797454834
Batch: 800	Loss mean: 1.070434808731079
Batch: 1000	Loss mean: 1.05238676071167
Batch: 1200	Loss mean: 1.0686883926391602
Batch: 1400	Loss mean: 1.0557748079299927
Batch: 1600	Loss mean: 1.0439832210540771
Batch: 1800	Loss mean: 1.0894758701324463
Batch: 2000	Loss mean: 1.0449298620224
Batch: 2200	Loss mean: 1.0640153884887695
Batch: 2400	Loss mean: 1.056121826171875
Batch: 2600	Loss mean: 1.047910213470459
Batch: 2800	Loss mean: 1.0540080070495605
Batch: 3000	Loss mean: 1.0756124258041382
Batch: 3200	Loss mean: 1.044912338256836
Batch: 3400	Loss mean: 1.0450358390808105
Batch: 3600	Loss mean: 1.0475590229034424
Batch: 3800	Loss mean: 1.058254361152649
Batch: 4000	Loss mean: 1.0462555885314941
Batch: 4200	Loss mean: 1.0467901229858398
Batch: 4400	Loss mean: 1.0575460195541382
Batch: 4600	Loss mean: 1.0690003633499146
U

# Re-evaluate the model with SIRST Metrics

In [None]:
device = torch.device("cuda")
model = Simpeff()
model.load_state_dict(torch.load('simpeff_v0000_0101.pth'))
model.to(device)

Simpeff(
  (conv0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=replicate)
  (conv2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=replicate)
  (conv5): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv7): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv9): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (dila10_13): DilatedModule(
    (dilated_conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), padding_m

In [None]:
test(model)

Updated batch_test_no=0
Updated batch_test_no=50
Updated batch_test_no=100
Updated batch_test_no=150
Updated batch_test_no=200
Updated batch_test_no=250
Updated batch_test_no=300
Updated batch_test_no=350
Updated batch_test_no=400
Updated batch_test_no=450
Updated batch_test_no=500
({'detection_rate_0.0': tensor(0.3951, device='cuda:0'), 'detection_rate_0.5': tensor(0.1572, device='cuda:0'), 'detection_rate_1.0': tensor(0., device='cuda:0')}, {'false_alarm_rate_0.0': tensor(0.6049, device='cuda:0'), 'false_alarm_rate_0.5': tensor(0.8428, device='cuda:0'), 'false_alarm_rate_1.0': tensor(1., device='cuda:0')})
