In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchvision
from torchvision.models._utils import IntermediateLayerGetter
from transformer import Encoder, Decoder
from box_ops import box_cxcywh_to_xyxy, generalized_box_iou
from scipy.optimize import linear_sum_assignment
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Model

In [2]:
backbone = 'resnet50'

In [3]:
backbone = getattr(torchvision.models, backbone)(pretrained=True)
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
backbone

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
return_layers = {'layer4': 0}
img_size = 256
channel = 2048
model_dim = 256
feature_size = int(256 / 32)
num_class = 101
N = 128

In [5]:
test_img = torch.rand((2, 3, img_size, img_size))
targets = {
    "labels" : torch.tensor([[1, 1, 3], [3, 3, 2]]),
    "boxes" : torch.rand((2, 3, 4))
}

In [6]:
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
positional_encoding = nn.Parameter(torch.rand(feature_size ** 2, channel))
linear = nn.Linear(channel, model_dim)
query = nn.Parameter(torch.rand(test_img.shape[0], N, model_dim))
ffn_box = nn.Sequential(
    nn.Linear(model_dim, model_dim),
    nn.ReLU(),
    nn.Linear(model_dim, model_dim),
    nn.ReLU(),
    nn.Linear(model_dim, 4),
    nn.ReLU()
)
ffn_class = nn.Sequential(
    nn.Linear(model_dim, model_dim),
    nn.ReLU(),
    nn.Linear(model_dim, model_dim),
    nn.ReLU(),
    nn.Linear(model_dim, num_class),
    nn.Softmax(dim=-1)
)

In [7]:
normalize(test_img)
features = backbone(test_img)[0]
features = features.view((features.shape[0], channel, -1)).contiguous().transpose(1,2)
features += positional_encoding
features = linear(features)

print(features.shape)

torch.Size([2, 64, 256])


In [8]:
encoder = Encoder()
decoder = Decoder()

In [9]:
encode = encoder(features)

print(encode.shape)

torch.Size([2, 64, 256])


In [10]:
decode_list = decoder(query, encode)

print(len(decode_list), decode_list[0].shape)

10 torch.Size([2, 128, 256])


In [11]:
out_list = []

for decode in decode_list:
    temp = {}
    temp['pred_boxes'] = ffn_box(decode)
    temp['pred_logits'] = ffn_class(decode)
    out_list.append(temp)

# Select Matching

In [12]:
class HungarianMatcher(nn.Module):
    def __init__(self, cost_class: float = 1, cost_box: float = 1, cost_giou: float = 1):
        super().__init__()
        self.cost_class = cost_class
        self.cost_box = cost_box
        self.cost_giou = cost_giou
        
    @torch.no_grad()
    def forward(self, out, targets):
        pred_logits = out["pred_logits"]
        pred_boxes = out["pred_boxes"]
        target_logits = targets["labels"]
        target_boxes = targets["boxes"]
        bs, num_queries = pred_logits.shape[:2]
        
        out_prob = pred_logits.flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_box = pred_boxes.flatten(0, 1)  # [batch_size * num_queries, 4]

        tgt_ids = torch.cat([v for v in target_logits]) # [batch_size * num_obj]
        tgt_box = torch.cat([v for v in target_boxes]) # [batch_size * num_obj, 4]

        # cost :
        #     row : pred_querys
        #     col : target_obj
        cost_class = -out_prob[:, tgt_ids]                                                          # [batch_size * num_queries, batch_size * num_obj]
        cost_box = torch.cdist(out_box, tgt_box, p=1)                                               # [batch_size * num_queries, batch_size * num_obj] 
        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_box), box_cxcywh_to_xyxy(tgt_box))  # [batch_size * num_queries, batch_size * num_obj]

        # Final cost matrix
        C = self.cost_box * cost_box + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu() # [batch_size, num_queries, batch_size * num_obj]

        sizes = [len(v) for v in target_boxes]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        result = []
        for i, j in indices:
            i = torch.as_tensor(i, dtype=torch.int64)
            j = torch.as_tensor(j, dtype=torch.int64)
            result.append(i[j])
        return result


In [13]:
matcher = HungarianMatcher()

In [14]:
matchings = matcher(out_list[-1], targets)

In [15]:
matchings

[tensor([11, 22, 31]), tensor([31,  8, 72])]

# Loss

In [31]:
class HungarianLoss(nn.Module):
    def __init__(self, cost_class: float = 1, cost_box: float = 1, cost_giou: float = 1, cost_no_obj: float = 0.1):
        super().__init__()
        self.cost_class = cost_class
        self.cost_box = cost_box
        self.cost_giou = cost_giou
        self.cost_no_obj = cost_no_obj

    @torch.no_grad()
    def forward(self, out_list, targets, matchings):
        loss = 0
        for out in out_list:
            loss += self.cal_loss(out, targets, matchings)
        return loss / len(out_list)
    
    def cal_loss(self, out, targets, matchings):
        loss = 0
        for i, matching in enumerate(matchings):
            box_loss = self.cal_box_loss(out["pred_boxes"][i], targets["boxes"][i], matching)
            class_loss = self.cal_class_loss(out["pred_logits"][i], targets["labels"][i], matching)
            loss += box_loss
            loss += self.cost_class * class_loss
        return loss
    
    def cal_box_loss(self, out, target, matching):
        # out['pred_logits'] : [num_queries, 101]
        # out['pred_boxes'] : [num_queries, 4]
        # target['labels'] : [num_obj]
        # target['boxes'] : [num_obj, 4]
        # matching : [num_obj]

        num_boxes = matching.shape[0]
        pred_boxes = out[matching, :]
        l1_loss = F.l1_loss(pred_boxes, target)
        cost_giou = generalized_box_iou(box_cxcywh_to_xyxy(pred_boxes), box_cxcywh_to_xyxy(target))
        giou_loss = 1 - torch.diag(cost_giou)
        
        loss = (self.cost_box * l1_loss.sum() + self.cost_giou * giou_loss.sum()) / num_boxes
        return loss
        
    def cal_class_loss(self, out, target, matching):
        num_queries = out.shape[0]
        target_labels = torch.zeros((num_queries), dtype=torch.int64)
        cost_no_obj = torch.ones((num_queries)) * self.cost_no_obj
        target_labels[matching] = target
        cost_no_obj[matching] = 1
        
        class_loss = -torch.log(out[:, target_labels])
        class_loss *= cost_no_obj
        loss = class_loss.mean()
        return loss

In [32]:
Loss = HungarianLoss()

In [33]:
loss = Loss(out_list, targets, matchings)

In [34]:
loss

tensor(4.7395)