## Import

In [1]:
import os
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW, lr_scheduler
import transforms as T
from torchvision import datasets, ops
from torchvision.models.feature_extraction import create_feature_extractor
from einops import rearrange
from pycocotools import mask as coco_mask

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import linear_sum_assignment

## build dataset

In [2]:
def convert_coco_poly_to_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks

class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False):
        self.return_masks = return_masks

    def __call__(self, image, target):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        anno = target["annotations"]

        anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        if self.return_masks:
            segmentations = [obj["segmentation"] for obj in anno]
            masks = convert_coco_poly_to_mask(segmentations, h, w)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if self.return_masks:
            target["masks"] = masks
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        return image, target

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, transforms, return_masks):
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self._transforms = transforms
        self.prepare = ConvertCocoPolysToMask(return_masks)

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        img, target = self.prepare(img, target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target

def make_coco_transforms(image_set):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),
                    T.RandomResize(scales, max_size=1333),
                ])
            ),
            normalize,
        ])

    if image_set == 'val':
        return T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ])

    raise ValueError(f'unknown {image_set}')

def build(image_set):
    mode = 'instances'

    PATHS = {
        "train": (os.path.join('tiny_coco', "train2017"), os.path.join('tiny_coco', "annotations", f'{mode}_train2017.json')),
        "val": (os.path.join('tiny_coco', "val2017"), os.path.join('tiny_coco', "annotations", f'{mode}_val2017.json')),
    }
    
    img_folder, ann_file = PATHS[image_set]
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=False)
    return dataset

## build DataLoader

In [3]:
from typing import Optional, List
from torch import Tensor

class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        # type: (Device) -> NestedTensor # noqa
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)

def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        
        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

def collate_fn(batch):
    '''
    original dataloader return like: [(image1, annotations1), (image2, annotations2), ...]
    
    list(zip(*batch)) => [(image1, image2, ...), (annotations1, annotations2, ...)]
    
    then to  nested_tensor
    '''
    batch = list(zip(*batch))
    batch[0] = nested_tensor_from_tensor_list(batch[0])
    return tuple(batch)

DataLoader Return: NestedTensor, Tuple of (annotations) => Tuple(Tensors, masks), Tuple of (annotations)

In [4]:
train_dataset = build(image_set='train')
val_dataset = build(image_set='val')

train_dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=collate_fn, num_workers=0, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, collate_fn=collate_fn, num_workers=0)

print(f'\nNumber of training samples: {len(train_dataloader)}')
print(f'\nNumber of training samples: {len(val_dataloader)}')

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

Number of training samples: 25

Number of training samples: 25


### Test code

In [5]:
# NestedTensor, Tuple of (annotations) => Tuple(Tensors, masks), Tuple of (annotations)
samples, target = next(iter(train_dataloader))

print(samples.tensors.shape)
print(samples.mask.shape)

print(target[0])
print(target[1])

torch.Size([2, 3, 813, 1056])
torch.Size([2, 813, 1056])
{'boxes': tensor([[0.3879, 0.7730, 0.1526, 0.2436],
        [0.5000, 0.5572, 1.0000, 0.5269],
        [0.4271, 0.7010, 0.2176, 0.3278],
        [0.2444, 0.3577, 0.1440, 0.1080]]), 'labels': tensor([ 2,  7,  1, 13]), 'image_id': tensor([483108]), 'area': tensor([  9380.2393, 103913.7656,  13548.4082,   5039.4087]), 'iscrowd': tensor([0, 0, 0, 0]), 'orig_size': tensor([640, 428]), 'size': tensor([813, 544])}
{'boxes': tensor([[0.0812, 0.4834, 0.0296, 0.1382],
        [0.8932, 0.5715, 0.0330, 0.1398],
        [0.4591, 0.5126, 0.0257, 0.0382],
        [0.8218, 0.5319, 0.0368, 0.0502],
        [0.3590, 0.1519, 0.0466, 0.0247],
        [0.6389, 0.7607, 0.1588, 0.3618],
        [0.6016, 0.5017, 0.0522, 0.0675],
        [0.1422, 0.2492, 0.0227, 0.0537],
        [0.3495, 0.5074, 0.0245, 0.0581],
        [0.1414, 0.3356, 0.0226, 0.0198],
        [0.7196, 0.3524, 0.1794, 0.1756],
        [0.6877, 0.5000, 0.0508, 0.1117],
        [0.6620, 0.

## The DETR Architecture

### position encoding

LearnedPositionEmbedding 仅适用于图像长宽较短( < 50)的情况，如果在backbone提取特征时采用空洞卷积(dilation convolution)

往往会导致特征图宽高不够小导致越界，更大的特征图推荐使用sine来表示

In [6]:
import math

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos

def build_position_encoding(hidden_dim=256, position_embedding = 'sine'):
    N_steps = hidden_dim // 2
    if position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {position_embedding}")

    return position_embedding

### Test position encoding

In [7]:
test_position_encoding_model = build_position_encoding()
# samples: torch.Size([2, 3, 1199, 800])
test_position_encoding = test_position_encoding_model(samples)
print(test_position_encoding.shape)

torch.Size([2, 256, 813, 1056])


### Backbone

In [8]:
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
import torch.nn.functional as F

class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias


class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}
        else:
            return_layers = {'layer4': "layer4"}
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out


class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        
        replace_stride_with_dilation = [False, False, dilation]
        norm_layer = FrozenBatchNorm2d

        weights = None
        if name == 'resnet18':
            weights = torchvision.models.ResNet18_Weights.DEFAULT
            backbone = torchvision.models.resnet18(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 512
        elif name == 'resnet34':
            weights = torchvision.models.ResNet34_Weights.DEFAULT
            backbone = torchvision.models.resnet34(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 512
        elif name == 'resnet50':
            weights = torchvision.models.ResNet50_Weights.DEFAULT
            backbone = torchvision.models.resnet50(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 2048
        elif name == 'resnet101':
            weights = torchvision.models.ResNet101_Weights.DEFAULT
            backbone = torchvision.models.resnet101(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 2048
        else:
            raise ValueError(f"Unsupported backbone: {name}")
        
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

#### test backbone

In [9]:
backbone = Backbone(name='resnet50', train_backbone=False, return_interm_layers=True, dilation=True)

# Dict of {layer_{i}: nestedTensor}
backbone_out_nestedTensor = backbone(samples)

for k, v in backbone_out_nestedTensor.items():
    print(f"{k} tensor shape : {v.tensors.shape}, mask shape: {v.mask.shape}")

layer1 tensor shape : torch.Size([2, 256, 204, 264]), mask shape: torch.Size([2, 204, 264])
layer2 tensor shape : torch.Size([2, 512, 102, 132]), mask shape: torch.Size([2, 102, 132])
layer3 tensor shape : torch.Size([2, 1024, 51, 66]), mask shape: torch.Size([2, 51, 66])
layer4 tensor shape : torch.Size([2, 2048, 51, 66]), mask shape: torch.Size([2, 51, 66])


#### test position with backbone_feature

In [10]:
backbone_feature_position_enc = test_position_encoding_model(backbone_out_nestedTensor['layer4'])
print(backbone_feature_position_enc.shape)

torch.Size([2, 256, 51, 66])


### Joiner position & backbone

经过 backbone 返回的是层数名称和对应的 NestedTensor => 

```python
layer1 tensor shape : torch.Size([2, 256, 300, 200]), mask shape: torch.Size([2, 300, 200])

layer2 tensor shape : torch.Size([2, 512, 150, 100]), mask shape: torch.Size([2, 150, 100])

layer3 tensor shape : torch.Size([2, 1024, 75, 50]), mask shape: torch.Size([2, 75, 50])

layer4 tensor shape : torch.Size([2, 2048, 75, 50]), mask shape: torch.Size([2, 75, 50])
```

把每层的输出存到list里

In [11]:
class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos

### test joiner

In [12]:
test_joiner = Joiner(backbone, test_position_encoding_model)

test_joiner_out, test_joiner_pos = test_joiner(samples)

print(f"test_joiner_out has length: {len(test_joiner_out)}")
for i in range(len(test_joiner_out)):
    print(f'test_joiner_out layer_{i+1} has shape: {test_joiner_out[i].tensors.shape}')
    
    
print(f"\ntest_joiner_pos has length: {len(test_joiner_pos)}")
for i in range(len(test_joiner_pos)):
    print(f'test_joiner_pos layer_{i+1} has shape: {test_joiner_pos[i].shape}')


test_joiner_out has length: 4
test_joiner_out layer_1 has shape: torch.Size([2, 256, 204, 264])
test_joiner_out layer_2 has shape: torch.Size([2, 512, 102, 132])
test_joiner_out layer_3 has shape: torch.Size([2, 1024, 51, 66])
test_joiner_out layer_4 has shape: torch.Size([2, 2048, 51, 66])

test_joiner_pos has length: 4
test_joiner_pos layer_1 has shape: torch.Size([2, 256, 204, 264])
test_joiner_pos layer_2 has shape: torch.Size([2, 256, 102, 132])
test_joiner_pos layer_3 has shape: torch.Size([2, 256, 51, 66])
test_joiner_pos layer_4 has shape: torch.Size([2, 256, 51, 66])


### build backbone (joiner)

In [13]:
def build_backbone(
    hidden_dim: int,
    position_embedding: str = 'sine',
    backbone: str = 'resnet50',
    return_interm_layers: bool = False,
    dilation: bool = False
):
    position_embedding = build_position_encoding(hidden_dim=256, position_embedding='sine')
    train_backbone = False
    backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
    model = Joiner(backbone, position_embedding)
    return model

### Transformer

#### Encoder

In [14]:
import copy

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [15]:
class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)

class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src

        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)

        if self.norm is not None:
            output = self.norm(output)

        return output

#### Test Encoder

In [16]:
d_model = 256
normalize_before = True

test_encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=8, dim_feedforward=2048, activation="relu", normalize_before=normalize_before)
test_last_encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
# 采用pre_norm需要在最后的输出后面再进行一次归一化
test_encoder = TransformerEncoder(test_encoder_layer, num_layers=6, norm=test_last_encoder_norm)

print(test_encoder)

TransformerEncoder(
  (layers): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (linear1): Linear(in_features=256, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=256, bias=True)
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)


In [17]:
num_channels = 2048
d_model = 256

encoder_test_conv1x1 = nn.Conv2d(num_channels, d_model, kernel_size=1, stride=1)
# samples: NestedTensor => joiner(samples) => out: list[NestedTensor],  pos: list[tensor]   (NestedTensor: (Tuple(Tensors, masks)) )
encoder_test_joiner_out, encoder_test_joiner_pos = test_joiner(samples)

encoder_test_last_feature, encoder_test_last_mask = encoder_test_joiner_out[-1].decompose()
encoder_test_last_feature_uni_channel = encoder_test_conv1x1(encoder_test_last_feature)

print(f"tensor of shape: {encoder_test_last_feature_uni_channel.shape} with mask: {encoder_test_last_mask.shape} will go in encoder, \
      and its pos got {encoder_test_joiner_pos[-1].shape}")

# RuntimeError: Given normalized_shape=[256], expected input with shape [*, 256], but got input of size[2, 256, 35, 34]
# flatten NxCxHxW to HWxNxC
encoder_test_last_feature_uni_channel_flatten = rearrange(encoder_test_last_feature_uni_channel, "b c h w -> (h w) b c")
encoder_test_last_mask_flatten = rearrange(encoder_test_last_mask, 'b h w -> b (h w)')
encoder_test_joiner_pos_flatten = rearrange(encoder_test_joiner_pos[-1], "b c h w -> (h w) b c")

memory = test_encoder.forward(src=encoder_test_last_feature_uni_channel_flatten, mask=None, src_key_padding_mask=encoder_test_last_mask_flatten, pos=encoder_test_joiner_pos_flatten)

print(memory.shape)

tensor of shape: torch.Size([2, 256, 51, 66]) with mask: torch.Size([2, 51, 66]) will go in encoder,       and its pos got torch.Size([2, 256, 51, 66])
torch.Size([3366, 2, 256])


#### Decoder

In [18]:
class TransformerDecoder(nn.Module):

    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        output = tgt

        intermediate = []

        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)


class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_pre(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                    tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, tgt_mask, memory_mask,
                                 tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)

#### Test Decoder

In [20]:
normalize_before = True
num_queries = 100

query_embed = nn.Embedding(num_queries, 256)
print(query_embed.weight.shape)


test_decoder_layer = TransformerDecoderLayer(d_model=256, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=True)
test_decoder_norm = nn.LayerNorm(256) if normalize_before else None

# if return_intermediate=False : out hs shape: torch.Size([1, 2, 100, 256]),out memory shape: torch.Size([2, 256, 2024])
# else True : 
test_decoder = TransformerDecoder(test_decoder_layer, num_layers=6, norm=test_decoder_norm, return_intermediate=True)

sim_memory = torch.rand(size=(3366, 2, 256))

# samples: NestedTensor => joiner(samples) => out: list[NestedTensor],  pos: list[tensor]   (NestedTensor: (Tuple(Tensors, masks)) )
encoder_test_joiner_out, encoder_test_joiner_pos = test_joiner(samples)

encoder_test_last_feature, encoder_test_last_mask = encoder_test_joiner_out[-1].decompose()

# RuntimeError: Given normalized_shape=[256], expected input with shape [*, 256], but got input of size[2, 256, 35, 34]
# flatten NxCxHxW to HWxNxC

encoder_test_last_mask_flatten = rearrange(encoder_test_last_mask, 'b h w -> b (h w)')
encoder_test_joiner_pos_flatten = rearrange(encoder_test_joiner_pos[-1], "b c h w -> (h w) b c")
encoder_test_query_embed = query_embed.weight.unsqueeze(1).repeat(1, 2, 1)

tgt = torch.zeros_like(encoder_test_query_embed)

print(f"tgt shape: {tgt.shape}, memory shape: {sim_memory.shape}, memory_key_padding_mask shape: {encoder_test_last_mask_flatten.shape} \
      \npos shape: {encoder_test_joiner_pos_flatten.shape}, query_pos shape: {encoder_test_query_embed.shape}")

hs = test_decoder.forward(tgt=tgt, memory=sim_memory, memory_key_padding_mask=encoder_test_last_mask_flatten,
                          pos=encoder_test_joiner_pos_flatten, query_pos=encoder_test_query_embed)

print(f'hs shape: {hs.shape}')

# memory need to reshape(b, c, h, w)
out_hs, out_memory = hs.transpose(1, 2), sim_memory.permute(1, 2, 0)

print(f'out hs shape: {out_hs.shape},out memory shape: {out_memory.shape}')


torch.Size([100, 256])
tgt shape: torch.Size([100, 2, 256]), memory shape: torch.Size([3366, 2, 256]), memory_key_padding_mask shape: torch.Size([2, 3366])       
pos shape: torch.Size([3366, 2, 256]), query_pos shape: torch.Size([100, 2, 256])
hs shape: torch.Size([6, 100, 2, 256])
out hs shape: torch.Size([6, 2, 100, 256]),out memory shape: torch.Size([2, 256, 3366])


#### Transformer All in one

In [21]:
class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = rearrange(src, "b c h w -> (h w) b c")
        pos_embed = rearrange(pos_embed, "b c h w -> (h w) b c")
        mask = rearrange(mask, 'b h w -> b (h w)')
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        
        tgt = torch.zeros_like(query_embed)
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

#### Test Transformer

In [22]:
d_model = 256
num_queries = 100

test_Transformer = Transformer(d_model=d_model, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
                               dim_feedforward=1024, dropout=0.1,activation="relu", normalize_before=True, return_intermediate_dec=True)

transformer_features, transformer_pos = test_joiner(samples)
transformer_src, transformer_mask = transformer_features[-1].decompose()
print(f'src shape: {transformer_src.shape}')

assert transformer_mask is not None
transformer_conv1x1 = nn.Conv2d(2048, d_model, kernel_size=1, stride=1)
transformer_embed = nn.Embedding(num_queries, d_model)
transformer_hs, transformer_memory = test_Transformer.forward(transformer_conv1x1(transformer_src), transformer_mask,
                                                              transformer_embed.weight, transformer_pos[-1])

# (num_decoder_layer, batch_size, num_queries, d_model), memory : like src
print(f'hs shape: {transformer_hs.shape}, memory shape: {transformer_memory.shape}')

src shape: torch.Size([2, 2048, 51, 66])
hs shape: torch.Size([6, 2, 100, 256]), memory shape: torch.Size([2, 256, 51, 66])


## DETR

In [23]:
class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

class DETR(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_classes: number of object classes
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.input_proj = nn.Conv2d(2048, hidden_dim, kernel_size=1)
        self.backbone = backbone
        self.aux_loss = aux_loss

    def forward(self, samples: NestedTensor):
        """ The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits (including no-object) for all queries.
                                Shape= [batch_size x num_queries x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image (disregarding possible padding).
                               See PostProcess for information on how to retrieve the unnormalized bounding box.
               - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                                dictionnaries containing the two above keys for each decoder layer.
        """
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        src, mask = features[-1].decompose()
        assert mask is not None
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        # outputs_class[-1] 得到最后一层的decoder输出 outputs_class: [6, 2, 100, 92] => [-1] : [2, 100, 92]
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]


#### Test DETR

In [24]:
d_model = 256

backbone = build_backbone(hidden_dim=d_model, return_interm_layers=True, dilation=True)
transformer = Transformer(d_model=d_model, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
                          activation='relu', normalize_before=True, return_intermediate_dec=True)

test_detr = DETR(backbone, transformer, num_classes=91, num_queries=100, aux_loss=True)

# out: {
#       'pred_logits': _,
#       'pred_boxes': _
#       'aux_outputs'[Optional]: ['pred_logits': _, 'pred_boxes': _] x num_decoder_layers - 1
# }
test_detr_out = test_detr(samples)

outputs_without_aux = {k: v for k, v in test_detr_out.items() if k != 'aux_outputs'}

for k, v in outputs_without_aux.items():
    print(f'outputs_without_aux: {k}: {v.shape}')

idx = 0

for k, v in test_detr_out.items():
    if k == 'aux_outputs':
        for i in test_detr_out['aux_outputs']:
            for k2, v2 in i.items():
                print(f'aux_outputs: decoder_layer_{idx}: {k2}: {v2.shape}')
            idx += 1
    else:
        print(f'last decoder {k} : {v.shape}')

print(test_detr_out['aux_outputs'])

outputs_without_aux: pred_logits: torch.Size([2, 100, 92])
outputs_without_aux: pred_boxes: torch.Size([2, 100, 4])
last decoder pred_logits : torch.Size([2, 100, 92])
last decoder pred_boxes : torch.Size([2, 100, 4])
aux_outputs: decoder_layer_0: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_0: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_1: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_1: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_2: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_2: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_3: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_3: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_4: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_4: pred_boxes: torch.Size([2, 100, 4])
[{'pred_logits': tensor([[[ 0.1052,  0.5491, -0.2823,  ..., -0.0963,  1.0068,  0.8477],
         [-0.0434,  0.16

## HungarianMatcher

In [33]:
from torchvision.ops import generalized_box_iou, box_convert

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

    @torch.no_grad()
    def forward(self, outputs, targets):
        """ Performs the matching

        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates

            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

        # Compute the giou cost betwen boxes
        cost_giou = -generalized_box_iou(box_convert(out_bbox, 'cxcywh', 'xyxy'), box_convert(tgt_bbox, 'cxcywh', 'xyxy'))

        # Final cost matrix
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


def build_matcher():
    return HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)


### Test matcher

In [34]:
random_pred_logits = torch.rand(size=(2, 100, 92))
random_pred_boxes = torch.rand(size=(2, 100, 4))

sim_out = {
    'pred_logits': random_pred_logits,
    'pred_boxes': random_pred_boxes,
    'aux_outputs': [
        {
            'pred_logits': random_pred_logits,
            'pred_boxes': random_pred_boxes,
        } for _ in range(5)
    ]
}


idx = 1
for k, v in sim_out.items():
    if k == 'aux_outputs':
        for i in sim_out['aux_outputs']:
            for k2, v2 in i.items():
                print(f'aux_outputs: decoder_layer_{idx}: {k2}: {v2.shape}')
            idx += 1
    else:
        print(f'last decoder {k} : {v.shape}')

# print(target[0]['boxes'])

last decoder pred_logits : torch.Size([2, 100, 92])
last decoder pred_boxes : torch.Size([2, 100, 4])
aux_outputs: decoder_layer_1: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_1: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_2: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_2: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_3: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_3: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_4: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_4: pred_boxes: torch.Size([2, 100, 4])
aux_outputs: decoder_layer_5: pred_logits: torch.Size([2, 100, 92])
aux_outputs: decoder_layer_5: pred_boxes: torch.Size([2, 100, 4])


In [35]:
bs, num_queries = sim_out["pred_logits"].shape[:2]

# We flatten to compute the cost matrices in a batch
out_prob = sim_out["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
out_bbox = sim_out["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in target])
tgt_bbox = torch.cat([v["boxes"] for v in target])

print(f'out_prob shape: {out_prob.shape}, out_bbox shape: {out_bbox.shape}')

# 假设是batch_size = 2即每次两张图片，这里总共有28个物体(类别), 需要用sizes确定界限
print(tgt_ids.shape)
print(tgt_bbox.shape)

sizes = [len(v["boxes"]) for v in target]
for i in range(len(sizes)):
    print(f'object_{i} has {sizes[i]} classes/objects')

out_prob shape: torch.Size([200, 92]), out_bbox shape: torch.Size([200, 4])
torch.Size([8])
torch.Size([8, 4])
object_0 has 4 classes/objects
object_1 has 4 classes/objects


In [36]:
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
print(tgt_ids)
cost_class = -out_prob[:, tgt_ids]

# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_convert(out_bbox, 'cxcywh', 'xyxy'), box_convert(tgt_bbox, 'cxcywh', 'xyxy'))

print(f'cost_class shape: {cost_class.shape}')
print(f'cost_bbox shape: {cost_bbox.shape}')
print(f'cost_giou shape: {cost_giou.shape}')

tensor([ 1, 49, 61, 81,  1, 81, 27, 81])
cost_class shape: torch.Size([200, 8])
cost_bbox shape: torch.Size([200, 8])
cost_giou shape: torch.Size([200, 8])


In [37]:
# Final cost matrix
C = 1 * cost_class + 5 * cost_bbox + 2 * cost_giou
C = C.view(bs, num_queries, -1).cpu()
print(C.shape)

C_split = torch.split(C, sizes, dim=-1)

# 

for i, c in enumerate(C_split):
    print(i)
    print(c[i].shape)
    
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C_split)]

print(indices)

# 用于遍历一个 列表 (list)，其中列表的每个元素都是一个 元组 (tuple) 或其他可迭代的对象，并且这个元组（或可迭代对象）包含 两个元素
# i 将会是 linear_sum_assignment 返回的 行索引数组 (预测框的索引)。
# j 将会是 linear_sum_assignment 返回的 列索引数组 (真实目标框的索引)。
for i, j in indices:
    print(i)
    print(j)

torch.Size([2, 100, 8])
0
torch.Size([100, 4])
1
torch.Size([100, 4])
[(array([ 6, 69, 73, 74]), array([1, 2, 0, 3])), (array([59, 72, 80, 86]), array([1, 0, 2, 3]))]
[ 6 69 73 74]
[1 2 0 3]
[59 72 80 86]
[1 0 2 3]


In [38]:
matcher_result = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
print(matcher_result)

[(tensor([ 6, 69, 73, 74]), tensor([1, 2, 0, 3])), (tensor([59, 72, 80, 86]), tensor([1, 0, 2, 3]))]


## Loss

In [39]:
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    if target.numel() == 0:
        return [torch.zeros([], device=output.device)]
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res



In [40]:
def _get_src_permutation_idx(indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

def _get_tgt_permutation_idx(indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

# 0和1用来区分第1个图像和第二个图像，src_idx表示预测框的索引
# batch_idx[0] = 0 和 src_idx[0] = 1 表示批次中第一个图像的第 1 个预测框被匹配到
# batch_idx[15] = 1 和 src_idx[15] = 45 表示批次中第二个图像的第 45 个预测框被匹配到
pred_idx = _get_src_permutation_idx(matcher_result)
target_idx = _get_tgt_permutation_idx(matcher_result)

print(pred_idx)
print(target_idx)


(tensor([0, 0, 0, 0, 1, 1, 1, 1]), tensor([ 6, 69, 73, 74, 59, 72, 80, 86]))
(tensor([0, 0, 0, 0, 1, 1, 1, 1]), tensor([1, 2, 0, 3, 1, 0, 2, 3]))


### loss labels

In [46]:
import torch.nn.functional as F

def loss_labels(outputs, targets, indices, num_boxes, log=True):
    '''
    src_logits: 预测输出: (2, 100, 91)
    idx: 预测图像中的100个框, 有哪个被匹配到的索引
    target_classes_0: 匹配上的目标类别
    '''
    eos_coef = 0.1
    num_classes = 91
    empty_weight = torch.ones(num_classes + 1)
    empty_weight[-1] = 0.1
    
    assert 'pred_logits' in outputs
    # (bs, num_query, 91)
    src_logits = outputs['pred_logits']
    
    idx = _get_src_permutation_idx(indices)
    # print(f'idx: {idx}')
    # 取的是目标的索引
    target_classes_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
    # tensor([14,  1,  3,  5, 11,  2,  7, 12, 10, 13,  8,  6,  4,  9,  0,  1,  0])
    # print(f'target_1 labels: {targets[0]['labels']}')
    # print(f'target_2 labels: {targets[1]['labels']}')
    # print(f'matcher target result: {target_classes_o}')
    
    target_classes = torch.full(src_logits.shape[:2], fill_value=num_classes, dtype=torch.int64, device=src_logits.device)
    # print(target_classes.shape)
    target_classes[idx] = target_classes_o
    # print(src_logits.transpose(1, 2).shape)
    
    loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, weight=empty_weight)
    losses = {'loss_ce': loss_ce}
    
    if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
    return losses
    
loss_ce = loss_labels(sim_out, target, matcher_result, num_boxes=None)
for k, v in loss_ce.items():
    print(f'{k}: {v}')

loss_ce: 4.4762139320373535
class_error: 100.0


### loss cardinality

In [42]:
@torch.no_grad()
def loss_cardinality(outputs, targets, indices, num_boxes = None):
    """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
    This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
    """
    pred_logits = outputs['pred_logits']
    device = pred_logits.device
    tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
    # Count the number of predictions thar are NOT "no-object" (which is the last class) 统计每个图像中预测为非“无物体”的框的数量
    card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
    # print(f"max_pred_logits: {pred_logits.argmax(-1)}")
    # print(pred_logits.shape[-1] - 1)
    # print(f'card_pred: {card_pred}')
    # print(f'tgt_lengths: {tgt_lengths}')

    card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
    losses = {'cardinality_error': card_err}
    
    return losses
loss_card = loss_cardinality(sim_out, target, matcher_result)
for k, v in loss_card.items():
    print(f'{k}: {v}')


cardinality_error: 93.5


### loss boxes

In [43]:
from torchvision.ops import generalized_box_iou, box_convert

def loss_boxes(outputs, targets, indices, num_boxes):
    """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
        targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
        The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
    """
    assert 'pred_boxes' in outputs
    idx = _get_src_permutation_idx(indices)
    src_boxes = outputs['pred_boxes'][idx]
    
    # print(target[0]['boxes'].shape)
    target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
    # print(f'src_boxes shape: {src_boxes.shape}')
    # print(target_boxes)
    
    loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

    losses = {}
    losses['loss_bbox'] = loss_bbox.sum() / num_boxes

    loss_giou = 1 - torch.diag(generalized_box_iou(
        box_convert(src_boxes, in_fmt='cxcywh', out_fmt='xyxy'),
        box_convert(target_boxes, in_fmt='cxcywh', out_fmt='xyxy')))
    losses['loss_giou'] = loss_giou.sum() / num_boxes
    return losses

loss_box = loss_boxes(sim_out, target, matcher_result, num_boxes=17)
for k, v in loss_box.items():
    print(f'{k}: {v}')

loss_bbox: 0.15365788340568542
loss_giou: 0.3403002917766571


### loss all in one

In [44]:
matcher = build_matcher()

def get_loss(loss, outputs, targets, indices, num_boxes, **kwargs):
    loss_map = {
        'labels': loss_labels,
        'cardinality': loss_cardinality,
        'boxes': loss_boxes
    }
    
    assert loss in loss_map, f'do you really want to compute {loss} loss ?'
    return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

def loss_forward(outputs, targets):
    """ This performs the loss computation.
    Parameters:
        outputs: dict of tensors, see the output specification of the model for the format
        targets: list of dicts, such that len(targets) == batch_size.
                 The expected keys in each dict depends on the losses applied, see each loss' doc
    """
    outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
    
    # Retrieve the matching between the outputs of the last layer and the targets
    indices = matcher(outputs_without_aux, targets)
    
    # Compute the average number of target boxes accross all nodes, for normalization purposes
    num_boxes = sum(len(t["labels"]) for t in targets)
    num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
    
    losses_dict = {}
    
    losses = ['labels', 'boxes', 'cardinality']
    for loss in losses:
        losses_dict.update(get_loss(loss, outputs, targets, indices, num_boxes))
        
    if 'aux_outputs' in outputs:
        for i, aux_outputs in enumerate(outputs['aux_outputs']):
            indices = matcher(aux_outputs, targets)
            for loss in losses:
                kwargs = {}
                if loss == 'labels':
                    kwargs = {'log': False}
                l_dict = get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                losses_dict.update(l_dict)
    
    return losses_dict

#### test all losses

In [47]:
losses_dict = loss_forward(sim_out, target)

for k, v in losses_dict.items():
    print(f'{k}: {v}')

loss_ce: 4.4762139320373535
class_error: 100.0
loss_bbox: tensor([0.3265])
loss_giou: tensor([0.7231])
cardinality_error: 93.5
loss_ce_0: 4.4762139320373535
loss_bbox_0: tensor([0.3265])
loss_giou_0: tensor([0.7231])
cardinality_error_0: 93.5
loss_ce_1: 4.4762139320373535
loss_bbox_1: tensor([0.3265])
loss_giou_1: tensor([0.7231])
cardinality_error_1: 93.5
loss_ce_2: 4.4762139320373535
loss_bbox_2: tensor([0.3265])
loss_giou_2: tensor([0.7231])
cardinality_error_2: 93.5
loss_ce_3: 4.4762139320373535
loss_bbox_3: tensor([0.3265])
loss_giou_3: tensor([0.7231])
cardinality_error_3: 93.5
loss_ce_4: 4.4762139320373535
loss_bbox_4: tensor([0.3265])
loss_giou_4: tensor([0.7231])
cardinality_error_4: 93.5
