In [3]:
import torch
from torch import nn
import copy
from collections import OrderedDict

S3 = torch.rand(size=(2, 384, 80, 80))
S4 = torch.rand(size=(2, 768, 40, 40))
S5 = torch.rand(size=(2, 1536, 20, 20))

feats = [S3, S4, S5]

for i in feats:
    print(i.shape)

torch.Size([2, 384, 80, 80])
torch.Size([2, 768, 40, 40])
torch.Size([2, 1536, 20, 20])


In [4]:
hidden_dim = 256
input_proj = nn.ModuleList()

for input_feat in feats:
    in_channel = input_feat.shape[1]
    proj = nn.Sequential(OrderedDict([
        ('conv', nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)),
        ('norm', nn.BatchNorm2d(hidden_dim))
    ]))
    
    input_proj.append(proj)
    
proj_feats  = [input_proj[i](feat) for i, feat in enumerate(feats)]

for i in proj_feats:
    print(i.shape)

torch.Size([2, 256, 80, 80])
torch.Size([2, 256, 40, 40])
torch.Size([2, 256, 20, 20])


In [13]:
def get_activation(act: str, inpace: bool=True):
    """get activation
    """
    if act is None:
        return nn.Identity()

    elif isinstance(act, nn.Module):
        return act

    act = act.lower()

    if act == 'silu' or act == 'swish':
        m = nn.SiLU()

    elif act == 'relu':
        m = nn.ReLU()

    elif act == 'leaky_relu':
        m = nn.LeakyReLU()

    elif act == 'silu':
        m = nn.SiLU()

    elif act == 'gelu':
        m = nn.GELU()

    elif act == 'hardsigmoid':
        m = nn.Hardsigmoid()

    else:
        raise RuntimeError('')

    if hasattr(m, 'inplace'):
        m.inplace = inpace

    return m

def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
        """
        """
        grid_w = torch.arange(int(w), dtype=torch.float32)
        grid_h = torch.arange(int(h), dtype=torch.float32)
        grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
        assert embed_dim % 4 == 0, \
            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
        pos_dim = embed_dim // 4
        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
        omega = 1. / (temperature ** omega)

        out_w = grid_w.flatten()[..., None] @ omega[None]
        out_h = grid_h.flatten()[..., None] @ omega[None]

        return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]

class TransformerEncoderLayer(nn.Module):
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation="relu",
                 normalize_before=False):
        super().__init__()
        self.normalize_before = normalize_before

        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = get_activation(activation)

    @staticmethod
    def with_pos_embed(tensor, pos_embed):
        return tensor if pos_embed is None else tensor + pos_embed

    def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
        residual = src
        if self.normalize_before:
            src = self.norm1(src)
        q = k = self.with_pos_embed(src, pos_embed)
        src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask)

        src = residual + self.dropout1(src)
        if not self.normalize_before:
            src = self.norm1(src)

        residual = src
        if self.normalize_before:
            src = self.norm2(src)
        src = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = residual + self.dropout2(src)
        if not self.normalize_before:
            src = self.norm2(src)
        return src
    
class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor:
        output = src
        for layer in self.layers:
            output = layer(output, src_mask=src_mask, pos_embed=pos_embed)

        if self.norm is not None:
            output = self.norm(output)

        return output



encoder_layer = TransformerEncoderLayer(
    hidden_dim,
    nhead=8,
    dim_feedforward=1024,
    dropout=0.0,
    activation='silu'
)

encoder = nn.ModuleList([
    TransformerEncoder(copy.deepcopy(encoder_layer), 1) for _ in range(len([2]))
])

# print(encoder)

use_encoder_idx = [2]
for i, enc_ind in enumerate(use_encoder_idx):   # 0, 2
    h, w = proj_feats[enc_ind].shape[2:]
    # [B, C, H, W] to [B, HxW, C]
    src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1)
    
    pos_embed = build_2d_sincos_position_embedding(w, h, hidden_dim).to(src_flatten.device)
    
    memory = encoder[i](src_flatten, pos_embed=pos_embed)
    proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, hidden_dim, h, w).contiguous()

for i in proj_feats:
    print(i.shape)

torch.Size([2, 256, 80, 80])
torch.Size([2, 256, 40, 40])
torch.Size([2, 256, 20, 20])


## Boardcasting and fusion

In [None]:
class ConvNormLayer(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
        super().__init__()
        padding = (kernel_size-1)//2 if padding is None else padding
        self.conv = nn.Conv2d(
            ch_in,
            ch_out,
            kernel_size,
            stride,
            groups=g,
            padding=padding,
            bias=bias)
        self.norm = nn.BatchNorm2d(ch_out)
        self.act = nn.Identity() if act is None else get_activation(act)

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class ConvNormLayer_fuse(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None):
        super().__init__()
        padding = (kernel_size-1)//2 if padding is None else padding
        self.conv = nn.Conv2d(
            ch_in,
            ch_out,
            kernel_size,
            stride,
            groups=g,
            padding=padding,
            bias=bias)
        self.norm = nn.BatchNorm2d(ch_out)
        self.act = nn.Identity() if act is None else get_activation(act)
        self.ch_in, self.ch_out, self.kernel_size, self.stride, self.g, self.padding, self.bias = \
            ch_in, ch_out, kernel_size, stride, g, padding, bias

    def forward(self, x):
        if hasattr(self, 'conv_bn_fused'):
            y = self.conv_bn_fused(x)
        else:
            y = self.norm(self.conv(x))
        return self.act(y)

    def convert_to_deploy(self):
        if not hasattr(self, 'conv_bn_fused'):
            self.conv_bn_fused = nn.Conv2d(
                self.ch_in,
                self.ch_out,
                self.kernel_size,
                self.stride,
                groups=self.g,
                padding=self.padding,
                bias=True)

        kernel, bias = self.get_equivalent_kernel_bias()
        self.conv_bn_fused.weight.data = kernel
        self.conv_bn_fused.bias.data = bias
        self.__delattr__('conv')
        self.__delattr__('norm')

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor()

        return kernel3x3, bias3x3

    def _fuse_bn_tensor(self):
        kernel = self.conv.weight
        running_mean = self.norm.running_mean
        running_var = self.norm.running_var
        gamma = self.norm.weight
        beta = self.norm.bias
        eps = self.norm.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

In [28]:
import torch.nn.functional as F

class VGGBlock(nn.Module):
    def __init__(self, ch_in, ch_out, act='relu'):
        super().__init__()
        self.ch_in = ch_in
        self.ch_out = ch_out
        self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
        self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
        self.act = nn.Identity() if act is None else get_activation(act)

    def forward(self, x):
        if hasattr(self, 'conv'):
            y = self.conv(x)
        else:
            y = self.conv1(x) + self.conv2(x)

        return self.act(y)

    def convert_to_deploy(self):
        if not hasattr(self, 'conv'):
            self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)

        kernel, bias = self.get_equivalent_kernel_bias()
        self.conv.weight.data = kernel
        self.conv.bias.data = bias
        self.__delattr__('conv1')
        self.__delattr__('conv2')

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)

        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return F.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch: ConvNormLayer):
        if branch is None:
            return 0, 0
        kernel = branch.conv.weight
        running_mean = branch.norm.running_mean
        running_var = branch.norm.running_var
        gamma = branch.norm.weight
        beta = branch.norm.bias
        eps = branch.norm.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

class CSPLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_blocks=3,
                 expansion=1.0,
                 bias=False,
                 act="silu",
                 bottletype=VGGBlock):
        super(CSPLayer, self).__init__()
        hidden_channels = int(out_channels * expansion)
        self.conv1 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
        self.conv2 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
        self.bottlenecks = nn.Sequential(*[
            bottletype(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
        ])
        if hidden_channels != out_channels:
            self.conv3 = ConvNormLayer_fuse(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
        else:
            self.conv3 = nn.Identity()

    def forward(self, x):
        x_2 = self.conv2(x)
        x_1 = self.conv1(x)
        x_1 = self.bottlenecks(x_1)
        return self.conv3(x_1 + x_2)

class RepNCSPELAN4(nn.Module):
    # csp-elan
    def __init__(self, c1, c2, c3, c4, n=3,
                 bias=False,
                 act="silu"):
        super().__init__()
        self.c = c3//2
        self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act)
        self.cv2 = nn.Sequential(CSPLayer(c3//2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
        self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act))
        self.cv4 = ConvNormLayer_fuse(c3+(2*c4), c2, 1, 1, bias=bias, act=act)

    def forward_chunk(self, x):
        y = list(self.cv1(x).chunk(2, 1))
        y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
        return self.cv4(torch.cat(y, 1))

    def forward(self, x):
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
        return self.cv4(torch.cat(y, 1))

In [32]:
import torch.nn.functional as F

lateral_convs = nn.ModuleList()
fpn_blocks = nn.ModuleList()

for _ in range(len(feats) - 1, 0, -1):
    lateral_convs.append(ConvNormLayer_fuse(hidden_dim, hidden_dim, 1, 1))
    fpn_blocks.append(
        RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, 
                     round(1 * hidden_dim // 2), 
                     round(3 * 2), act='silu')
    )
    
    
    
# print(lateral_convs[len(feats) - 1 - 2])

inner_outs = [proj_feats[-1]]
print(inner_outs[0].shape)

feat_high = inner_outs[0]
feat_low = proj_feats[2 - 1]

feat_high = lateral_convs[len(feats) - 1 - 2](feat_high)
inner_outs[0] = feat_high
print(inner_outs[0].shape)

upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest')
print(upsample_feat.shape)

inner_out = fpn_blocks[len(feats) - 1 - 2](torch.concat([upsample_feat, feat_low], dim=1))
print(inner_out.shape)
inner_outs.insert(0, inner_out)


torch.Size([2, 256, 20, 20])
torch.Size([2, 256, 20, 20])
torch.Size([2, 256, 40, 40])
torch.Size([2, 256, 40, 40])


In [33]:
class SCDown(nn.Module):
    def __init__(self, c1, c2, k, s, act=None):
        super().__init__()
        self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1)
        self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2)

    def forward(self, x):
        return self.cv2(self.cv1(x))

downsample_convs = nn.ModuleList()
pan_blocks = nn.ModuleList()

for _ in range(len(feats) - 1):
    downsample_convs.append(
        nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, act='silu'))
        )
    pan_blocks.append(
        RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(1 * hidden_dim // 2), round(3 * 2), act='silu')
        )

In [36]:
outs = [inner_outs[0]]
feat_low = outs[-1]
feat_high = inner_outs[1]

print(f'feat_high: {feat_high.shape}')
downsample_feat = downsample_convs[0](feat_low)
print(f'downsample_feat: feat_low from : {feat_low.shape} to : {downsample_feat.shape}')
out = pan_blocks[0](torch.concat([downsample_feat, feat_high], dim=1))
print(out.shape)

feat_high: torch.Size([2, 256, 20, 20])
downsample_feat: feat_low from : torch.Size([2, 256, 40, 40]) to : torch.Size([2, 256, 20, 20])
torch.Size([2, 256, 20, 20])
