# import

In [1]:
import os
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW, lr_scheduler
import transforms as T
from torchvision import datasets, ops
from torchvision.models.feature_extraction import create_feature_extractor
from einops import rearrange
from pycocotools import mask as coco_mask

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import linear_sum_assignment

# build dataset

In [2]:
from pathlib import Path

import torch
import torch.utils.data
from pycocotools import mask as coco_mask

from torchvision_datasets import CocoDetection as TvCocoDetection
from util.misc import get_local_rank, get_local_size
import transforms as T


class CocoDetection(TvCocoDetection):
    def __init__(self, img_folder, ann_file, transforms, return_masks, cache_mode=False, local_rank=0, local_size=1):
        super(CocoDetection, self).__init__(img_folder, ann_file,
                                            local_rank=local_rank, local_size=local_size)
        self._transforms = transforms
        self.prepare = ConvertCocoPolysToMask(return_masks)

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        img, target = self.prepare(img, target)
        if self._transforms is not None:
            img, target = self._transforms(img, target)
        return img, target


def convert_coco_poly_to_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks


class ConvertCocoPolysToMask(object):
    def __init__(self, return_masks=False):
        self.return_masks = return_masks

    def __call__(self, image, target):
        w, h = image.size

        image_id = target["image_id"]
        image_id = torch.tensor([image_id])

        anno = target["annotations"]

        anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]

        boxes = [obj["bbox"] for obj in anno]
        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        if self.return_masks:
            segmentations = [obj["segmentation"] for obj in anno]
            masks = convert_coco_poly_to_mask(segmentations, h, w)

        keypoints = None
        if anno and "keypoints" in anno[0]:
            keypoints = [obj["keypoints"] for obj in anno]
            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
            num_keypoints = keypoints.shape[0]
            if num_keypoints:
                keypoints = keypoints.view(num_keypoints, -1, 3)

        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        if self.return_masks:
            masks = masks[keep]
        if keypoints is not None:
            keypoints = keypoints[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        if self.return_masks:
            target["masks"] = masks
        target["image_id"] = image_id
        if keypoints is not None:
            target["keypoints"] = keypoints

        # for conversion to coco api
        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
        target["area"] = area[keep]
        target["iscrowd"] = iscrowd[keep]

        target["orig_size"] = torch.as_tensor([int(h), int(w)])
        target["size"] = torch.as_tensor([int(h), int(w)])

        return image, target


def make_coco_transforms(image_set):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),
                    T.RandomResize(scales, max_size=1333),
                ])
            ),
            normalize,
        ])

    if image_set == 'val':
        return T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ])

    raise ValueError(f'unknown {image_set}')


def build(image_set):
    
    mode = 'instances'
    PATHS = {
        "train": (os.path.join('tiny_coco', "train2017"), os.path.join('tiny_coco', "annotations", f'{mode}_train2017.json')),
        "val": (os.path.join('tiny_coco', "val2017"), os.path.join('tiny_coco', "annotations", f'{mode}_val2017.json')),
    }

    img_folder, ann_file = PATHS[image_set]
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=False,
                            local_rank=get_local_rank(), local_size=get_local_size())
    return dataset


# build DataLoader

In [3]:
import util.misc as utils

dataset_train = build(image_set='train')
dataset_val = build(image_set='val')


data_loader_train = DataLoader(dataset_train, 2,
                                   collate_fn=utils.collate_fn, num_workers=0,
                                   pin_memory=True)
data_loader_val = DataLoader(dataset_val, 2,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=0,
                                 pin_memory=True)

print(f'\nNumber of train samples: {len(data_loader_train)}')
print(f'\nNumber of val samples: {len(data_loader_val)}')

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

Number of train samples: 25

Number of val samples: 25


## Test dataloader

In [4]:
# NestedTensor, Tuple of (annotations) => Tuple(Tensors, masks), Tuple of (annotations)
samples, target = next(iter(data_loader_train))

print(samples.tensors.shape)
print(samples.mask.shape)

print(target[0])
print(target[1])

torch.Size([2, 3, 810, 768])
torch.Size([2, 810, 768])
{'boxes': tensor([[0.9218, 0.7209, 0.0278, 0.0915],
        [0.9893, 0.7768, 0.0174, 0.1130],
        [0.9748, 0.7756, 0.0214, 0.0764],
        [0.9395, 0.7376, 0.0222, 0.0933],
        [0.9800, 0.9051, 0.0209, 0.0974],
        [0.9607, 0.9080, 0.0153, 0.0434],
        [0.2526, 0.5632, 0.5052, 0.8467],
        [0.7689, 0.6562, 0.1585, 0.5806],
        [0.0717, 0.4550, 0.0216, 0.0167],
        [0.1572, 0.5468, 0.0210, 0.0152],
        [0.0448, 0.5400, 0.0896, 0.0906],
        [0.9630, 0.7527, 0.0271, 0.1001],
        [0.1845, 0.4892, 0.0164, 0.0139],
        [0.9983, 0.6217, 0.0035, 0.0152],
        [0.2225, 0.4890, 0.0162, 0.0109],
        [0.1717, 0.5665, 0.0478, 0.0210],
        [0.1250, 0.5698, 0.0377, 0.0183],
        [0.9869, 0.6088, 0.0262, 0.0286],
        [0.0460, 0.4536, 0.0238, 0.0166],
        [0.1965, 0.4904, 0.0137, 0.0143],
        [0.0722, 0.4553, 0.0186, 0.0157]]), 'labels': tensor([44, 44, 44, 44, 44, 44,  1,  1, 5

# model

## position encoding

In [5]:
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn

from util.misc import NestedTensor


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        h, w = x.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        x_emb = self.col_embed(i)
        y_emb = self.row_embed(j)
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos


def build_position_encoding(hidden_dim=256, position_embedding = 'sine'):
    N_steps = hidden_dim // 2
    if position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {position_embedding}")

    return position_embedding

### Test position encoding

In [6]:
test_position_encoding_model = build_position_encoding()
# samples: torch.Size([2, 3, 1199, 800])
test_position_encoding = test_position_encoding_model(samples)
print(f'train samples shape: {samples.tensors.shape}')
print(f'position encoding shape: {test_position_encoding.shape}')

train samples shape: torch.Size([2, 3, 810, 768])
position encoding shape: torch.Size([2, 256, 810, 768])


## Backbone

In [7]:
"""
Backbone modules.
"""
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from util.misc import NestedTensor, is_main_process

class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n, eps=1e-5):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))
        self.eps = eps

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = self.eps
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias


class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
            return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
            self.strides = [8, 16, 32]
            self.num_channels = [512, 1024, 2048]
        else:
            return_layers = {'layer4': "0"}
            self.strides = [32]
            self.num_channels = [2048]
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out
    
class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        
        replace_stride_with_dilation = [False, False, dilation]
        norm_layer = FrozenBatchNorm2d

        weights = None
        if name == 'resnet18':
            weights = torchvision.models.ResNet18_Weights.DEFAULT
            backbone = torchvision.models.resnet18(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 512
        elif name == 'resnet34':
            weights = torchvision.models.ResNet34_Weights.DEFAULT
            backbone = torchvision.models.resnet34(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 512
        elif name == 'resnet50':
            weights = torchvision.models.ResNet50_Weights.DEFAULT
            backbone = torchvision.models.resnet50(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 2048
        elif name == 'resnet101':
            weights = torchvision.models.ResNet101_Weights.DEFAULT
            backbone = torchvision.models.resnet101(weights=weights, replace_stride_with_dilation=replace_stride_with_dilation, norm_layer=norm_layer)
            num_channels = 2048
        else:
            raise ValueError(f"Unsupported backbone: {name}")
        
        assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
        super().__init__(backbone, train_backbone, return_interm_layers)
        if dilation:
            self.strides[-1] = self.strides[-1] // 2

### test backbone

In [8]:
backbone = Backbone(name='resnet50', train_backbone=False, return_interm_layers=True, dilation=False)

# Dict of {layer_{i}: nestedTensor}
backbone_out_nestedTensor = backbone(samples)

print(f'train samples shape: {samples.tensors.shape}')
for k, v in backbone_out_nestedTensor.items():
    print(f"{k} tensor shape : {v.tensors.shape}, mask shape: {v.mask.shape}")

train samples shape: torch.Size([2, 3, 810, 768])
0 tensor shape : torch.Size([2, 512, 102, 96]), mask shape: torch.Size([2, 102, 96])
1 tensor shape : torch.Size([2, 1024, 51, 48]), mask shape: torch.Size([2, 51, 48])
2 tensor shape : torch.Size([2, 2048, 26, 24]), mask shape: torch.Size([2, 26, 24])


### test position with backbone_feature

In [9]:
backbone_feature_position_enc = test_position_encoding_model(backbone_out_nestedTensor['2'])
print(backbone_feature_position_enc.shape)

torch.Size([2, 256, 26, 24])


## Joiner position & backbone

In [10]:
class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)
        self.strides = backbone.strides
        self.num_channels = backbone.num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in sorted(xs.items()):
            out.append(x)

        # position encoding
        for x in out:
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos

### test joiner

In [11]:
test_joiner = Joiner(backbone, test_position_encoding_model)

test_joiner_out, test_joiner_pos = test_joiner(samples)

print(f"test_joiner_out has length: {len(test_joiner_out)}")
for i in range(len(test_joiner_out)):
    print(f'test_joiner_out layer_{i+1} has shape: {test_joiner_out[i].tensors.shape}')
    
    
print(f"\ntest_joiner_pos has length: {len(test_joiner_pos)}")
for i in range(len(test_joiner_pos)):
    print(f'test_joiner_pos layer_{i+1} has shape: {test_joiner_pos[i].shape}')


test_joiner_out has length: 3
test_joiner_out layer_1 has shape: torch.Size([2, 512, 102, 96])
test_joiner_out layer_2 has shape: torch.Size([2, 1024, 51, 48])
test_joiner_out layer_3 has shape: torch.Size([2, 2048, 26, 24])

test_joiner_pos has length: 3
test_joiner_pos layer_1 has shape: torch.Size([2, 256, 102, 96])
test_joiner_pos layer_2 has shape: torch.Size([2, 256, 51, 48])
test_joiner_pos layer_3 has shape: torch.Size([2, 256, 26, 24])


## build backbone (joiner)

In [12]:
def build_backbone(
    hidden_dim: int,
    position_embedding: str = 'sine',
    backbone: str = 'resnet50',
    return_interm_layers: bool = False,
    dilation: bool = False
):
    position_embedding = build_position_encoding(hidden_dim=256, position_embedding='sine')
    train_backbone = False
    backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
    model = Joiner(backbone, position_embedding)
    return model

# DeformableTransformer

## Encoder

In [13]:
import copy

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

### MSDeformAttn

In [14]:
import functools
import torch.nn.init as init

def deformable_attention_core_func_v2(
    value: List[torch.Tensor], # <-- Change type hint to List[Tensor]
    value_spatial_shapes,
    sampling_locations: torch.Tensor,
    attention_weights: torch.Tensor,
    num_points_list: List[int],
    method='default',
    value_shape='irrelevant', # <-- This parameter is now less relevant if value is pre-processed
    ):
    """
    Args:
        value (List[Tensor]): List of level features [(bs * n_head, c, H_0, W_0), ...]
        # ... other args
    """
    # bs, n_head, c, _ = value[0].shape # Incorrect way to get dims if value is list
    bs_times_n_head, c, _, _ = value[0].shape # Get shape from first level tensor
    # Need bs and n_head separately for reshaping output later
    # We don't know n_head directly here unless passed or inferred. Let's assume it's needed later.
    # sampling_locations shape: [bs, query_length, n_head, n_levels * n_points, 2]
    bs, Len_q, n_head, _, _ = sampling_locations.shape # Infer bs and n_head here

    # TODO: If value_shape logic was needed, reinstate carefully
    # if value_shape == 'default':
    #     pass # Assume value is already a list [bs*n_head, c, h, w]
    # elif value_shape == 'reshape': # This block might be redundant now
    #     # Original value shape was [bs, value_length, n_heads, c]
    #     # bs, _, n_head, c = value.shape # Can't do this if value is List
    #     # split_shape = [h * w for h, w in value_spatial_shapes]
    #     # # value needs permutation and splitting
    #     # value_permuted = value.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, value_length] - Incorrect permute
    #     # # Need to reshape to spatial first: [bs, n_heads, c, h, w] then flatten/split? Complex.
    #     # # --> It's better to do the splitting *before* calling this core function.
    #     pass # Remove or adapt this logic

    _, Len_q, _, _, _ = sampling_locations.shape

    sampling_grids = 2 * sampling_locations - 1 if method == 'default' else sampling_locations
    sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # Shape: [bs * n_head, Len_q, n_levels * n_points, 2]
    sampling_locations_list = sampling_grids.split(num_points_list, dim=-2) # List of [bs * n_head, Len_q, n_points_level, 2]

    sampling_value_list = []
    for level, (h, w) in enumerate(value_spatial_shapes):
        # value_l is already precomputed and passed in the list 'value'
        value_l = value[level] # Shape: [bs * n_head, c, h, w]
        sampling_grid_l: torch.Tensor = sampling_locations_list[level] # Shape: [bs * n_head, Len_q, n_points_level, 2]

        if method == 'default':
            # grid_sample input: [N, C, H_in, W_in], grid: [N, H_out, W_out, 2] -> output: [N, C, H_out, W_out]
            # Our value_l: [bs*n_head, c, h, w]
            # Our grid_l: [bs*n_head, Len_q, n_points_level, 2]
            # We need grid reshaped to [N, H_out, W_out, 2] where N=bs*n_head, H_out=Len_q, W_out=n_points_level
            sampling_grid_l_reshaped = sampling_grid_l.reshape(bs * n_head, Len_q, num_points_list[level], 2)

            sampling_value_l = F.grid_sample(
                value_l,
                sampling_grid_l_reshaped, # Use reshaped grid
                mode='bilinear',
                padding_mode='zeros',
                align_corners=False)
            # Output sampling_value_l shape: [bs*n_head, c, Len_q, n_points_level]

        elif method == 'discrete':
             # n * m, seq, n, 2
             sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value_l.device) + 0.5).to(torch.int64)

             # FIX ME? for rectangle input
             # Clamp uses h for both dimensions? Should use w for dim 0, h for dim 1
             # sampling_coord = sampling_coord.clamp(0, h - 1) # Original
             sampling_coord[..., 0] = sampling_coord[..., 0].clamp(0, w - 1) # Clamp x coord
             sampling_coord[..., 1] = sampling_coord[..., 1].clamp(0, h - 1) # Clamp y coord

             sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2)

             s_idx = torch.arange(sampling_coord.shape[0], device=value_l.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1])
             # Indexing: value_l[batch_idx, channel_idx, y_coord, x_coord]
             sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c

             sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level])


        sampling_value_list.append(sampling_value_l)
        # sampling_value_list contains tensors of shape [bs*n_head, c, Len_q, n_points_level]

    # Concatenate sampling values across levels
    # List of [bs*n_head, c, Len_q, n_points_level] -> [bs*n_head, c, Len_q, sum(n_points_list)]
    concatenated_sampling_values = torch.cat(sampling_value_list, dim=-1)

    # Prepare attention weights
    # attention_weights shape: [bs, Len_q, n_head, sum(n_points_list)]
    # Needs shape [bs * n_head, 1, Len_q, sum(n_points_list)] for broadcasting
    attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list))

    # Apply weights: [bs*n_head, c, Len_q, total_points] * [bs*n_head, 1, Len_q, total_points]
    weighted_sample_locs = concatenated_sampling_values * attn_weights
    # Sum over points: -> [bs*n_head, c, Len_q]
    output = weighted_sample_locs.sum(dim=-1)

    # Reshape output: [bs * n_head, c, Len_q] -> [bs, n_head * c, Len_q]
    output = output.reshape(bs, n_head * c, Len_q) # Need head_dim here! Pass it in or infer C=head_dim

    # Let's assume C = self.head_dim passed somehow or calculated
    head_dim = c # Use inferred 'c' from value_l shape
    output = output.reshape(bs, n_head * head_dim, Len_q)

    # Permute to [bs, Len_q, C] which matches MSDeformAttn expected output
    return output.permute(0, 2, 1)

class MSDeformAttn(nn.Module):
    def __init__(
        self,
        embed_dim=256,
        num_heads=8,
        num_levels=4,
        num_points=4,
        method='default',
        offset_scale=0.5,
    ):
        """Multi-Scale Deformable Attention
        """
        super(MSDeformAttn, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_levels = num_levels
        self.offset_scale = offset_scale

        if isinstance(num_points, list):
            assert len(num_points) == num_levels, ''
            num_points_list = num_points
        else:
            num_points_list = [num_points for _ in range(num_levels)]

        self.num_points_list = num_points_list

        num_points_scale = [1/n for n in num_points_list for _ in range(n)]
        self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32))

        self.total_points = num_heads * sum(num_points_list)
        self.method = method

        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
        self.attention_weights = nn.Linear(embed_dim, self.total_points)

        self.ms_deformable_attn_core = functools.partial(deformable_attention_core_func_v2, method=self.method)

        self._reset_parameters()

        if method == 'discrete':
            for p in self.sampling_offsets.parameters():
                p.requires_grad = False

    def _reset_parameters(self):
        # sampling_offsets
        init.constant_(self.sampling_offsets.weight, 0)
        thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
        grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1])
        scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1)
        grid_init *= scaling
        self.sampling_offsets.bias.data[...] = grid_init.flatten()

        # attention_weights
        init.constant_(self.attention_weights.weight, 0)
        init.constant_(self.attention_weights.bias, 0)


    def forward(self,
                query: torch.Tensor,
                reference_points: torch.Tensor,
                value: torch.Tensor,            # value is likely flattened: [bs, value_length, C]
                value_spatial_shapes: List[int],
                level_start_index: torch.Tensor): # <-- 新增 level_start_index
        """
        Args:
            query (Tensor): [bs, query_length, C]
            reference_points (Tensor): [bs, query_length, n_levels, 2 or 4]
            value (Tensor): [bs, value_length, C]
            value_spatial_shapes (List): [n_levels, 2]
            level_start_index (Tensor): [n_levels]  <-- 新增
        Returns:
            output (Tensor): [bs, query_length, C]
        """
        bs, Len_q = query.shape[:2]
        _, Len_v = value.shape[:2] # 获取 value length

        # --- 计算 sampling offsets 和 attention weights ---
        sampling_offsets: torch.Tensor = self.sampling_offsets(query)
        # sampling_offsets shape: [bs, Len_q, num_heads * sum(num_points_list) * 2]
        # Reshape to: [bs, Len_q, num_heads, sum(num_points_list), 2]
        sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, sum(self.num_points_list), 2)

        attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))
        attention_weights = F.softmax(attention_weights, dim=-1)
        # attention_weights shape: [bs, Len_q, num_heads, sum(num_points_list)]

        # --- 计算 sampling locations ---
        if reference_points.shape[-1] == 2:
            # Reshape sampling_offsets to separate levels and points
            # Assumes num_points is constant per level (self.num_points_list = [N, N, ..., N])
            # If num_points can vary, this needs adjustment using num_points_list
            num_points_per_level = self.num_points_list[0] # Assuming constant
            sampling_offsets = sampling_offsets.reshape(
                bs, Len_q, self.num_heads, self.num_levels, num_points_per_level, 2)
            # New sampling_offsets shape: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            # Prepare offset_normalizer
            # Flip H, W to W, H for normalization
            offset_normalizer = torch.as_tensor(value_spatial_shapes, dtype=query.dtype, device=query.device).flip([1])
            # Reshape for broadcasting: [1, 1, 1, num_levels, 1, 2]
            offset_normalizer = offset_normalizer.reshape(1, 1, 1, self.num_levels, 1, 2)

            # Normalize offsets: [..., num_heads, num_levels, num_points, 2] / [..., 1, num_levels, 1, 2]
            normalized_offsets = sampling_offsets / offset_normalizer

            # Prepare reference_points
            # Reshape/unsqueeze for broadcasting: [bs, Len_q, 1, num_levels, 1, 2]
            reference_points_unsqueezed = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2)

            # Calculate sampling locations: [..., 1, num_levels, 1, 2] + [..., num_heads, num_levels, num_points, 2]
            sampling_locations = reference_points_unsqueezed + normalized_offsets

            # Reshape back for core function: [bs, Len_q, num_heads, num_levels * num_points, 2]
            sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * num_points_per_level, 2)

        elif reference_points.shape[-1] == 4:
            # reference_points [bs, Len_q, n_levels, 4] (x_center, y_center, w, h)
            # sampling_offsets [bs, Len_q, num_heads, sum(num_points_list), 2]

            num_points_per_level = self.num_points_list[0] # Assuming constant
            sampling_offsets = sampling_offsets.reshape(
                bs, Len_q, self.num_heads, self.num_levels, num_points_per_level, 2)
            # Reshaped offsets: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            ref_xy = reference_points[..., :2].to(dtype=query.dtype, device=query.device) # [bs, Len_q, num_levels, 2]
            ref_wh = reference_points[..., 2:].to(dtype=query.dtype, device=query.device) # [bs, Len_q, num_levels, 2]

            # num_points_scale was originally [sum(num_points_list)]
            num_points_scale = self.num_points_scale.to(dtype=query.dtype, device=query.device)
            # Reshape to [1, 1, 1, num_levels, num_points_per_level, 1]
            num_points_scale = num_points_scale.reshape(1, 1, 1, self.num_levels, num_points_per_level, 1)

            # Calculate scaled offset: offset = sampling_offset * scale * ref_wh
            # Shapes: [..., head, lvl, pt, 2] * [..., 1, lvl, pt, 1] * [..., 1, lvl, 1, 2]
            offset = sampling_offsets * num_points_scale * self.offset_scale * ref_wh.reshape(bs, Len_q, 1, self.num_levels, 1, 2)
            # Result shape: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            # Calculate sampling locations: sampling_locations = ref_xy + offset
            # Shapes: [..., 1, lvl, 1, 2] + [..., head, lvl, pt, 2]
            sampling_locations = ref_xy.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + offset
            # Result shape: [bs, Len_q, num_heads, num_levels, num_points_per_level, 2]

            # Reshape back for core function: [bs, Len_q, num_heads, num_levels * num_points, 2]
            sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * num_points_per_level, 2)

        else:
            raise ValueError(
                "Last dim of reference_points must be 2 or 4, but get {} instead.".
                format(reference_points.shape[-1]))

        # --- Prepare value for core function ---
        # Input value shape: [bs, value_length, C]
        # Core function expects value LIST: [level0_tensor, level1_tensor, ...]
        # where level_tensor shape is [bs, H_lvl * W_lvl, num_heads, head_dim] ?? No, core expects [bs*n_head, c, h, w] after reshape
        # Let's prepare the list first. Value needs splitting based on level_start_index.
        # head_dim = C // num_heads
        value = value.reshape(bs, Len_v, self.num_heads, self.head_dim) # [bs, value_length, num_heads, head_dim]
        # Need to permute to match core function's expectation after splitting?
        # Core uses: value_l = value[level].reshape(bs * n_head, c, h, w)
        # Let's create the list of [bs, H*W, num_heads, head_dim] first
        value_list = []
        value_level_shapes = [(h, w) for h, w in value_spatial_shapes] # Ensure it's list of tuples
        for lvl in range(self.num_levels):
            start_idx = level_start_index[lvl]
            end_idx = level_start_index[lvl+1] if lvl < self.num_levels - 1 else Len_v
            value_lvl = value[:, start_idx:end_idx, :, :] # Shape [bs, H*W, num_heads, head_dim]
            # Reshape and permute for core function: [bs * num_heads, head_dim, H, W]
            h, w = value_level_shapes[lvl]
            value_lvl_permuted = value_lvl.permute(0, 2, 3, 1).reshape(bs * self.num_heads, self.head_dim, h, w)
            value_list.append(value_lvl_permuted)


        # --- Call the core function ---
        # Note: Pass value_list to the core function
        output = self.ms_deformable_attn_core(
            value_list, # Pass the prepared list
            value_spatial_shapes,
            sampling_locations,
            attention_weights,
            self.num_points_list,
            value_shape='reshape') # Use 'reshape' as we manually prepared value_list


        # output shape from core: [bs, n_head * c, Len_q]
        # Reshape to [bs, Len_q, n_head * c] which is [bs, Len_q, C]
        return output # Should already be [bs, Len_q, C] if core returns correctly

In [15]:
class DeformableTransformerEncoderLayer(nn.Module):
    def __init__(self,
                 d_model=256, d_ffn=1024,
                 dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4):
        super().__init__()

        # self attention
        self.self_attn = MSDeformAttn(embed_dim=d_model, num_heads=n_heads, num_levels=n_levels, num_points=n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm2(src)
        return src

    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
        # self attention
        src2 = self.self_attn(query=self.with_pos_embed(src, pos),
                            reference_points=reference_points,
                            value=src,
                            value_spatial_shapes=spatial_shapes,
                            level_start_index=level_start_index) # <-- Pass level_start_index
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # ffn
        src = self.forward_ffn(src)
        return src
    

class DeformableTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    @staticmethod
    def get_reference_points(spatial_shapes, valid_ratios, device):
        reference_points_list = []
        for lvl, (H_, W_) in enumerate(spatial_shapes):

            ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                          torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        return reference_points

    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
        output = src
        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
        for _, layer in enumerate(self.layers):
            output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)

        return output

In [16]:
d_model = 256
normalize_before = True

test_encoder_layer = DeformableTransformerEncoderLayer(d_model=d_model, n_heads=8, d_ffn=1024, activation="relu")
test_encoder = DeformableTransformerEncoder(test_encoder_layer, num_layers=6)

print(test_encoder)

DeformableTransformerEncoder(
  (layers): ModuleList(
    (0-5): 6 x DeformableTransformerEncoderLayer(
      (self_attn): MSDeformAttn(
        (sampling_offsets): Linear(in_features=256, out_features=256, bias=True)
        (attention_weights): Linear(in_features=256, out_features=128, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=256, out_features=1024, bias=True)
      (dropout2): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=1024, out_features=256, bias=True)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
)


### test encoder

In [17]:
num_feature_levels = 4
num_backbone_outs = len(backbone.strides)
input_proj_list = []


if num_feature_levels > 1:
    for _ in range(num_backbone_outs):
        in_channels = backbone.num_channels[_]
        input_proj_list.append(nn.Sequential(
                        nn.Conv2d(in_channels, d_model, kernel_size=1),
                        nn.GroupNorm(32, d_model),
                    ))
    for _ in range(num_feature_levels - num_backbone_outs):
        input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, d_model, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, d_model),
                ))
        in_channels = d_model
    input_proj = nn.ModuleList(input_proj_list)

else:
    input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(backbone.num_channels[0], d_model, kernel_size=1),
                    nn.GroupNorm(32, d_model),
                )])

for proj in input_proj:
    nn.init.xavier_uniform_(proj[0].weight, gain=1)
    nn.init.constant_(proj[0].bias, 0)


features, pos = test_joiner(samples)
srcs = []
masks = []

for l, feat in enumerate(features):
    src, mask = feat.decompose()
    srcs.append(input_proj[l](src))
    masks.append(mask)

if num_feature_levels > len(srcs):
    _len_srcs = len(srcs)
    for l in range(_len_srcs, num_feature_levels):
        if l == _len_srcs:
             src = input_proj[l](features[-1].tensors)
        else:
            src = input_proj[l](srcs[-1])
        m = samples.mask
        mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
        pos_l = test_joiner[1](NestedTensor(src, mask)).to(src.dtype)
        srcs.append(src)
        masks.append(mask)
        pos.append(pos_l)

num_queries = 300
query_embed = nn.Embedding(num_queries, d_model)

query_embeds = query_embed.weight

# srcs, masks, pos, query_embeds

In [18]:
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))


for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos)):
    bs, c, h, w = src.shape
    spatial_shape = (h, w)
    spatial_shapes.append(spatial_shape)
    src = src.flatten(2).transpose(1, 2)
    mask = mask.flatten(1)
    pos_embed = pos_embed.flatten(2).transpose(1, 2)
    
    lvl_pos_embed = pos_embed + level_embed[lvl].view(1, 1, -1)
    
    lvl_pos_embed_flatten.append(lvl_pos_embed)
    src_flatten.append(src)
    mask_flatten.append(mask)
    
def get_valid_ratio(mask):
    _, H, W = mask.shape
    valid_H = torch.sum(~mask[:, :, 0], 1)
    valid_W = torch.sum(~mask[:, 0, :], 1)
    valid_ratio_h = valid_H.float() / H
    valid_ratio_w = valid_W.float() / W
    valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
    return valid_ratio
    
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)

print(f'src_flatten shape : {src_flatten.shape} (H1 X W1 + H2 X W2 + H3 X W3 + H4 X W4)')
print(f'mask_flatten shape : {mask_flatten.shape}')
print(f'lvl_pos_embed_flatten shape : {lvl_pos_embed_flatten.shape}')
print(f'level_start_index : {level_start_index}')
print(f'valid_ratios shape: {valid_ratios.shape}')

src_flatten shape : torch.Size([2, 13020, 256]) (H1 X W1 + H2 X W2 + H3 X W3 + H4 X W4)
mask_flatten shape : torch.Size([2, 13020])
lvl_pos_embed_flatten shape : torch.Size([2, 13020, 256])
level_start_index : tensor([    0,  9792, 12240, 12864])
valid_ratios shape: torch.Size([2, 4, 2])


In [20]:
def get_reference_points(spatial_shapes, valid_ratios, device):
    reference_points_list = []
    for lvl, (H_, W_) in enumerate(spatial_shapes):

        ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                        torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
        ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
        ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
        ref = torch.stack((ref_x, ref_y), -1)
        reference_points_list.append(ref)
    reference_points = torch.cat(reference_points_list, 1)
    reference_points = reference_points[:, :, None] * valid_ratios[:, None]
    return reference_points

device = "cuda" if torch.cuda.is_available() else "cpu"
reference_points = get_reference_points(spatial_shapes, valid_ratios, device)

encoder_layer = DeformableTransformerEncoderLayer()
encoder_layer_output = encoder_layer.forward(src_flatten, lvl_pos_embed_flatten, reference_points, spatial_shapes, level_start_index, padding_mask=None)
print(encoder_layer_output.shape)

torch.Size([2, 13020, 256])
