# import

In [3]:
import os
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW, lr_scheduler
import transforms as T
from torchvision import datasets, ops
from torchvision.models.feature_extraction import create_feature_extractor
from einops import rearrange
from pycocotools import mask as coco_mask

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import linear_sum_assignment

# build dataset

In [4]:
from pathlib import Path

import torch
import torch.utils.data
from pycocotools import mask as coco_mask

from torchvision_datasets import CocoDetection as TvCocoDetection
from util.misc import get_local_rank, get_local_size
import transforms as T


class CocoDetection(TvCocoDetection):
    def __init__(self, img_folder, ann_file, transforms, return_masks, cache_mode=False, local_rank=0, local_size=1):
        super(CocoDetection, self).__init__(img_folder, ann_file,
                                            local_rank=local_rank, local_size=local_size)
        self._transforms = transforms
        self.prepare = ConvertCocoPolysToMask(return_masks)

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        img, target = self.prepare(img, target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target


def convert_coco_poly_to_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks


class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False):
        self.return_masks = return_masks

    def __call__(self, image, target):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        anno = target["annotations"]

        anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        if self.return_masks:
            segmentations = [obj["segmentation"] for obj in anno]
            masks = convert_coco_poly_to_mask(segmentations, h, w)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if self.return_masks:
            target["masks"] = masks
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        return image, target


def make_coco_transforms(image_set):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),
                    T.RandomResize(scales, max_size=1333),
                ])
            ),
            normalize,
        ])

    if image_set == 'val':
        return T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ])

    raise ValueError(f'unknown {image_set}')


def build(image_set):
    
    mode = 'instances'
    PATHS = {
        "train": (os.path.join('tiny_coco', "train2017"), os.path.join('tiny_coco', "annotations", f'{mode}_train2017.json')),
        "val": (os.path.join('tiny_coco', "val2017"), os.path.join('tiny_coco', "annotations", f'{mode}_val2017.json')),
    }

    img_folder, ann_file = PATHS[image_set]
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=False,
                            local_rank=get_local_rank(), local_size=get_local_size())
    return dataset


# build DataLoader

In [5]:
import util.misc as utils

dataset_train = build(image_set='train')
dataset_val = build(image_set='val')


data_loader_train = DataLoader(dataset_train, 2,
                                   collate_fn=utils.collate_fn, num_workers=0,
                                   pin_memory=True, shuffle=True)
data_loader_val = DataLoader(dataset_val, 2,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=0,
                                 pin_memory=True)

print(f'\nNumber of train samples: {len(data_loader_train)}')
print(f'\nNumber of val samples: {len(data_loader_val)}')

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

Number of train samples: 25

Number of val samples: 25


## Test dataloader

In [6]:
# NestedTensor, Tuple of (annotations) => Tuple(Tensors, masks), Tuple of (annotations)
samples, target = next(iter(data_loader_train))

print(samples.tensors.shape)
print(samples.mask.shape)

print(target[0])
print(target[1])

torch.Size([2, 3, 981, 736])
torch.Size([2, 981, 736])
{'boxes': tensor([[0.6462, 0.5000, 0.2346, 1.0000],
        [0.3243, 0.5000, 0.6485, 1.0000],
        [0.2572, 0.6367, 0.5145, 0.7266]]), 'labels': tensor([ 2,  1, 17]), 'image_id': tensor([562150]), 'area': tensor([115069.0547, 318147.0000, 183393.0625]), 'iscrowd': tensor([0, 0, 0]), 'orig_size': tensor([516, 640]), 'size': tensor([672, 730])}
{'boxes': tensor([[0.5468, 0.7075, 0.6022, 0.5663]]), 'labels': tensor([70]), 'image_id': tensor([360772]), 'area': tensor([144803.0312]), 'iscrowd': tensor([0]), 'orig_size': tensor([640, 480]), 'size': tensor([981, 736])}


# model

## position encoding

In [7]:
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn

from util.misc import NestedTensor


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos


def build_position_encoding(hidden_dim=256, position_embedding = 'sine'):
    N_steps = hidden_dim // 2
    if position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {position_embedding}")

    return position_embedding

### Test position encoding

In [8]:
test_position_encoding_model = build_position_encoding()
# samples: torch.Size([2, 3, 1199, 800])
test_position_encoding = test_position_encoding_model(samples)
print(f'train samples shape: {samples.tensors.shape}')
print(f'position encoding shape: {test_position_encoding.shape}')

train samples shape: torch.Size([2, 3, 981, 736])
position encoding shape: torch.Size([2, 256, 981, 736])


## Backbone

In [9]:
"""
Backbone modules.
"""
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from util.misc import NestedTensor, is_main_process

class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n, eps=1e-5):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))
        self.eps = eps

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = self.eps
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias


class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
            return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
            self.strides = [8, 16, 32]
            self.num_channels = [512, 1024, 2048]
        else:
            return_layers = {'layer4': "0"}
            self.strides = [32]
            self.num_channels = [2048]
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out
    
class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        
        replace_stride_with_dilation = [False, False, dilation]
        norm_layer = FrozenBatchNorm2d

        weights = None
        if name == 'resnet18':
            weights = torchvision.models.ResNet18_Weights.DEFAULT
            backbone = torchvision.models.resnet18(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 512
        elif name == 'resnet34':
            weights = torchvision.models.ResNet34_Weights.DEFAULT
            backbone = torchvision.models.resnet34(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 512
        elif name == 'resnet50':
            weights = torchvision.models.ResNet50_Weights.DEFAULT
            backbone = torchvision.models.resnet50(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 2048
        elif name == 'resnet101':
            weights = torchvision.models.ResNet101_Weights.DEFAULT
            backbone = torchvision.models.resnet101(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 2048
        else:
            raise ValueError(f"Unsupported backbone: {name}")
        
        assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
        super().__init__(backbone, train_backbone, return_interm_layers)
        if dilation:
            self.strides[-1] = self.strides[-1] // 2

### test backbone

In [10]:
backbone = Backbone(name='resnet50', train_backbone=False, return_interm_layers=True, dilation=False)

# Dict of {layer_{i}: nestedTensor}
backbone_out_nestedTensor = backbone(samples)

print(f'train samples shape: {samples.tensors.shape}')
for k, v in backbone_out_nestedTensor.items():
    print(f"{k} tensor shape : {v.tensors.shape}, mask shape: {v.mask.shape}")

train samples shape: torch.Size([2, 3, 981, 736])
0 tensor shape : torch.Size([2, 512, 123, 92]), mask shape: torch.Size([2, 123, 92])
1 tensor shape : torch.Size([2, 1024, 62, 46]), mask shape: torch.Size([2, 62, 46])
2 tensor shape : torch.Size([2, 2048, 31, 23]), mask shape: torch.Size([2, 31, 23])


### test position with backbone_feature

In [11]:
backbone_feature_position_enc = test_position_encoding_model(backbone_out_nestedTensor['2'])
print(backbone_feature_position_enc.shape)

torch.Size([2, 256, 31, 23])


## Joiner position & backbone

In [12]:
class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)
        self.strides = backbone.strides
        self.num_channels = backbone.num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in sorted(xs.items()):
            out.append(x)

        # position encoding
        for x in out:
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos

### test joiner

In [13]:
test_joiner = Joiner(backbone, test_position_encoding_model)

test_joiner_out, test_joiner_pos = test_joiner(samples)

print(f"test_joiner_out has length: {len(test_joiner_out)}")
for i in range(len(test_joiner_out)):
    print(f'test_joiner_out layer_{i+1} has shape: {test_joiner_out[i].tensors.shape}')
    
    
print(f"\ntest_joiner_pos has length: {len(test_joiner_pos)}")
for i in range(len(test_joiner_pos)):
    print(f'test_joiner_pos layer_{i+1} has shape: {test_joiner_pos[i].shape}')


test_joiner_out has length: 3
test_joiner_out layer_1 has shape: torch.Size([2, 512, 123, 92])
test_joiner_out layer_2 has shape: torch.Size([2, 1024, 62, 46])
test_joiner_out layer_3 has shape: torch.Size([2, 2048, 31, 23])

test_joiner_pos has length: 3
test_joiner_pos layer_1 has shape: torch.Size([2, 256, 123, 92])
test_joiner_pos layer_2 has shape: torch.Size([2, 256, 62, 46])
test_joiner_pos layer_3 has shape: torch.Size([2, 256, 31, 23])


## build backbone (joiner)

In [14]:
def build_backbone(
    hidden_dim: int,
    position_embedding: str = 'sine',
    backbone: str = 'resnet50',
    return_interm_layers: bool = False,
    dilation: bool = False
):
    position_embedding = build_position_encoding(hidden_dim=256, position_embedding='sine')
    train_backbone = False
    backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
    model = Joiner(backbone, position_embedding)
    return model

# DeformableTransformer

## Encoder

In [15]:
import copy

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

### MSDeformAttn

In [16]:
import functools
import torch.nn.init as init

def deformable_attention_core_func_v2(
    value: List[torch.Tensor], # <-- Change type hint to List[Tensor]
    value_spatial_shapes,
    sampling_locations: torch.Tensor,
    attention_weights: torch.Tensor,
    num_points_list: List[int],
    method='default',
    value_shape='irrelevant', # <-- This parameter is now less relevant if value is pre-processed
    ):
    """
    Args:
        value (List[Tensor]): List of level features [(bs * n_head, c, H_0, W_0), ...]
        # ... other args
    """
    # bs, n_head, c, _ = value[0].shape # Incorrect way to get dims if value is list
    bs_times_n_head, c, _, _ = value[0].shape # Get shape from first level tensor
    # Need bs and n_head separately for reshaping output later
    # We don't know n_head directly here unless passed or inferred. Let's assume it's needed later.
    # sampling_locations shape: [bs, query_length, n_head, n_levels * n_points, 2]
    bs, Len_q, n_head, _, _ = sampling_locations.shape # Infer bs and n_head here

    # TODO: If value_shape logic was needed, reinstate carefully
    # if value_shape == 'default':
    #     pass # Assume value is already a list [bs*n_head, c, h, w]
    # elif value_shape == 'reshape': # This block might be redundant now
    #     # Original value shape was [bs, value_length, n_heads, c]
    #     # bs, _, n_head, c = value.shape # Can't do this if value is List
    #     # split_shape = [h * w for h, w in value_spatial_shapes]
    #     # # value needs permutation and splitting
    #     # value_permuted = value.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, value_length] - Incorrect permute
    #     # # Need to reshape to spatial first: [bs, n_heads, c, h, w] then flatten/split? Complex.
    #     # # --> It's better to do the splitting *before* calling this core function.
    #     pass # Remove or adapt this logic

    _, Len_q, _, _, _ = sampling_locations.shape

    sampling_grids = 2 * sampling_locations - 1 if method == 'default' else sampling_locations
    sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # Shape: [bs * n_head, Len_q, n_levels * n_points, 2]
    sampling_locations_list = sampling_grids.split(num_points_list, dim=-2) # List of [bs * n_head, Len_q, n_points_level, 2]

    sampling_value_list = []
    for level, (h, w) in enumerate(value_spatial_shapes):
        # value_l is already precomputed and passed in the list 'value'
        value_l = value[level] # Shape: [bs * n_head, c, h, w]
        sampling_grid_l: torch.Tensor = sampling_locations_list[level] # Shape: [bs * n_head, Len_q, n_points_level, 2]

        if method == 'default':
            # grid_sample input: [N, C, H_in, W_in], grid: [N, H_out, W_out, 2] -> output: [N, C, H_out, W_out]
            # Our value_l: [bs*n_head, c, h, w]
            # Our grid_l: [bs*n_head, Len_q, n_points_level, 2]
            # We need grid reshaped to [N, H_out, W_out, 2] where N=bs*n_head, H_out=Len_q, W_out=n_points_level
            sampling_grid_l_reshaped = sampling_grid_l.reshape(bs * n_head, Len_q, num_points_list[level], 2)

            sampling_value_l = F.grid_sample(
                value_l,
                sampling_grid_l_reshaped, # Use reshaped grid
                mode='bilinear',
                padding_mode='zeros',
                align_corners=False)
            # Output sampling_value_l shape: [bs*n_head, c, Len_q, n_points_level]

        elif method == 'discrete':
             # n * m, seq, n, 2
             sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value_l.device) + 0.5).to(torch.int64)

             # FIX ME? for rectangle input
             # Clamp uses h for both dimensions? Should use w for dim 0, h for dim 1
             # sampling_coord = sampling_coord.clamp(0, h - 1) # Original
             sampling_coord[..., 0] = sampling_coord[..., 0].clamp(0, w - 1) # Clamp x coord
             sampling_coord[..., 1] = sampling_coord[..., 1].clamp(0, h - 1) # Clamp y coord

             sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2)

             s_idx = torch.arange(sampling_coord.shape[0], device=value_l.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1])
             # Indexing: value_l[batch_idx, channel_idx, y_coord, x_coord]
             sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c

             sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level])


        sampling_value_list.append(sampling_value_l)
        # sampling_value_list contains tensors of shape [bs*n_head, c, Len_q, n_points_level]

    # Concatenate sampling values across levels
    # List of [bs*n_head, c, Len_q, n_points_level] -> [bs*n_head, c, Len_q, sum(n_points_list)]
    concatenated_sampling_values = torch.cat(sampling_value_list, dim=-1)

    # Prepare attention weights
    # attention_weights shape: [bs, Len_q, n_head, sum(n_points_list)]
    # Needs shape [bs * n_head, 1, Len_q, sum(n_points_list)] for broadcasting
    attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list))

    # Apply weights: [bs*n_head, c, Len_q, total_points] * [bs*n_head, 1, Len_q, total_points]
    weighted_sample_locs = concatenated_sampling_values * attn_weights
    # Sum over points: -> [bs*n_head, c, Len_q]
    output = weighted_sample_locs.sum(dim=-1)

    # Reshape output: [bs * n_head, c, Len_q] -> [bs, n_head * c, Len_q]
    output = output.reshape(bs, n_head * c, Len_q) # Need head_dim here! Pass it in or infer C=head_dim

    # Let's assume C = self.head_dim passed somehow or calculated
    head_dim = c # Use inferred 'c' from value_l shape
    output = output.reshape(bs, n_head * head_dim, Len_q)

    # Permute to [bs, Len_q, C] which matches MSDeformAttn expected output
    return output.permute(0, 2, 1)

class MSDeformAttn(nn.Module):
    def __init__(
        self,
        embed_dim=256,
        num_heads=8,
        num_levels=4,
        num_points=4,
        method='default',
        offset_scale=0.5,
    ):
        """Multi-Scale Deformable Attention
        """
        super(MSDeformAttn, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_levels = num_levels
        self.offset_scale = offset_scale

        if isinstance(num_points, list):
            assert len(num_points) == num_levels, ''
            num_points_list = num_points
        else:
            num_points_list = [num_points for _ in range(num_levels)]

        self.num_points_list = num_points_list

        num_points_scale = [1/n for n in num_points_list for _ in range(n)]
        self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32))

        self.total_points = num_heads * sum(num_points_list)
        self.method = method

        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
        self.attention_weights = nn.Linear(embed_dim, self.total_points)

        self.ms_deformable_attn_core = functools.partial(deformable_attention_core_func_v2, method=self.method)

        self._reset_parameters()

        if method == 'discrete':
            for p in self.sampling_offsets.parameters():
                p.requires_grad = False

    def _reset_parameters(self):
        # sampling_offsets
        init.constant_(self.sampling_offsets.weight, 0)
        thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
        grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1])
        scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1)
        grid_init *= scaling
        self.sampling_offsets.bias.data[...] = grid_init.flatten()

        # attention_weights
        init.constant_(self.attention_weights.weight, 0)
        init.constant_(self.attention_weights.bias, 0)


    def forward(self,
                query: torch.Tensor,
                reference_points: torch.Tensor,
                value: torch.Tensor,            # value is likely flattened: [bs, value_length, C]
                value_spatial_shapes: List[int],
                level_start_index: torch.Tensor): # <-- 新增 level_start_index
        """
        Args:
            query (Tensor): [bs, query_length, C]
            reference_points (Tensor): [bs, query_length, n_levels, 2 or 4]
            value (Tensor): [bs, value_length, C]
            value_spatial_shapes (List): [n_levels, 2]
            level_start_index (Tensor): [n_levels]  <-- 新增
        Returns:
            output (Tensor): [bs, query_length, C]
        """
        bs, Len_q = query.shape[:2]
        _, Len_v = value.shape[:2] # 获取 value length

        # --- 计算 sampling offsets 和 attention weights ---
        sampling_offsets: torch.Tensor = self.sampling_offsets(query)
        # sampling_offsets shape: [bs, Len_q, num_heads * sum(num_points_list) * 2]
        # Reshape to: [bs, Len_q, num_heads, sum(num_points_list), 2]
        sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, sum(self.num_points_list), 2)

        attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))
        attention_weights = F.softmax(attention_weights, dim=-1)
        # attention_weights shape: [bs, Len_q, num_heads, sum(num_points_list)]

        # --- 计算 sampling locations ---
        if reference_points.shape[-1] == 2:
            # Reshape sampling_offsets to separate levels and points
            # Assumes num_points is constant per level (self.num_points_list = [N, N, ..., N])
            # If num_points can vary, this needs adjustment using num_points_list
            num_points_per_level = self.num_points_list[0] # Assuming constant
            sampling_offsets = sampling_offsets.reshape(
                bs, Len_q, self.num_heads, self.num_levels, num_points_per_level, 2)
            # New sampling_offsets shape: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            # Prepare offset_normalizer
            # Flip H, W to W, H for normalization
            offset_normalizer = torch.as_tensor(value_spatial_shapes, dtype=query.dtype, device=query.device).flip([1])
            # Reshape for broadcasting: [1, 1, 1, num_levels, 1, 2]
            offset_normalizer = offset_normalizer.reshape(1, 1, 1, self.num_levels, 1, 2)

            # Normalize offsets: [..., num_heads, num_levels, num_points, 2] / [..., 1, num_levels, 1, 2]
            normalized_offsets = sampling_offsets / offset_normalizer

            # Prepare reference_points
            # Reshape/unsqueeze for broadcasting: [bs, Len_q, 1, num_levels, 1, 2]
            reference_points_unsqueezed = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2)

            # Calculate sampling locations: [..., 1, num_levels, 1, 2] + [..., num_heads, num_levels, num_points, 2]
            sampling_locations = reference_points_unsqueezed + normalized_offsets

            # Reshape back for core function: [bs, Len_q, num_heads, num_levels * num_points, 2]
            sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * num_points_per_level, 2)

        elif reference_points.shape[-1] == 4:
            # reference_points [bs, Len_q, n_levels, 4] (x_center, y_center, w, h)
            # sampling_offsets [bs, Len_q, num_heads, sum(num_points_list), 2]

            num_points_per_level = self.num_points_list[0] # Assuming constant
            sampling_offsets = sampling_offsets.reshape(
                bs, Len_q, self.num_heads, self.num_levels, num_points_per_level, 2)
            # Reshaped offsets: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            ref_xy = reference_points[..., :2].to(dtype=query.dtype, device=query.device) # [bs, Len_q, num_levels, 2]
            ref_wh = reference_points[..., 2:].to(dtype=query.dtype, device=query.device) # [bs, Len_q, num_levels, 2]

            # num_points_scale was originally [sum(num_points_list)]
            num_points_scale = self.num_points_scale.to(dtype=query.dtype, device=query.device)
            # Reshape to [1, 1, 1, num_levels, num_points_per_level, 1]
            num_points_scale = num_points_scale.reshape(1, 1, 1, self.num_levels, num_points_per_level, 1)

            # Calculate scaled offset: offset = sampling_offset * scale * ref_wh
            # Shapes: [..., head, lvl, pt, 2] * [..., 1, lvl, pt, 1] * [..., 1, lvl, 1, 2]
            offset = sampling_offsets * num_points_scale * self.offset_scale * ref_wh.reshape(bs, Len_q, 1, self.num_levels, 1, 2)
            # Result shape: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            # Calculate sampling locations: sampling_locations = ref_xy + offset
            # Shapes: [..., 1, lvl, 1, 2] + [..., head, lvl, pt, 2]
            sampling_locations = ref_xy.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + offset
            # Result shape: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            # Reshape back for core function: [bs, Len_q, num_heads, num_levels * num_points, 2]
            sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * num_points_per_level, 2)

        else:
            raise ValueError(
                "Last dim of reference_points must be 2 or 4, but get {} instead.".
                format(reference_points.shape[-1]))

        # --- Prepare value for core function ---
        # Input value shape: [bs, value_length, C]
        # Core function expects value LIST: [level0_tensor, level1_tensor, ...]
        # where level_tensor shape is [bs, H_lvl * W_lvl, num_heads, head_dim] ?? No, core expects [bs*n_head, c, h, w] after reshape
        # Let's prepare the list first. Value needs splitting based on level_start_index.
        # head_dim = C // num_heads
        value = value.reshape(bs, Len_v, self.num_heads, self.head_dim) # [bs, value_length, num_heads, head_dim]
        # Need to permute to match core function's expectation after splitting?
        # Core uses: value_l = value[level].reshape(bs * n_head, c, h, w)
        # Let's create the list of [bs, H*W, num_heads, head_dim] first
        value_list = []
        value_level_shapes = [(h, w) for h, w in value_spatial_shapes] # Ensure it's list of tuples
        for lvl in range(self.num_levels):
            start_idx = level_start_index[lvl]
            end_idx = level_start_index[lvl+1] if lvl < self.num_levels - 1 else Len_v
            value_lvl = value[:, start_idx:end_idx, :, :] # Shape [bs, H*W, num_heads, head_dim]
            # Reshape and permute for core function: [bs * num_heads, head_dim, H, W]
            h, w = value_level_shapes[lvl]
            value_lvl_permuted = value_lvl.permute(0, 2, 3, 1).reshape(bs * self.num_heads, self.head_dim, h, w)
            value_list.append(value_lvl_permuted)


        # --- Call the core function ---
        # Note: Pass value_list to the core function
        output = self.ms_deformable_attn_core(
            value_list, # Pass the prepared list
            value_spatial_shapes,
            sampling_locations,
            attention_weights,
            self.num_points_list,
            value_shape='reshape') # Use 'reshape' as we manually prepared value_list


        # output shape from core: [bs, n_head * c, Len_q]
        # Reshape to [bs, Len_q, n_head * c] which is [bs, Len_q, C]
        return output # Should already be [bs, Len_q, C] if core returns correctly

In [17]:
class DeformableTransformerEncoderLayer(nn.Module):
    def __init__(self,
                 d_model=256, d_ffn=1024,
                 dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4):
        super().__init__()

        # self attention
        self.self_attn = MSDeformAttn(embed_dim=d_model, num_heads=n_heads, num_levels=n_levels, num_points=n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm2(src)
        return src

    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
        # self attention
        src2 = self.self_attn(query=self.with_pos_embed(src, pos),
                            reference_points=reference_points,
                            value=src,
                            value_spatial_shapes=spatial_shapes,
                            level_start_index=level_start_index) # <-- Pass level_start_index
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # ffn
        src = self.forward_ffn(src)
        return src
    

class DeformableTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    @staticmethod
    def get_reference_points(spatial_shapes, valid_ratios, device):
        reference_points_list = []
        for lvl, (H_, W_) in enumerate(spatial_shapes):

            ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                          torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        return reference_points

    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
        output = src
        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
        for _, layer in enumerate(self.layers):
            output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)

        return output

In [18]:
d_model = 256
normalize_before = True

test_encoder_layer = DeformableTransformerEncoderLayer(d_model=d_model, n_heads=8, d_ffn=1024, activation="relu")
test_encoder = DeformableTransformerEncoder(test_encoder_layer, num_layers=6)

print(test_encoder)

DeformableTransformerEncoder(
  (layers): ModuleList(
    (0-5): 6 x DeformableTransformerEncoderLayer(
      (self_attn): MSDeformAttn(
        (sampling_offsets): Linear(in_features=256, out_features=256, bias=True)
        (attention_weights): Linear(in_features=256, out_features=128, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=256, out_features=1024, bias=True)
      (dropout2): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=1024, out_features=256, bias=True)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
)


### test encoder

In [19]:

num_feature_levels = 4
num_backbone_outs = len(backbone.strides)
input_proj_list = []


if num_feature_levels > 1:
    for _ in range(num_backbone_outs):
        in_channels = backbone.num_channels[_]
        input_proj_list.append(nn.Sequential(
                        nn.Conv2d(in_channels, d_model, kernel_size=1),
                        nn.GroupNorm(32, d_model),
                    ))
    for _ in range(num_feature_levels - num_backbone_outs):
        input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, d_model, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, d_model),
                ))
        in_channels = d_model
    input_proj = nn.ModuleList(input_proj_list)

else:
    input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(backbone.num_channels[0], d_model, kernel_size=1),
                    nn.GroupNorm(32, d_model),
                )])

for proj in input_proj:
    nn.init.xavier_uniform_(proj[0].weight, gain=1)
    nn.init.constant_(proj[0].bias, 0)


features, pos = test_joiner(samples)
srcs = []
masks = []

for l, feat in enumerate(features):
    src, mask = feat.decompose()
    srcs.append(input_proj[l](src))
    masks.append(mask)

if num_feature_levels > len(srcs):
    _len_srcs = len(srcs)
    for l in range(_len_srcs, num_feature_levels):
        if l == _len_srcs:
             src = input_proj[l](features[-1].tensors)
        else:
            src = input_proj[l](srcs[-1])
        m = samples.mask
        mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
        pos_l = test_joiner[1](NestedTensor(src, mask)).to(src.dtype)
        srcs.append(src)
        masks.append(mask)
        pos.append(pos_l)

# srcs, masks, pos, query_embeds

In [20]:
num_layers = 6
num_queries = 300
two_stage = True
with_box_refine = True
use_ms_detr = True

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x



class_embed = nn.Linear(d_model, 91)
bbox_embed = MLP(d_model, d_model, 4, 3)



if not two_stage:
    query_embed = nn.Embedding(num_queries, d_model * 2)

query_embeds = None
if not two_stage:
    query_embeds = query_embed.weight
    
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (num_layers + 1) if two_stage else num_layers

if with_box_refine:
    class_embed = _get_clones(class_embed, num_pred)
    bbox_embed = _get_clones(bbox_embed, num_pred)
else:
    class_embed = nn.ModuleList([class_embed for _ in range(num_pred)])
    bbox_embed = nn.ModuleList([bbox_embed for _ in range(num_pred)])

if use_ms_detr:
    # NOTE: in implementation (a), we share the box heads for o2o and o2m branches, but do not share the class heads
    class_embed_o2m = copy.deepcopy(class_embed[:-1])
    bbox_embed_o2m = bbox_embed[:-1]

In [21]:
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))


for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos)):
    bs, c, h, w = src.shape
    spatial_shape = (h, w)
    spatial_shapes.append(spatial_shape)
    src = src.flatten(2).transpose(1, 2)
    mask = mask.flatten(1)
    pos_embed = pos_embed.flatten(2).transpose(1, 2)
    
    lvl_pos_embed = pos_embed + level_embed[lvl].view(1, 1, -1)
    
    lvl_pos_embed_flatten.append(lvl_pos_embed)
    src_flatten.append(src)
    mask_flatten.append(mask)
    
def get_valid_ratio(mask):
    _, H, W = mask.shape
    valid_H = torch.sum(~mask[:, :, 0], 1)
    valid_W = torch.sum(~mask[:, 0, :], 1)
    valid_ratio_h = valid_H.float() / H
    valid_ratio_w = valid_W.float() / W
    valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
    return valid_ratio
    
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)

print(f'src_flatten shape : {src_flatten.shape} (H1 X W1 + H2 X W2 + H3 X W3 + H4 X W4)')
print(f'mask_flatten shape : {mask_flatten.shape}')
print(f'lvl_pos_embed_flatten shape : {lvl_pos_embed_flatten.shape}')
print(f'level_start_index : {level_start_index}')
print(f'valid_ratios shape: {valid_ratios.shape}')

src_flatten shape : torch.Size([2, 15073, 256]) (H1 X W1 + H2 X W2 + H3 X W3 + H4 X W4)
mask_flatten shape : torch.Size([2, 15073])
lvl_pos_embed_flatten shape : torch.Size([2, 15073, 256])
level_start_index : tensor([    0, 11316, 14168, 14881])
valid_ratios shape: torch.Size([2, 4, 2])


In [22]:
def get_reference_points(spatial_shapes, valid_ratios, device):
    reference_points_list = []
    for lvl, (H_, W_) in enumerate(spatial_shapes):

        ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                        torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
        ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
        ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
        ref = torch.stack((ref_x, ref_y), -1)
        reference_points_list.append(ref)
    reference_points = torch.cat(reference_points_list, 1)
    reference_points = reference_points[:, :, None] * valid_ratios[:, None]
    return reference_points

device = "cuda" if torch.cuda.is_available() else "cpu"
reference_points = get_reference_points(spatial_shapes, valid_ratios, device)

encoder_layer = DeformableTransformerEncoderLayer()
encoder_layer_output = encoder_layer.forward(src_flatten, lvl_pos_embed_flatten, reference_points, spatial_shapes, level_start_index, padding_mask=None)
print(encoder_layer_output.shape)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([2, 15073, 256])


### Encoder Test

In [23]:
num_encoder_layers = 6
encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)

memory = encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
print(memory.shape)

torch.Size([2, 15073, 256])


# DecoderLayer

In [24]:
mixed_selection = False

if two_stage:
    enc_output = nn.Linear(d_model, d_model)
    enc_output_norm = nn.LayerNorm(d_model)
    pos_trans = nn.Linear(d_model * 2, d_model * 2)
    pos_trans_norm = nn.LayerNorm(d_model * 2)
else:
    reference_points_model = nn.Linear(d_model, 2)

def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes):
        N_, S_, C_ = memory.shape
        base_scale = 4.0
        proposals = []
        _cur = 0
        for lvl, (H_, W_) in enumerate(spatial_shapes):
            mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
            valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
            valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

            grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
                                            torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
            grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

            scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
            grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
            wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
            proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
            proposals.append(proposal)
            _cur += (H_ * W_)
        output_proposals = torch.cat(proposals, 1)
        output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
        output_proposals = torch.log(output_proposals / (1 - output_proposals))
        output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
        output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))

        output_memory = memory
        output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
        output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
        output_memory = enc_output_norm(enc_output(output_memory))
        return output_memory, output_proposals

def get_proposal_pos_embed(proposals):
        num_pos_feats = 128
        temperature = 10000
        scale = 2 * math.pi

        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
        # N, L, 4
        proposals = proposals.sigmoid() * scale
        # N, L, 4, 128
        pos = proposals[:, :, :, None] / dim_t
        # N, L, 4, 64, 2
        pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
        return pos

In [25]:

# prepare input for decoder
bs, _, c = memory.shape
if two_stage:
    output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)

    # hack implementation for two-stage Deformable DETR
    enc_outputs_class = class_embed[num_layers](output_memory)
    enc_outputs_coord_unact = bbox_embed[num_layers](output_memory) + output_proposals

    topk = 300
    topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
    topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
    topk_coords_unact = topk_coords_unact.detach()
    reference_points = topk_coords_unact.sigmoid()
    init_reference_out = reference_points
    pos_trans_out = pos_trans_norm(pos_trans(get_proposal_pos_embed(topk_coords_unact)))

    if not mixed_selection:
        query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
    else:
        # tgt: content embedding, query_embed here is the learnable content embedding
        tgt = query_embed.unsqueeze(0).expand(bs, -1, -1)
        # query_embed: position embedding, transformed from the topk proposals
        query_embed, _ = torch.split(pos_trans_out, c, dim=2)

else:
    query_embed, tgt = torch.split(query_embeds, c, dim=1)
    query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
    tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
    reference_points = reference_points_model(query_embed).sigmoid()
    init_reference_out = reference_points

print(query_embed.shape)
print(tgt.shape)
print(init_reference_out.shape)

torch.Size([2, 300, 256])
torch.Size([2, 300, 256])
torch.Size([2, 300, 4])


## DecoderLayer

In [26]:
class DeformableTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model=256, d_ffn=1024,
                 dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4, use_ms_detr=False, use_aux_ffn=True):
        super().__init__()

        # cross attention
        self.cross_attn = MSDeformAttn(d_model, n_heads, n_levels, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

        self.use_ms_detr = use_ms_detr
        self.use_aux_ffn = use_aux_ffn
        
        # auxiliary ffn
        if self.use_ms_detr and self.use_aux_ffn:
            self.linear3 = nn.Linear(d_model, d_ffn)
            self.dropout5 = nn.Dropout(dropout)
            self.linear4 = nn.Linear(d_ffn, d_model)
            self.dropout6 = nn.Dropout(dropout)
            self.norm4 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_aux_ffn(self, tgt):
        tgt2 = self.linear4(self.dropout5(self.activation(self.linear3(tgt))))
        tgt = tgt + self.dropout6(tgt2)
        tgt = self.norm4(tgt)
        return tgt

    def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
        if self.use_ms_detr:
            # cross attention
            tgt2 = self.cross_attn.forward(self.with_pos_embed(tgt, query_pos),
                                   reference_points, src, src_spatial_shapes, level_start_index)
            tgt = tgt + self.dropout1(tgt2)
            tgt = self.norm1(tgt)

            if self.use_aux_ffn:
                # auxiliary ffn
                tgt_o2m = self.forward_aux_ffn(tgt)
            else:
                tgt_o2m = tgt

            # self attention
            q = k = self.with_pos_embed(tgt, query_pos)
            tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
            tgt = tgt + self.dropout2(tgt2)
            tgt = self.norm2(tgt)

            # ffn
            tgt_o2o = self.forward_ffn(tgt)

        else:
            # self attention
            q = k = self.with_pos_embed(tgt, query_pos)
            tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
            tgt = tgt + self.dropout2(tgt2)
            tgt = self.norm2(tgt)

            # cross attention
            tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), 
                                   reference_points, src, src_spatial_shapes, level_start_index)
            tgt = tgt + self.dropout1(tgt2)
            tgt = self.norm1(tgt)

            # ffn
            tgt_o2o = tgt_o2m = self.forward_ffn(tgt)
        
        return tgt_o2o, tgt_o2m

In [27]:
decoder_layer = DeformableTransformerDecoderLayer(use_ms_detr=use_ms_detr, use_aux_ffn=True)

if reference_points.shape[-1] == 4:
    reference_points_input = reference_points[:, :, None] \
                                         * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
else:
    assert reference_points.shape[-1] == 2
    reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
    
print(reference_points_input.shape)

decoder_output, decoder_output_o2m = decoder_layer.forward(tgt, query_embeds, reference_points_input, src_flatten,
                                                    spatial_shapes, level_start_index, mask_flatten)

print(decoder_output.shape)
print(decoder_output_o2m.shape)

torch.Size([2, 300, 4, 4])
torch.Size([2, 300, 256])
torch.Size([2, 300, 256])


## decoder

In [28]:
def inverse_sigmoid(x, eps=1e-5):
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1/x2)

class DeformableTransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False, look_forward_twice=False, use_ms_detr=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate
        # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
        self.bbox_embed = None
        self.class_embed = None
        self.look_forward_twice = look_forward_twice
        self.use_ms_detr = use_ms_detr

    def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
                query_pos=None, src_padding_mask=None, **kwargs):
        output = tgt

        intermediate = []
        intermediate_o2m = []
        intermediate_reference_points = []
        for lid, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                reference_points_input = reference_points[:, :, None] \
                                         * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
            output, output_o2m = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask, **kwargs)

            # hack implementation for iterative bounding box refinement
            if self.bbox_embed is not None:
                tmp = self.bbox_embed[lid](output)
                if reference_points.shape[-1] == 4:
                    new_reference_points = tmp + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                else:
                    assert reference_points.shape[-1] == 2
                    new_reference_points = tmp
                    new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                reference_points = new_reference_points.detach()

            if self.return_intermediate:
                intermediate.append(output)
                intermediate_o2m.append(output_o2m)
                intermediate_reference_points.append(
                    new_reference_points
                    if self.look_forward_twice
                    else reference_points
                )

        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(intermediate_o2m), torch.stack(intermediate_reference_points)

        return output, output_o2m, reference_points

In [29]:

decoder = DeformableTransformerDecoder(decoder_layer, num_layers, 
                    True, look_forward_twice=False, use_ms_detr=use_ms_detr)
hs_o2o, hs_o2m, inter_references = decoder.forward(tgt, reference_points, memory, spatial_shapes,
                                                   level_start_index, valid_ratios, query_embeds,
                                                   mask_flatten)

print(hs_o2o.shape)
print(hs_o2m.shape)
print(inter_references.shape)

torch.Size([6, 2, 300, 256])
torch.Size([6, 2, 300, 256])
torch.Size([6, 2, 300, 4])


# Transformer

In [30]:
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_

class DeformableTransformer(nn.Module):
    def __init__(self, d_model=256, nhead=8,
                 num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
                 activation="relu", return_intermediate_dec=False,
                 num_feature_levels=4, dec_n_points=4,  enc_n_points=4,
                 two_stage=False, two_stage_num_proposals=300):
        super().__init__()

        self.d_model = d_model
        self.nhead = nhead
        self.two_stage = two_stage
        self.two_stage_num_proposals = two_stage_num_proposals

        encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
                                                          dropout, activation,
                                                          num_feature_levels, nhead, enc_n_points)
        self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)

        decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
                                                          dropout, activation,
                                                          num_feature_levels, nhead, dec_n_points)
        self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)

        self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))

        if two_stage:
            self.enc_output = nn.Linear(d_model, d_model)
            self.enc_output_norm = nn.LayerNorm(d_model)
            self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
            self.pos_trans_norm = nn.LayerNorm(d_model * 2)
        else:
            self.reference_points = nn.Linear(d_model, 2)

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        for m in self.modules():
            if isinstance(m, MSDeformAttn):
                m._reset_parameters()
        if not self.two_stage:
            xavier_uniform_(self.reference_points.weight.data, gain=1.0)
            constant_(self.reference_points.bias.data, 0.)
        normal_(self.level_embed)

    def get_proposal_pos_embed(self, proposals):
        num_pos_feats = 128
        temperature = 10000
        scale = 2 * math.pi

        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
        # N, L, 4
        proposals = proposals.sigmoid() * scale
        # N, L, 4, 128
        pos = proposals[:, :, :, None] / dim_t
        # N, L, 4, 64, 2
        pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
        return pos

    def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
        N_, S_, C_ = memory.shape
        base_scale = 4.0
        proposals = []
        _cur = 0
        for lvl, (H_, W_) in enumerate(spatial_shapes):
            mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
            valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
            valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

            grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
                                            torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
            grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

            scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
            grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
            wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
            proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
            proposals.append(proposal)
            _cur += (H_ * W_)
        output_proposals = torch.cat(proposals, 1)
        output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
        output_proposals = torch.log(output_proposals / (1 - output_proposals))
        output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
        output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))

        output_memory = memory
        output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
        output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
        output_memory = self.enc_output_norm(self.enc_output(output_memory))
        return output_memory, output_proposals

    def get_valid_ratio(self, mask):
        _, H, W = mask.shape
        valid_H = torch.sum(~mask[:, :, 0], 1)
        valid_W = torch.sum(~mask[:, 0, :], 1)
        valid_ratio_h = valid_H.float() / H
        valid_ratio_w = valid_W.float() / W
        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
        return valid_ratio

    def forward(self, srcs, masks, pos_embeds, query_embed=None):
        assert self.two_stage or query_embed is not None

        # prepare input for encoder
        src_flatten = []
        mask_flatten = []
        lvl_pos_embed_flatten = []
        spatial_shapes = []
        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
            bs, c, h, w = src.shape
            spatial_shape = (h, w)
            spatial_shapes.append(spatial_shape)
            src = src.flatten(2).transpose(1, 2)
            mask = mask.flatten(1)
            pos_embed = pos_embed.flatten(2).transpose(1, 2)
            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
            lvl_pos_embed_flatten.append(lvl_pos_embed)
            src_flatten.append(src)
            mask_flatten.append(mask)
        src_flatten = torch.cat(src_flatten, 1)
        mask_flatten = torch.cat(mask_flatten, 1)
        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
        level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)

        # encoder
        memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)

        # prepare input for decoder
        bs, _, c = memory.shape
        if self.two_stage:
            output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)

            # hack implementation for two-stage Deformable DETR
            enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
            enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals

            topk = self.two_stage_num_proposals
            topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
            topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
            topk_coords_unact = topk_coords_unact.detach()
            reference_points = topk_coords_unact.sigmoid()
            init_reference_out = reference_points
            pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
            query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
        else:
            query_embed, tgt = torch.split(query_embed, c, dim=1)
            query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
            tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
            reference_points = self.reference_points(query_embed).sigmoid()
            init_reference_out = reference_points

        # decoder
        hs, inter_references = self.decoder(tgt, reference_points, memory,
                                            spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)

        inter_references_out = inter_references
        if self.two_stage:
            return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact, output_proposals.sigmoid()
        return hs, init_reference_out, inter_references_out, None, None, output_proposals.sigmoid()


class DeformableTransformerEncoderLayer(nn.Module):
    def __init__(self,
                 d_model=256, d_ffn=1024,
                 dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4):
        super().__init__()

        # self attention
        self.self_attn = MSDeformAttn(d_model, n_heads, n_levels, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm2(src)
        return src

    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
        # self attention
        src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # ffn
        src = self.forward_ffn(src)

        return src


class DeformableTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    @staticmethod
    def get_reference_points(spatial_shapes, valid_ratios, device):
        reference_points_list = []
        for lvl, (H_, W_) in enumerate(spatial_shapes):

            ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                          torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        return reference_points

    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
        output = src
        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
        for _, layer in enumerate(self.layers):
            output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)

        return output


class DeformableTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model=256, d_ffn=1024,
                 dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4):
        super().__init__()

        # cross attention
        self.cross_attn = MSDeformAttn(d_model, n_heads, n_levels, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
        # self attention
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # cross attention
        tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
                               reference_points,
                               src, src_spatial_shapes, level_start_index)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # ffn
        tgt = self.forward_ffn(tgt)

        return tgt


class DeformableTransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate
        # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
        self.bbox_embed = None
        self.class_embed = None

    def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
                query_pos=None, src_padding_mask=None):
        output = tgt

        intermediate = []
        intermediate_reference_points = []
        for lid, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                reference_points_input = reference_points[:, :, None] \
                                         * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
            output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)

            # hack implementation for iterative bounding box refinement
            if self.bbox_embed is not None:
                tmp = self.bbox_embed[lid](output)
                if reference_points.shape[-1] == 4:
                    new_reference_points = tmp + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                else:
                    assert reference_points.shape[-1] == 2
                    new_reference_points = tmp
                    new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                reference_points = new_reference_points.detach()

            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(intermediate_reference_points)

        return output, reference_points

In [31]:
def build_deforamble_transformer():
    return DeformableTransformer(
        d_model=d_model,
        nhead=8,
        num_encoder_layers=num_layers,
        num_decoder_layers=num_layers,
        dim_feedforward=1024,
        dropout=0.,
        activation="relu",
        return_intermediate_dec=True,
        num_feature_levels=num_feature_levels,
        dec_n_points=4,
        enc_n_points=4,
        two_stage=two_stage,
        two_stage_num_proposals=num_queries)

transformer_instance = build_deforamble_transformer()

In [32]:
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (transformer_instance.decoder.num_layers + 1) if two_stage else transformer_instance.decoder.num_layers
if with_box_refine:
    # hack implementation for iterative bounding box refinement
    transformer_instance.decoder.bbox_embed = bbox_embed
else:
    transformer_instance.decoder.bbox_embed = None
if two_stage:
    # hack implementation for two-stage
    transformer_instance.decoder.class_embed = class_embed
    
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = transformer_instance(srcs, masks, pos, query_embeds)

print(f'hs shape: {hs.shape}')
print(f'init_reference shape: {init_reference.shape}')
print(f'inter_references shape: {inter_references.shape}')
print(f'enc_outputs_class shape: {enc_outputs_class.shape}')
print(f'enc_outputs_coord_unact shape: {enc_outputs_coord_unact.shape}')
print(f'anchors shape: {anchors.shape}')

hs shape: torch.Size([6, 2, 300, 256])
init_reference shape: torch.Size([2, 300, 4])
inter_references shape: torch.Size([6, 2, 300, 4])
enc_outputs_class shape: torch.Size([2, 15073, 91])
enc_outputs_coord_unact shape: torch.Size([2, 15073, 4])
anchors shape: torch.Size([2, 15073, 4])


In [33]:
aux_loss=True

outputs_classes = []
outputs_coords = []

@torch.jit.unused
def _set_aux_loss(outputs_class, outputs_coord):
    # this is a workaround to make torchscript happy, as torchscript
    # doesn't support dictionary with non-homogeneous values, such
    # as a dict having both a Tensor and a list.
    return [{'pred_logits': a, 'pred_boxes': b}
            for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]


for lvl in range(hs.shape[0]):
    if lvl == 0:
        reference = init_reference
    else:
        reference = inter_references[lvl - 1]
    reference = inverse_sigmoid(reference)
    outputs_class = class_embed[lvl](hs[lvl])
    tmp = bbox_embed[lvl](hs[lvl])
    if reference.shape[-1] == 4:
        tmp += reference
    else:
        assert reference.shape[-1] == 2
        tmp[..., :2] += reference
    outputs_coord = tmp.sigmoid()
    
    outputs_classes.append(outputs_class)
    outputs_coords.append(outputs_coord)

outputs_class = torch.stack(outputs_classes)
outputs_coord = torch.stack(outputs_coords)
print(outputs_class.shape)
print(outputs_coord.shape)

out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if aux_loss:
    out['aux_outputs'] = _set_aux_loss(outputs_class, outputs_coord)

if use_ms_detr:
    outputs_classes_o2m = []
    outputs_coords_o2m = []
    for lvl in range(hs.shape[0]):
        if lvl == 0:
            reference = init_reference
        else:
            reference = inter_references[lvl - 1]
        reference = inverse_sigmoid(reference)
        outputs_class_o2m = class_embed_o2m[lvl](hs[lvl])
        tmp = bbox_embed_o2m[lvl](hs[lvl])
        if reference.shape[-1] == 4:
            tmp += reference
        else:
            assert reference.shape[-1] == 2
            tmp[..., :2] += reference
        outputs_coord_o2m = tmp.sigmoid()
        
        outputs_classes_o2m.append(outputs_class_o2m)
        outputs_coords_o2m.append(outputs_coord_o2m)
    
    outputs_class_o2m = torch.stack(outputs_classes_o2m)
    outputs_coord_o2m = torch.stack(outputs_coords_o2m)
    out['o2m_outputs'] = {'pred_logits': outputs_class_o2m[-1], 'pred_boxes': outputs_coord_o2m[-1]}
    if aux_loss:
                out['o2m_outputs']['aux_outputs'] = _set_aux_loss(outputs_class_o2m, outputs_coord_o2m)

if two_stage:
    enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
    out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord, 'anchors': anchors}



for k, v in out.items():
    print(f"\nKey: '{k}'") 

    if isinstance(v, torch.Tensor):
        print(f"  Type: Tensor")
        print(f"  Shape: {v.shape}")

    elif isinstance(v, list):
        print(f"  Type: List")
        print(f"  Length: {len(v)}")
        if len(v) > 0 and isinstance(v[0], dict):
            print(f"  List contains dicts. Example element 0 keys: {list(v[0].keys())}")
            print("    Shapes inside first dict element:")
            for inner_key, inner_value in v[0].items():
                if isinstance(inner_value, torch.Tensor):
                    print(f"      '{inner_key}': {inner_value.shape}")
                else:
                    print(f"      '{inner_key}': Type {type(inner_value)}")
        elif len(v) > 0:
                print(f"  List contains elements of type: {type(v[0])}")


    elif isinstance(v, dict):
        print(f"  Type: Dictionary")
        print(f"  \nKeys: {list(v.keys())}")
        # print(f"  Content: {v}")

    else:
        # 其他类型
        print(f"  Type: {type(v)}")
        # print(f"  Value: {v}")

print('##########################################################################################')

for k, v in out['o2m_outputs'].items():

    print(f"\nKey: '{k}'") 

    if isinstance(v, torch.Tensor):
        print(f"  Type: Tensor")
        print(f"  Shape: {v.shape}")

    elif isinstance(v, list):
        print(f"  Type: List")
        print(f"  Length: {len(v)}")
        if len(v) > 0 and isinstance(v[0], dict):
            print(f"  List contains dicts. Example element 0 keys: {list(v[0].keys())}")
            print("    Shapes inside first dict element:")
            for inner_key, inner_value in v[0].items():
                if isinstance(inner_value, torch.Tensor):
                    print(f"      '{inner_key}': {inner_value.shape}")
                else:
                    print(f"      '{inner_key}': Type {type(inner_value)}")
        elif len(v) > 0:
                print(f"  List contains elements of type: {type(v[0])}")


    elif isinstance(v, dict):
        print(f"  Type: Dictionary")
        print(f"  \nKeys: {list(v.keys())}")
        # print(f"  Content: {v}")

    else:
        # 其他类型
        print(f"  Type: {type(v)}")
        # print(f"  Value: {v}")
        

print('##########################################################################################')

for k, v in out['enc_outputs'].items():
    print(f'{k}: {v.shape}')

torch.Size([6, 2, 300, 91])
torch.Size([6, 2, 300, 4])

Key: 'pred_logits'
  Type: Tensor
  Shape: torch.Size([2, 300, 91])

Key: 'pred_boxes'
  Type: Tensor
  Shape: torch.Size([2, 300, 4])

Key: 'aux_outputs'
  Type: List
  Length: 5
  List contains dicts. Example element 0 keys: ['pred_logits', 'pred_boxes']
    Shapes inside first dict element:
      'pred_logits': torch.Size([2, 300, 91])
      'pred_boxes': torch.Size([2, 300, 4])

Key: 'o2m_outputs'
  Type: Dictionary
  
Keys: ['pred_logits', 'pred_boxes', 'aux_outputs']

Key: 'enc_outputs'
  Type: Dictionary
  
Keys: ['pred_logits', 'pred_boxes', 'anchors']
##########################################################################################

Key: 'pred_logits'
  Type: Tensor
  Shape: torch.Size([2, 300, 91])

Key: 'pred_boxes'
  Type: Tensor
  Shape: torch.Size([2, 300, 4])

Key: 'aux_outputs'
  Type: List
  Length: 5
  List contains dicts. Example element 0 keys: ['pred_logits', 'pred_boxes']
    Shapes inside first dict

# DETR

In [34]:
class DeformableDETR(nn.Module):
    """ This is the Deformable DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
                 aux_loss=True, with_box_refine=False, two_stage=False, use_ms_detr=False):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_classes: number of object classes
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
            with_box_refine: iterative bounding box refinement
            two_stage: two-stage Deformable DETR
        """
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.num_feature_levels = num_feature_levels
        self.use_ms_detr = use_ms_detr

        if not two_stage:
            self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
        
        if num_feature_levels > 1:
            num_backbone_outs = len(backbone.strides)
            input_proj_list = []
            for _ in range(num_backbone_outs):
                in_channels = backbone.num_channels[_]
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
                in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            self.input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                )])
        self.backbone = backbone
        self.aux_loss = aux_loss
        self.with_box_refine = with_box_refine
        self.two_stage = two_stage

        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
        for proj in self.input_proj:
            nn.init.xavier_uniform_(proj[0].weight, gain=1)
            nn.init.constant_(proj[0].bias, 0)

        # if two-stage, the last class_embed and bbox_embed is for region proposal generation
        num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
        if with_box_refine:
            self.class_embed = _get_clones(self.class_embed, num_pred)
            self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
            nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
            # hack implementation for iterative bounding box refinement
            self.transformer.decoder.bbox_embed = self.bbox_embed
        else:
            nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
            self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
            self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
            self.transformer.decoder.bbox_embed = None
        if two_stage:
            # hack implementation for two-stage
            self.transformer.decoder.class_embed = self.class_embed
            for box_embed in self.bbox_embed:
                nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
        
        if self.use_ms_detr:
            # NOTE: in implementation (a), we share the box heads for o2o and o2m branches, but do not share the class heads
            self.class_embed_o2m = copy.deepcopy(self.class_embed[:-1])
            self.bbox_embed_o2m = self.bbox_embed[:-1]

    def forward(self, samples: NestedTensor):
        """The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits (including no-object) for all queries.
                                Shape= [batch_size x num_queries x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image (disregarding possible padding).
                               See PostProcess for information on how to retrieve the unnormalized bounding box.
               - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                                dictionnaries containing the two above keys for each decoder layer.
        """
        if not isinstance(samples, NestedTensor):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        srcs = []
        masks = []
        for l, feat in enumerate(features):
            src, mask = feat.decompose()
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            assert mask is not None
        if self.num_feature_levels > len(srcs):
            _len_srcs = len(srcs)
            for l in range(_len_srcs, self.num_feature_levels):
                if l == _len_srcs:
                    src = self.input_proj[l](features[-1].tensors)
                else:
                    src = self.input_proj[l](srcs[-1])
                m = samples.mask
                mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
                pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
                srcs.append(src)
                masks.append(mask)
                pos.append(pos_l)

        query_embeds = None
        if not self.two_stage:
            query_embeds = self.query_embed.weight

        hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = self.transformer(srcs, masks, pos, query_embeds)

        outputs_classes = []
        outputs_coords = []
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            outputs_class = self.class_embed[lvl](hs[lvl])
            tmp = self.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                assert reference.shape[-1] == 2
                tmp[..., :2] += reference
            outputs_coord = tmp.sigmoid()
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)
        outputs_class = torch.stack(outputs_classes)
        outputs_coord = torch.stack(outputs_coords)

        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)

        if self.use_ms_detr:
            outputs_classes_o2m = []
            outputs_coords_o2m = []
            for lvl in range(hs.shape[0]):
                if lvl == 0:
                    reference = init_reference
                else:
                    reference = inter_references[lvl - 1]
                reference = inverse_sigmoid(reference)
                outputs_class_o2m = self.class_embed_o2m[lvl](hs[lvl])
                tmp = self.bbox_embed_o2m[lvl](hs[lvl])
                if reference.shape[-1] == 4:
                    tmp += reference
                else:
                    assert reference.shape[-1] == 2
                    tmp[..., :2] += reference
                outputs_coord_o2m = tmp.sigmoid()
                outputs_classes_o2m.append(outputs_class_o2m)
                outputs_coords_o2m.append(outputs_coord_o2m)
            outputs_class_o2m = torch.stack(outputs_classes_o2m)
            outputs_coord_o2m = torch.stack(outputs_coords_o2m)
            out['o2m_outputs'] = {'pred_logits': outputs_class_o2m[-1], 'pred_boxes': outputs_coord_o2m[-1]}
            if self.aux_loss:
                out['o2m_outputs']['aux_outputs'] = self._set_aux_loss(outputs_class_o2m, outputs_coord_o2m)

        if self.two_stage:
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord, 'anchors': anchors}

        return out

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

In [35]:
model = DeformableDETR(
    test_joiner,
    transformer_instance,
    num_classes=91,
    num_queries=num_queries,
    num_feature_levels=num_feature_levels,
    aux_loss=True,
    with_box_refine=with_box_refine,
    two_stage=two_stage,
    use_ms_detr=use_ms_detr
)

out_detr = model(samples)

for k, v in out.items():
    print(f"\nKey: '{k}'") 

    if isinstance(v, torch.Tensor):
        print(f"  Type: Tensor")
        print(f"  Shape: {v.shape}")

    elif isinstance(v, list):
        print(f"  Type: List")
        print(f"  Length: {len(v)}")
        if len(v) > 0 and isinstance(v[0], dict):
            print(f"  List contains dicts. Example element 0 keys: {list(v[0].keys())}")
            print("    Shapes inside first dict element:")
            for inner_key, inner_value in v[0].items():
                if isinstance(inner_value, torch.Tensor):
                    print(f"      '{inner_key}': {inner_value.shape}")
                else:
                    print(f"      '{inner_key}': Type {type(inner_value)}")
        elif len(v) > 0:
                print(f"  List contains elements of type: {type(v[0])}")


    elif isinstance(v, dict):
        print(f"  Type: Dictionary")
        print(f"  \nKeys: {list(v.keys())}")
        # print(f"  Content: {v}")

    else:
        # 其他类型
        print(f"  Type: {type(v)}")
        # print(f"  Value: {v}")


Key: 'pred_logits'
  Type: Tensor
  Shape: torch.Size([2, 300, 91])

Key: 'pred_boxes'
  Type: Tensor
  Shape: torch.Size([2, 300, 4])

Key: 'aux_outputs'
  Type: List
  Length: 5
  List contains dicts. Example element 0 keys: ['pred_logits', 'pred_boxes']
    Shapes inside first dict element:
      'pred_logits': torch.Size([2, 300, 91])
      'pred_boxes': torch.Size([2, 300, 4])

Key: 'o2m_outputs'
  Type: Dictionary
  
Keys: ['pred_logits', 'pred_boxes', 'aux_outputs']

Key: 'enc_outputs'
  Type: Dictionary
  
Keys: ['pred_logits', 'pred_boxes', 'anchors']


# matcher

In [36]:
from torchvision.ops.boxes import box_area

# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    union = area1[:, None] + area2 - inter

    iou = inter / union
    return iou, union

def generalized_box_iou(boxes1, boxes2):
    """
    Generalized IoU from https://giou.stanford.edu/

    The boxes should be in [x0, y0, x1, y1] format

    Returns a [N, M] pairwise matrix, where N = len(boxes1)
    and M = len(boxes2)
    """
    # degenerate boxes gives inf / nan results
    # so do an early check
    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
    iou, union = box_iou(boxes1, boxes2)

    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    area = wh[:, :, 0] * wh[:, :, 1]

    return iou - (area - union) / area

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self,
                 cost_class: float = 1,
                 cost_bbox: float = 1,
                 cost_giou: float = 1):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

    def forward(self, outputs, targets):
        """ Performs the matching

        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates

            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        with torch.no_grad():
            bs, num_queries = outputs["pred_logits"].shape[:2]

            # We flatten to compute the cost matrices in a batch
            out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

            # Also concat the target labels and boxes
            tgt_ids = torch.cat([v["labels"] for v in targets])
            tgt_bbox = torch.cat([v["boxes"] for v in targets])

            # Compute the classification cost.
            alpha = 0.25
            gamma = 2.0
            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
            cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

            # Compute the L1 cost between boxes
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

            # Compute the giou cost betwen boxes
            cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
                                             box_cxcywh_to_xyxy(tgt_bbox))

            # Final cost matrix
            C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
            C = C.view(bs, num_queries, -1).cpu()

            sizes = [len(v["boxes"]) for v in targets]
            indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

In [37]:
set_cost_class = 2
set_cost_bbox = 5
set_cost_giou = 2

def build_matcher():
    return HungarianMatcher(cost_class=set_cost_class,
                            cost_bbox=set_cost_bbox,
                            cost_giou=set_cost_giou)

matcher = build_matcher()

matcher_result = matcher.forward(out_detr, target)
print(matcher_result)

[(tensor([ 31, 238, 270]), tensor([1, 2, 0])), (tensor([49]), tensor([0]))]


## matcher_o2m

In [38]:
def nonzero_tuple(x):
    """
    A 'as_tuple=True' version of torch.nonzero to support torchscript.
    because of https://github.com/pytorch/pytorch/issues/38718
    返回给定tensor中非0元素的索引
    """
    if torch.jit.is_scripting():
        if x.dim() == 0:
            return x.unsqueeze(0).nonzero().unbind(1)
        return x.nonzero().unbind(1)
    else:
        return x.nonzero(as_tuple=True)


def subsample_labels(
    labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
):
    """
    Return `num_samples` (or fewer, if not enough found)
    random samples from `labels` which is a mixture of positives & negatives.
    It will try to return as many positives as possible without
    exceeding `positive_fraction * num_samples`, and then try to
    fill the remaining slots with negatives.

    Args:
        labels (Tensor): (N, ) label vector with values:
            * -1: ignore
            * bg_label: background ("negative") class
            * otherwise: one or more foreground ("positive") classes
        num_samples (int): The total number of labels with value >= 0 to return.
            Values that are not sampled will be filled with -1 (ignore).
        positive_fraction (float): The number of subsampled labels with values > 0
            is `min(num_positives, int(positive_fraction * num_samples))`. The number
            of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
            In order words, if there are not enough positives, the sample is filled with
            negatives. If there are also not enough negatives, then as many elements are
            sampled as is possible.
        bg_label (int): label index of background ("negative") class.

    Returns:
        pos_idx, neg_idx (Tensor):
            1D vector of indices. The total length of both is `num_samples` or fewer.
    """
    positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
    negative = nonzero_tuple(labels == bg_label)[0]

    num_pos = int(num_samples * positive_fraction)
    # protect against not enough positive examples
    num_pos = min(positive.numel(), num_pos)
    num_neg = num_samples - num_pos
    # protect against not enough negative examples
    num_neg = min(negative.numel(), num_neg)

    # randomly select positive and negative examples
    perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
    perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]

    pos_idx = positive[perm1]
    neg_idx = negative[perm2]
    return pos_idx, neg_idx

def sample_topk_per_gt(pr_inds, gt_inds, cost_matrix, k):
    """
    pr_inds (tensor): tensor of shape (M,)
    gt_inds (tensor): tensor of shape (M,)
    cost_matrix (tensor): tensor of shape (num_targets, num_queries)
    """
    if len(gt_inds) == 0:
        return pr_inds, gt_inds
    # find topk matches for each gt
    gt_inds2, counts = gt_inds.unique(return_counts=True)
    scores, pr_inds2 = cost_matrix[gt_inds2].topk(k, dim=1)
    gt_inds2 = gt_inds2[:,None].repeat(1, k)

    # filter to as many matches that gt has
    pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)])
    gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)])
    scores = torch.cat([s[:c] for c, s in zip(counts, scores)])
    
    # assign query to gt with highest match score
    score_sorted_inds = scores.argsort(descending=False)
    pr_inds3 = pr_inds3[score_sorted_inds]
    gt_inds3 = gt_inds3[score_sorted_inds]

    return pr_inds3, gt_inds3

class Matcher(object):
    """
    This class assigns to each predicted "element" (e.g., a box) a ground-truth
    element. Each predicted element will have exactly zero or one matches; each
    ground-truth element may be matched to zero or more predicted elements.

    The matching is determined by the MxN match_quality_matrix, that characterizes
    how well each (ground-truth, prediction)-pair match each other. For example,
    if the elements are boxes, this matrix may contain box intersection-over-union
    overlap values.

    The matcher returns (a) a vector of length N containing the index of the
    ground-truth element m in [0, M) that matches to prediction n in [0, N).
    (b) a vector of length N containing the labels for each prediction.
    """

    def __init__(
        self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
    ):
        """
        Args:
            thresholds (list): a list of thresholds used to stratify predictions
                into levels.
            labels (list): a list of values to label predictions belonging at
                each level. A label can be one of {-1, 0, 1} signifying
                {ignore, negative class, positive class}, respectively.
            allow_low_quality_matches (bool): if True, produce additional matches
                for predictions with maximum match quality lower than high_threshold.
                See set_low_quality_matches_ for more details.

            For example,
                thresholds = [0.3, 0.5]
                labels = [0, -1, 1]
                All predictions with iou < 0.3 will be marked with 0 and
                thus will be considered as false positives while training.
                All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
                thus will be ignored.
                All predictions with 0.5 <= iou will be marked with 1 and
                thus will be considered as true positives.
        """
        # Add -inf and +inf to first and last position in thresholds
        thresholds = thresholds[:]
        # assert thresholds[0] > 0
        thresholds.insert(0, -float("inf"))
        thresholds.append(float("inf"))
        # Currently torchscript does not support all + generator
        assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]), thresholds
        assert all([l in [-1, 0, 1] for l in labels])
        assert len(labels) == len(thresholds) - 1
        self.thresholds = thresholds
        self.labels = labels
        self.allow_low_quality_matches = allow_low_quality_matches

    def __call__(self, match_quality_matrix):
        """
        Args:
            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
                pairwise quality between M ground-truth elements and N predicted
                elements. All elements must be >= 0 (due to the us of `torch.nonzero`
                for selecting indices in :meth:`set_low_quality_matches_`).

        Returns:
            matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
                ground-truth index in [0, M)
            match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
                whether a prediction is a true or false positive or ignored
        """
        assert match_quality_matrix.dim() == 2
        if match_quality_matrix.numel() == 0:
            default_matches = match_quality_matrix.new_full(
                (match_quality_matrix.size(1),), 0, dtype=torch.int64
            )
            # When no gt boxes exist, we define IOU = 0 and therefore set labels
            # to `self.labels[0]`, which usually defaults to background class 0
            # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
            default_match_labels = match_quality_matrix.new_full(
                (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
            )
            return default_matches, default_match_labels

        assert torch.all(match_quality_matrix >= 0)

        # match_quality_matrix is M (gt) x N (predicted)
        # Max over gt elements (dim 0) to find best gt candidate for each prediction
        matched_vals, matches = match_quality_matrix.max(dim=0)

        match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)

        for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
            low_high = (matched_vals >= low) & (matched_vals < high)
            match_labels[low_high] = l

        if self.allow_low_quality_matches:
            self.set_low_quality_matches_(match_labels, match_quality_matrix)

        return matches, match_labels

    def set_low_quality_matches_(self, match_labels, match_quality_matrix, k=1):
        """
        Produce additional matches for predictions that have only low-quality matches.
        Specifically, for each ground-truth G find the set of predictions that have
        maximum overlap with it (including ties); for each prediction in that set, if
        it is unmatched, then match it to the ground-truth G.

        This function implements the RPN assignment case (i) in Sec. 3.1.2 of
        :paper:`Faster R-CNN`.
        """
        highest_quality_foreach_gt_inds = match_quality_matrix.topk(k=k, dim=1)[1]
        match_labels[highest_quality_foreach_gt_inds.flatten()] = 1


# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123
class Stage2Assigner(nn.Module):
    def __init__(self, threshold=0.4, k=6, coef_box=0.7, coef_cls=0.3):
        super().__init__()
        self.positive_fraction = 0.25
        self.bg_label = 400  # number > 91 to filter out later
        # self.batch_size_per_image = num_queries
        self.k = k
        self.coef_box = coef_box
        self.coef_cls = coef_cls

        self.proposal_matcher = Matcher(thresholds=[threshold], labels=[0, 1], allow_low_quality_matches=True)

    def _sample_proposals(
        self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor, batch_size_per_image: int
    ):
        """
        Based on the matching between N proposals and M groundtruth,
        sample the proposals and set their classification labels.

        Args:
            matched_idxs (Tensor): a vector of length N, each is the best-matched
                gt index in [0, M) for each proposal.
            matched_labels (Tensor): a vector of length N, the matcher's label
                (one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
            gt_classes (Tensor): a vector of length M.

        Returns:
            Tensor: a vector of indices of sampled proposals. Each is in [0, N).
            Tensor: a vector of the same length, the classification label for
                each sampled proposal. Each sample is labeled as either a category in
                [0, num_classes) or the background (num_classes).
        """
        has_gt = gt_classes.numel() > 0
        # Get the corresponding GT for each proposal
        if has_gt:
            gt_classes = gt_classes[matched_idxs]
            # Label unmatched proposals (0 label from matcher) as background (label=num_classes)
            gt_classes[matched_labels == 0] = self.bg_label
            # Label ignore proposals (-1 label)
            gt_classes[matched_labels == -1] = -1
        else:
            gt_classes = torch.zeros_like(matched_idxs) + self.bg_label
        
        sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
            gt_classes, batch_size_per_image, self.positive_fraction, self.bg_label
        )
        
        sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
        return sampled_idxs, gt_classes[sampled_idxs]
    
    def _process_proposals(self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor):
        has_gt = gt_classes.numel() > 0
        if has_gt:
            gt_classes = gt_classes[matched_idxs]
            gt_classes[matched_labels == 0] = self.bg_label
            gt_classes[matched_labels == -1] = -1
        else:
            gt_classes = torch.zeros_like(matched_idxs) + self.bg_label
        
        return gt_classes

    def _select_topk_per_gt(self, pr_inds, gt_inds, cost_matrix, max_k=6):
        if len(gt_inds) == 0:
            return pr_inds, gt_inds

        scores = cost_matrix[gt_inds, pr_inds]
        final_pr_inds = []
        final_gt_inds = []
        for gt_idx in gt_inds.unique():
            indices = torch.argwhere(gt_inds == gt_idx).flatten()
            selected_pr = pr_inds[indices]
            selected_scores = scores[indices]
            if len(indices) < max_k:
                final_pr_inds.append(selected_pr)
                final_gt_inds.append(gt_idx.repeat(len(indices)))
                continue
            topk_indices = selected_scores.topk(k=max_k)[1]
            selected_pr = selected_pr[topk_indices]
            final_pr_inds.append(selected_pr)
            final_gt_inds.append(gt_idx.repeat(max_k))

        return torch.cat(final_pr_inds), torch.cat(final_gt_inds)

    def postprocess_indices(self, pr_inds, gt_inds, iou, k):
        return sample_topk_per_gt(pr_inds, gt_inds, iou, k)

    @torch.no_grad()
    def get_cost_matrix(self, pred_logits, pred_boxes, gt_classes, gt_boxes):
        num_queries = len(pred_logits)
        out_prob = pred_logits.sigmoid()
        out_bbox = pred_boxes

        cost_box = box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(gt_boxes))[0]
        
        cost_class = out_prob[:, gt_classes]

        C = self.coef_box * cost_box + self.coef_cls * cost_class
        C = C.view(num_queries, -1)

        return C.T

    def forward(self, outputs, targets, return_cost_matrix=False):
        # COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid.
        bs, num_queries = outputs['pred_logits'].shape[:2]

        indices = []
        cost_matrices = []

        with torch.no_grad():
            for b in range(bs):
                pred_logits = outputs['pred_logits'][b].detach()
                pred_boxes = outputs['pred_boxes'][b]
                gt_boxes = targets[b]['boxes']
                gt_classes = targets[b]['labels']

                cost_matrix = self.get_cost_matrix(pred_logits, pred_boxes, gt_classes, gt_boxes)

                matched_idxs, matched_labels = self.proposal_matcher(cost_matrix)
                sampled_idxs, sampled_gt_classes = self._sample_proposals(
                    matched_idxs, matched_labels, targets[b]['labels'], batch_size_per_image=num_queries
                )
                pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
                pos_gt_inds = matched_idxs[pos_pr_inds]

                pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, cost_matrix, self.k)
                indices.append((pos_pr_inds, pos_gt_inds))
                cost_matrices.append(cost_matrix)

        if return_cost_matrix:
            return indices, cost_matrices
        return indices

In [39]:
o2m_matcher_k = 6
o2m_matcher_threshold=0.4

matcher_o2m = Stage2Assigner(k=o2m_matcher_k, threshold=o2m_matcher_threshold)

matcher_o2m_result = matcher_o2m.forward(out_detr, target)

print(matcher_o2m_result)

[(tensor([31, 31]), tensor([1, 0])), (tensor([49]), tensor([0]))]


# labels

In [40]:
def _get_src_permutation_idx(indices):
    # permute predictions following indices
    batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
    src_idx = torch.cat([src for (src, _) in indices])
    return batch_idx, src_idx

def _get_tgt_permutation_idx(indices):
    # permute targets following indices
    batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
    tgt_idx = torch.cat([tgt for (_, tgt) in indices])
    return batch_idx, tgt_idx


batch_idx, src_idx = _get_src_permutation_idx(matcher_result)
print(batch_idx, src_idx)

for i, idx in enumerate(matcher_result):
        pred_indices = idx[0]
        target_indices = idx[1]
        original_gt_labels = target[i]['labels']
        
        print(f'target_{i}: {idx} | origin target: {original_gt_labels[target_indices]}')


print("-----------------------------------------------------------------------------------------------")

for i, idx in enumerate(matcher_o2m_result):
        pred_indices = idx[0]
        target_indices = idx[1]
        original_gt_labels = target[i]['labels']
        
        print(f'target_{i}: {idx} | origin target: {original_gt_labels[target_indices]}')


tensor([0, 0, 0, 1]) tensor([ 31, 238, 270,  49])
target_0: (tensor([ 31, 238, 270]), tensor([1, 2, 0])) | origin target: tensor([ 1, 17,  2])
target_1: (tensor([49]), tensor([0])) | origin target: tensor([70])
-----------------------------------------------------------------------------------------------
target_0: (tensor([31, 31]), tensor([1, 0])) | origin target: tensor([1, 2])
target_1: (tensor([49]), tensor([0])) | origin target: tensor([70])


# loss

In [41]:
outputs_without_aux = {k: v for k, v in out_detr.items() if k != 'aux_outputs' and k != 'enc_outputs'}

# store one-to-one indices for indices merge
o2o_indices_list = []
# Retrieve the matching between the outputs of the last layer and the targets
indices = matcher(outputs_without_aux, target)

o2o_indices_list.append(indices)

 # Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in target)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(out_detr.values())).device)

In [44]:
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_boxes

@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    if target.numel() == 0:
        return [torch.zeros([], device=output.device)]
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


focal_alpha=0.25

def loss_labels(outputs, targets, indices, num_boxes, log=True):
    assert 'pred_logits' in outputs
    
    src_logits = outputs['pred_logits']
    
    idx = _get_src_permutation_idx(indices)
    target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
    target_classes = torch.full(src_logits.shape[:2], 91,
                                    dtype=torch.int64, device=src_logits.device)
    target_classes[idx] = target_classes_o

    target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
    target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

    target_classes_onehot = target_classes_onehot[:, :, :-1]
    loss_ce = (sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=focal_alpha, gamma=2)
                   * src_logits.shape[1])
    losses = {'loss_ce': loss_ce}

    if log:
        # TODO this should probably be a separate loss, not hacked in this one here
        losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
    return losses


losses_ce = loss_labels(outputs_without_aux, target, indices, num_boxes)

print(losses_ce)

{'loss_ce': tensor([1.0837], grad_fn=<MulBackward0>), 'class_error': tensor(100.)}


In [45]:
@torch.no_grad()
def loss_cardinality(outputs, targets, indices, num_boxes):
    """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
    This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
    """
    pred_logits = outputs['pred_logits']
    device = pred_logits.device
    tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
    # Count the number of predictions that are NOT "no-object" (which is the last class)
    card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
    card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
    losses = {'cardinality_error': card_err}
    return losses

losses_cardinal = loss_cardinality(outputs_without_aux, target, indices, num_boxes)
print(losses_cardinal)

{'cardinality_error': tensor(298.)}


In [46]:
def loss_boxes(outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = _get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(generalized_box_iou(
            box_cxcywh_to_xyxy(src_boxes),
            box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

losses_boxes = loss_boxes(outputs_without_aux, target, indices, num_boxes)
print(losses_boxes)

{'loss_bbox': tensor([1.1377], grad_fn=<DivBackward0>), 'loss_giou': tensor([1.1354], grad_fn=<DivBackward0>)}


In [None]:
def get_loss(loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': loss_labels,
            'cardinality': loss_cardinality,
            'boxes': loss_boxes,
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)


"""
在处理同一个预测索引的冲突时，设定了 O2M 结果优先于 O2O 结果的规则
"""
@staticmethod
def indices_merge(num_queries, o2o_indices, o2m_indices):
    bs = len(o2o_indices)
    temp_indices = torch.zeros(bs, num_queries, dtype=torch.int64) - 1
    new_one2many_indices = []

    for i in range(bs):
        one2many_fg_inds = o2m_indices[i][0]
        one2many_gt_inds = o2m_indices[i][1]
        one2one_fg_inds = o2o_indices[i][0]
        one2one_gt_inds = o2o_indices[i][1]
        temp_indices[i][one2one_fg_inds] = one2one_gt_inds
        temp_indices[i][one2many_fg_inds] = one2many_gt_inds
        fg_inds = torch.nonzero(temp_indices[i] >= 0).squeeze(1)
        # fg_inds = torch.argwhere(temp_indices[i] >= 0).squeeze(1)
        gt_inds = temp_indices[i][fg_inds]
        new_one2many_indices.append((fg_inds, gt_inds))

    return new_one2many_indices

new_o2m_indices = indices_merge(num_queries, matcher_result, matcher_o2m_result)
print(matcher_result)
print(matcher_o2m_result)
print(new_o2m_indices)

[(tensor([ 31, 238, 270]), tensor([1, 2, 0])), (tensor([49]), tensor([0]))]
[(tensor([31, 31]), tensor([1, 0])), (tensor([49]), tensor([0]))]
[(tensor([ 31, 238, 270]), tensor([0, 2, 0])), (tensor([49]), tensor([0]))]


In [None]:
losses = ['labels', 'boxes', 'cardinality']
use_indices_merge = False
losses_dict = {}
for loss in losses:
    kwargs = {}
    losses_dict.update(get_loss(loss, out_detr, target, matcher_result, num_boxes, **kwargs))


# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in out_detr:
    for i, aux_outputs in enumerate(out_detr['aux_outputs']):
        indices = matcher(aux_outputs, target)
        o2o_indices_list.append(target)
        for loss in losses:
            kwargs = {}
            if loss == 'labels':
                kwargs['log'] = False
            l_dict = get_loss(loss, aux_outputs, target, indices, num_boxes, **kwargs)
            l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
            losses_dict.update(l_dict)
            

# one-to-many losses
if 'o2m_outputs' in out_detr:
    o2m_outputs = out_detr['o2m_outputs']
    indices = matcher_o2m(o2m_outputs, target)

    if use_indices_merge:
        o2o_indices = o2o_indices_list.pop(0)
        indices = indices_merge(num_queries, o2o_indices, indices)

    for loss in losses:
        kwargs = {}
        l_dict = get_loss(loss, o2m_outputs, target, indices, num_boxes, **kwargs)
        l_dict = {k + '_o2m': v for k, v in l_dict.items()}
        losses_dict.update(l_dict)

    if "aux_outputs" in o2m_outputs:
        for i, aux_outputs in enumerate(o2m_outputs['aux_outputs']):
            indices = matcher_o2m(aux_outputs, target)

            if use_indices_merge:
                o2o_indices = o2o_indices_list[i]
                indices = indices_merge(num_queries, o2o_indices, indices)
            
            for loss in losses:
                if loss == 'masks':
                    # Intermediate masks losses are too costly to compute, we ignore them.
                    continue
                kwargs = {}
                if loss == 'labels':
                    # Logging is enabled only for the last layer
                    kwargs['log'] = False
                l_dict = get_loss(loss, aux_outputs, target, indices, num_boxes, **kwargs)
                l_dict = {k + f'_{i}_o2m': v for k, v in l_dict.items()}
                losses_dict.update(l_dict)


use_anchors_enc_match=True

if use_anchors_enc_match:
    enc_matcher = HungarianMatcher(cost_class=0, cost_bbox=matcher.cost_bbox, cost_giou=matcher.cost_giou)
else:
    enc_matcher = matcher

if 'enc_outputs' in out_detr:
    enc_outputs = out_detr['enc_outputs']
    bin_targets = copy.deepcopy(target)
    for bt in bin_targets:
        bt['labels'] = torch.zeros_like(bt['labels'])

    # NOTE: this is a hack to use anchors for encoder matching, after matching we need to restore pred_boxes for computing loss
    if use_anchors_enc_match:
        enc_outputs['pred_boxes'], enc_outputs['anchors'] = enc_outputs['anchors'], enc_outputs['pred_boxes']
        
    indices = enc_matcher(enc_outputs, bin_targets)
    
    if use_anchors_enc_match:
        enc_outputs['pred_boxes'], enc_outputs['anchors'] = enc_outputs['anchors'], enc_outputs['pred_boxes']

    for loss in losses:
        if loss == 'masks':
            # Intermediate masks losses are too costly to compute, we ignore them.
            continue
        kwargs = {}
        if loss == 'labels':
            # Logging is enabled only for the last layer
            kwargs['log'] = False
        l_dict = get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
        l_dict = {k + '_enc': v for k, v in l_dict.items()}
        losses_dict.update(l_dict)



for k, v in losses_dict.items():
    print(f'{k}: {v}')

done
loss_ce: tensor([1.0837], grad_fn=<MulBackward0>)
class_error: 100.0
loss_bbox: tensor([1.1377], grad_fn=<DivBackward0>)
loss_giou: tensor([1.1354], grad_fn=<DivBackward0>)
cardinality_error: 298.0
loss_ce_0: tensor([1.1568], grad_fn=<MulBackward0>)
loss_bbox_0: tensor([1.1377], grad_fn=<DivBackward0>)
loss_giou_0: tensor([1.1354], grad_fn=<DivBackward0>)
cardinality_error_0: 298.0
loss_ce_1: tensor([1.1611], grad_fn=<MulBackward0>)
loss_bbox_1: tensor([1.1377], grad_fn=<DivBackward0>)
loss_giou_1: tensor([1.1354], grad_fn=<DivBackward0>)
cardinality_error_1: 298.0
loss_ce_2: tensor([1.1435], grad_fn=<MulBackward0>)
loss_bbox_2: tensor([1.1377], grad_fn=<DivBackward0>)
loss_giou_2: tensor([1.1354], grad_fn=<DivBackward0>)
cardinality_error_2: 298.0
loss_ce_3: tensor([1.1635], grad_fn=<MulBackward0>)
loss_bbox_3: tensor([1.1377], grad_fn=<DivBackward0>)
loss_giou_3: tensor([1.1354], grad_fn=<DivBackward0>)
cardinality_error_3: 298.0
loss_ce_4: tensor([1.1735], grad_fn=<MulBackward0

# Criterion

In [72]:
class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, losses, num_queries, focal_alpha=0.25, o2m_matcher_threshold=0.4, o2m_matcher_k=6, use_indices_merge=False, use_anchors_enc_match=True):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            losses: list of all the losses to be applied. See get_loss for list of available losses.
            focal_alpha: alpha in Focal Loss
        """
        super().__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_alpha = focal_alpha
        self.use_indices_merge = use_indices_merge
        self.use_anchors_enc_match = use_anchors_enc_match
        if use_anchors_enc_match:
            self.enc_matcher = HungarianMatcher(cost_class=0, cost_bbox=matcher.cost_bbox, cost_giou=matcher.cost_giou)
        else:
            self.enc_matcher = matcher
        self.matcher_o2m = Stage2Assigner(k=o2m_matcher_k, threshold=o2m_matcher_threshold)


    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:, :, :-1]
        loss_ce = (sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
                   * src_logits.shape[1])
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(generalized_box_iou(
            box_cxcywh_to_xyxy(src_boxes),
            box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses


    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    @staticmethod
    def indices_merge(num_queries, o2o_indices, o2m_indices):
        bs = len(o2o_indices)
        temp_indices = torch.zeros(bs, num_queries, dtype=torch.int64).cuda() - 1
        new_one2many_indices = []

        for i in range(bs):
            one2many_fg_inds = o2m_indices[i][0].cuda()
            one2many_gt_inds = o2m_indices[i][1].cuda()
            one2one_fg_inds = o2o_indices[i][0].cuda()
            one2one_gt_inds = o2o_indices[i][1].cuda()
            temp_indices[i][one2one_fg_inds] = one2one_gt_inds
            temp_indices[i][one2many_fg_inds] = one2many_gt_inds
            fg_inds = torch.nonzero(temp_indices[i] >= 0).squeeze(1)
            # fg_inds = torch.argwhere(temp_indices[i] >= 0).squeeze(1)
            gt_inds = temp_indices[i][fg_inds]
            new_one2many_indices.append((fg_inds, gt_inds))

        return new_one2many_indices

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
        # store one-to-one indices for indices merge
        o2o_indices_list = []

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        o2o_indices_list.append(indices)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            kwargs = {}
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                o2o_indices_list.append(indices)
                for loss in self.losses:
                    if loss == 'masks':
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs['log'] = False
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        # one-to-many losses
        if 'o2m_outputs' in outputs:
            o2m_outputs = outputs['o2m_outputs']
            indices = self.matcher_o2m(o2m_outputs, targets)

            if self.use_indices_merge:
                o2o_indices = o2o_indices_list.pop(0)
                indices = self.indices_merge(self.num_queries, o2o_indices, indices)

            for loss in self.losses:
                kwargs = {}
                l_dict = self.get_loss(loss, o2m_outputs, targets, indices, num_boxes, **kwargs)
                l_dict = {k + '_o2m': v for k, v in l_dict.items()}
                losses.update(l_dict)

            if "aux_outputs" in o2m_outputs:
                for i, aux_outputs in enumerate(o2m_outputs['aux_outputs']):
                    indices = self.matcher_o2m(aux_outputs, targets)

                    if self.use_indices_merge:
                        o2o_indices = o2o_indices_list[i]
                        indices = self.indices_merge(self.num_queries, o2o_indices, indices)
                    
                    for loss in self.losses:
                        if loss == 'masks':
                            # Intermediate masks losses are too costly to compute, we ignore them.
                            continue
                        kwargs = {}
                        if loss == 'labels':
                            # Logging is enabled only for the last layer
                            kwargs['log'] = False
                        l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                        l_dict = {k + f'_{i}_o2m': v for k, v in l_dict.items()}
                        losses.update(l_dict)

        if 'enc_outputs' in outputs:
            enc_outputs = outputs['enc_outputs']
            bin_targets = copy.deepcopy(targets)
            for bt in bin_targets:
                bt['labels'] = torch.zeros_like(bt['labels'])

            # NOTE: this is a hack to use anchors for encoder matching, after matching we need to restore pred_boxes for computing loss
            if self.use_anchors_enc_match:
                enc_outputs['pred_boxes'], enc_outputs['anchors'] = enc_outputs['anchors'], enc_outputs['pred_boxes']
            indices = self.enc_matcher(enc_outputs, bin_targets)
            if self.use_anchors_enc_match:
                enc_outputs['pred_boxes'], enc_outputs['anchors'] = enc_outputs['anchors'], enc_outputs['pred_boxes']

            for loss in self.losses:
                if loss == 'masks':
                    # Intermediate masks losses are too costly to compute, we ignore them.
                    continue
                kwargs = {}
                if loss == 'labels':
                    # Logging is enabled only for the last layer
                    kwargs['log'] = False
                l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
                l_dict = {k + '_enc': v for k, v in l_dict.items()}
                losses.update(l_dict)

        return losses

In [75]:
set_cost_class = 2
set_cost_bbox = 5
set_cost_giou = 2

cls_loss_coef = 2
bbox_loss_coef = 5
giou_loss_coef = 2

enc_cls_loss_coef = 2
enc_bbox_loss_coef = 5
enc_giou_loss_coef = 2

dec_layers = 6

o2m_cls_loss_coef = 2
o2m_bbox_loss_coef = 5
o2m_giou_loss_coef = 2


def build():
    num_classes =  91

    device = torch.device("cuda")

    backbone = build_backbone(hidden_dim=256)
    transformer = build_deforamble_transformer()

    model = DeformableDETR(
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=num_queries,
        num_feature_levels=num_feature_levels,
        aux_loss=aux_loss,
        with_box_refine=with_box_refine,
        two_stage=two_stage,
        use_ms_detr=use_ms_detr,
    )

    matcher = build_matcher()
    enc_matcher = HungarianMatcher(cost_class=0, cost_bbox=set_cost_bbox, cost_giou=set_cost_giou)
    weight_dict = {'loss_ce': cls_loss_coef, 'loss_bbox': bbox_loss_coef, 'loss_giou': giou_loss_coef}
    weight_dict.update(
        {'loss_ce_enc': enc_cls_loss_coef, 'loss_bbox_enc': enc_bbox_loss_coef, 'loss_giou_enc': enc_giou_loss_coef}
    )

    
    # TODO this is a hack for auxiliary loss weights
    if aux_loss:
        aux_weight_dict = {}
        for i in range(dec_layers - 1):
            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)
    
    # NOTE: this is a hack to update the one-to-many loss weights
    o2m_weight_dict = {'loss_ce': o2m_cls_loss_coef, 'loss_bbox': o2m_bbox_loss_coef, 'loss_giou': o2m_giou_loss_coef}
    if aux_loss:
        aux_weight_dict = {}
        for i in range(dec_layers - 1):
            aux_weight_dict.update({k + f'_{i}': v for k, v in o2m_weight_dict.items()})
        o2m_weight_dict.update(aux_weight_dict)
    o2m_weight_dict = {k + '_o2m': v for k, v in o2m_weight_dict.items()}
    weight_dict.update(o2m_weight_dict)

    losses = ['labels', 'boxes', 'cardinality']

    # num_classes, matcher, weight_dict, losses, focal_alpha=0.25
    criterion = SetCriterion(
        num_classes, matcher, weight_dict, losses, focal_alpha=focal_alpha, num_queries=num_queries, use_anchors_enc_match=True)
    criterion.to(device)
    
    # post_process = PostProcess(topk=topk_eval) if nms_iou_threshold is None else NMSPostProcess(topk=topk_eval, nms_iou_threshold=nms_iou_threshold)
    # postprocessors = {'bbox': post_process}
    # if masks:
    #     postprocessors['segm'] = PostProcessSegm()
    #     if args.dataset_file == "coco_panoptic":
    #         is_thing_map = {i: i <= 90 for i in range(201)}
    #         postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)

    return model, criterion

In [77]:
model, criterion = build()

outputs = model.forward(samples)
loss = criterion.forward(outputs, target)

for k, v in loss.items():
    print(f'{k}: {v}')

loss_ce: tensor([1.1741], grad_fn=<MulBackward0>)
class_error: 100.0
loss_bbox: tensor([0.9935], grad_fn=<DivBackward0>)
loss_giou: tensor([0.8360], grad_fn=<DivBackward0>)
cardinality_error: 298.0
loss_ce_0: tensor([1.1362], grad_fn=<MulBackward0>)
loss_bbox_0: tensor([0.9935], grad_fn=<DivBackward0>)
loss_giou_0: tensor([0.8360], grad_fn=<DivBackward0>)
cardinality_error_0: 298.0
loss_ce_1: tensor([1.1346], grad_fn=<MulBackward0>)
loss_bbox_1: tensor([0.9935], grad_fn=<DivBackward0>)
loss_giou_1: tensor([0.8360], grad_fn=<DivBackward0>)
cardinality_error_1: 298.0
loss_ce_2: tensor([1.2174], grad_fn=<MulBackward0>)
loss_bbox_2: tensor([0.9935], grad_fn=<DivBackward0>)
loss_giou_2: tensor([0.8360], grad_fn=<DivBackward0>)
cardinality_error_2: 298.0
loss_ce_3: tensor([1.1329], grad_fn=<MulBackward0>)
loss_bbox_3: tensor([0.9935], grad_fn=<DivBackward0>)
loss_giou_3: tensor([0.8360], grad_fn=<DivBackward0>)
cardinality_error_3: 298.0
loss_ce_4: tensor([1.1745], grad_fn=<MulBackward0>)
lo