# vision in transformer

### 1、patch_embed

In [1]:
"""
Author: xiao qiang
Time: 2023/2/26 14:50 
Version: env==torch py==3.9
"""
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmengine.model import BaseModule
from mmengine.utils import to_2tuple
from mmcv.cnn import build_conv_layer, build_activation_layer, build_norm_layer


class AdaptivePadding(nn.Module):
    """
    Applies padding adaptively to the input
    This module can make input get fully covered by filter you specified.It support two modes 'same' and 'corner'
    the 'same' mode is same with 'SAME' padding mode in tensorflow, pad zero around input, the 'corner' mode would
    pad zero to bottom right.
    Args:
        kernel_size(int|tuple): size of the kernel, default:1
        stride(int|tuple): stride of the filter, default:1
        dilation(int|tuple): spacing between kernel elements, default:1
        padding(str): support 'same' and 'corner', 'corner' mode would pad zero to bottom right, and 'same' mode would
            pad zero around input, default: 'corner'
    """
    def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
        super(AdaptivePadding, self).__init__()
        assert padding in ('same', 'corner')
        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        dilation = to_2tuple(dilation)
        self.padding = padding
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation

    def get_pad_shape(self, input_shape):
        """
        calculate the padding size of input
        Args:
            input_shape: Arrange as (H, W)
        Returns:
            Tuple[int]: the padding size along the original H and W directions
        """
        input_h, input_w = input_shape
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.stride
        output_h = math.ceil(input_h / stride_h)
        output_w = math.ceil(input_w / stride_w)
        pad_h = max((output_h - 1) * stride_h +
                    (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
        pad_w = max((output_w - 1) * stride_w +
                    (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
        return pad_h, pad_w

    def forward(self, x):
        pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
        if pad_h > 0 or pad_w > 0:
            if self.padding == 'corner':
                # F.pad(x, [left, right, top, bottom])
                x = F.pad(x, [0, pad_w, 0, pad_h])
            elif self.padding == 'same':
                x = F.pad(x, [pad_w//2, pad_w-pad_w//2, pad_h//2, pad_h-pad_h//2])
        return x


class PatchEmbed(BaseModule):
    """
    Image to Patch embedding
    use a conv layer to implement patch-embed.
    Args:
        in_channels(int): the number of input channels, defaults:3
        embed_dims(int): the dimension of embedding, default:768
        conv_type(str): the type of convolution to generate patch embedding, default:'conv2d'.
        kernel_size(int): the kernel_size of embedding conv, default: 16
        stride(int): the slide stride of embedding conv, default:16
        padding(int|tuple|string): the padding length of embedding conv, when it is a string, it means the mode of
            adaptive padding, support 'same' and 'corner' now, default:'corner'
        dilation(int): the dilation rate of embedding conv, default: 1
        bias(bool): bias of embed conv, default True
        norm_cfg(dict, optional): config dict for normalization layer, default:None
        input_size(int|tuple|None): the size of input, which will be used to calculate the out size, only works when
            'dynamic_size' is False, default:None
        init_cfg(dict, optional): the config for initialization, default:None
    """
    def __init__(self, in_channels=3, embed_dims=768, conv_type='Conv2d', kernel_size=16,
                 stride=16, padding='corner', dilation=1, bias=True, norm_cfg=None, input_size=None, init_cfg=None):
        super(PatchEmbed, self).__init__(init_cfg=init_cfg)
        self.embed_dims = embed_dims
        if stride is None:
            stride = kernel_size
        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        dilation = to_2tuple(dilation)
        if isinstance(padding, str):
            self.adaptive_padding = AdaptivePadding(
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding)
            padding = 0
        else:
            self.adaptive_padding = 0
            padding = to_2tuple(padding)
        self.projection = build_conv_layer(dict(type=conv_type),
                                           in_channels=in_channels,
                                           out_channels=embed_dims,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           padding=padding,
                                           dilation=dilation,
                                           bias=bias)
        if norm_cfg is not None:
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
        else:
            self.norm = None
        if input_size:
            input_size = to_2tuple(input_size)
            self.init_input_size = input_size
            if self.adaptive_padding:
                pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size)
                input_h, input_w = input_size
                input_h = input_h + pad_h
                input_w = input_w + pad_w
                input_size = (input_h, input_w)
            # 卷积后输出尺寸计算：out_h = (h+2*padding-kernel_size)/stride+1
            # 带有空洞卷积，有效卷积核大小为:new_kernel_size = kernel_size+(kernel_size-1)*(dilation-1)
            # 将有效卷积核大小带入卷积输出尺寸计算即可
            h_out = (input_size[0] + 2*padding[0]-dilation[0]*(kernel_size[0]-1)-1)//stride[0]+1
            w_out = (input_size[1] + 2*padding[1]-dilation[1]*(kernel_size[1]-1)-1)//stride[1]+1
            self.init_out_size = (h_out, w_out)
        else:
            self.init_input_size = None
            self.init_out_size = None

    def forward(self, x):
        if self.adaptive_padding:
            x = self.adaptive_padding(x)
        x = self.projection(x)
        out_size = (x.shape[2], x.shape[3])
        # tensor.flatten(start_dim, end_dim)
        # shape:[n, embed_dims, h, w] -> [n, embed_dims, h*w] -> [n, h*w, embed_dims] h*w表示token个数
        x = x.flatten(2).transpose(1, 2)
        if self.norm is not None:
            x = self.norm(x)
        return x, out_size

### 2、attention

In [2]:
"""
Author: xiao qiang
Time: 2023/2/22 08:31 
Version: env==torch py==3.9
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmengine.model import BaseModule
from mmcv.cnn.bricks.drop import build_dropout
from mmcls.models.utils.layer_scale import LayerScale


class MultiheadAttention(BaseModule):
    """
    Multi-head Attention Module.
    This module implements multi-head attention that supports different input dims and embed dims. and it also supports
    a shortcut from 'value', which is useful if input dims is not same with embed dims.
    Args:
        embed_dims(int): the embedding dimension.
        num_heads(int): parallel attention heads.
        input_dims(int, Optional): the input dimension, and if None, use 'embed_dims',defaults to None.
        attn_drop(float): dropout rate of the dropout layer after the attention calculation of query and key, defaults 0
        proj_drop(float): dropout rate of the dropout layer after the output projection.
        dropout_layer(dict): the dropout config before adding the shortcut, defaults to dict(type='Dropout', drop_prob=0)
        qkv_bias(bool): if True, add a learnable bias to q, k, v, defaults: True
        qk_scale(float, optional): override default qk scale of 'head_dim**-0.5', if set, defaults None
        proj_bias（bool）:if True, add a learnable bias to output projection, defaults to True.
        v_shortcut（bool）:add a shortcut from value to output,it is usually used if input_dims is different from embed_dims
        defaults:False
        init_cfg(dict, optional):the config fro initialization,defaults to None
    """
    def __init__(self, embed_dims, num_heads, input_dims=None, attn_drop=0., proj_drop=0.,
                 dropout_layer=dict(type='Dropout', drop_prob=0.), qkv_bias=True, qk_scale=None, proj_bias=True,
                 v_shortcut=False, use_layer_scale=False, init_cfg=None):
        super(MultiheadAttention, self).__init__(init_cfg=init_cfg)
        self.input_dims = input_dims or embed_dims
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.v_shortcut = v_shortcut
        self.head_dims = embed_dims // num_heads
        self.scale = qk_scale or self.head_dims**-0.5
        self.qkv = nn.Linear(input_dims, embed_dims*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)
        self.out_drop = build_dropout(dropout_layer)
        if use_layer_scale:
            self.gamma1 = LayerScale(embed_dims)
        else:
            self.gamma1 = nn.Identity()

    def forward(self, x):
        # image shape: [B, N, patch_dim]
        B, N, _ = x.shape
        # input: [B, N, patch_dim], qkv后:[B, N, 3*embed_dims]
        # qkv: 获得指定head的qkv矩阵，reshape qkv:[B, N, 3, num_heads, head_dims] -> [3, B, num_heads, N, head_dims]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dims).permute(2, 0, 3, 1, 4)
        # q,k,v shape: [B, num_heads, N, head_dims]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # q@k.transpose(-2, -1):[B, num_heads, N, head_dims]@[B, num_heads, head_dims, N] = [B, num_heads, N, N]
        # 计算每张图片中每个head上q与k的attention, 后求softmax与dropout
        attn = (q@k.transpose(-2, -1))*self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # [B, num_heads, N, N]@[B, num_heads, N, head_dims] = [B, num_heads, N, head_dims]->[B, N, num_heads, head_dims]
        x = (attn@v).transpose(1, 2).reshape(B, N, self.embed_dims)
        # 多头注意力机制出来后concat后再次进行线性映射，并且添加可学习的bias
        x = self.proj(x)
        # 对映射后的输出进行dropout，并进行gamma1缩放，后再进行drop_path正则化
        # x:[B, N, embed_dims]
        x = self.out_drop(self.gamma1(self.proj_drop(x)))
        if self.v_shortcut:
            x = v.squeeze(1)+x
        return x


def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic', num_extra_tokens=1):
    """
    Resize pos_embed weights.
    Args:
        pos_embed(torch.Tensor): Position embedding weights with shape [1, L, C].
        src_shape(tuple): The resolution of down_sampled origin training image, in format(H, W).
        dst_shape(tuple): The resolution of down_sampled new training image, in format(H, W)
        mode(str): Algorithm used for up_sampling, choose one from 'nearest', 'linear','bilinear', 'bicubic' and
            'trilinear'
        num_extra_tokens(int): The number of extra tokens, such as cls_token, defaults to 1
    """
    if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
        return pos_embed
    assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
    _, L, C = pos_embed.shape
    src_h, src_w = src_shape
    assert L == src_h * src_w + num_extra_tokens, f'the length of "pos_embed" should equal src_h*src_w+extra_tokens'
    extra_tokens = pos_embed[:, :num_extra_tokens]
    src_weight = pos_embed[:, num_extra_tokens:]
    # src_weight: [1, L-extra_tokens, C] -> [1, src_h*src_w, C] -> [1, C, src_h*src_w]
    src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
    dst_weight = F.interpolate(src_weight, size=dst_shape, align_corners=False, mode=mode)
    # dst_weight: [1, C, dst_h*dst_w] - >[1, C, L] -> [1, L, C]
    dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
    return torch.cat((extra_tokens, dst_weight), dim=1)

### 3、FFN

In [3]:
"""
Author: xiao qiang
Time: 2023/2/25 22:31 
Version: env==torch py==3.9
"""
import torch
import torch.nn as nn

from mmengine.model import BaseModule
from mmcv.cnn import build_conv_layer, build_activation_layer, build_norm_layer, Linear
from mmcv.cnn.bricks.drop import build_dropout


class FFN(BaseModule):
    """
    Implements feed-forward networks with identity connection.
    Args:
        embed_dims(int): the feature dimension.Same as MultiHeadAttention, defaults:256
        feedforward_channels(int): the hidden dimension of FFNs. defaults: 1024
        num_fcs(int, optional): the number of fully_connected layers in FFNs, default:2
        act_cfg(dict, optional): the activation config for FFNs, defaults: dict(type='ReLU')
        ffn_drop(float, optional): Probability of an element to be zeroed in FFN, default:0.0
        add_identity(bool, optional): whether to add the identity connection, default: True.
        dropout_layer(obj:ConfigDict): the dropout_layer used when adding the shortcut.
        init_cfg(obj: ConfigDict): the config for initialization, default:None
    """
    def __init__(self, embed_dims=256, feedforward_channels=1024, num_fcs=2, act_cfg=dict(type='ReLU'),
                 ffn_drop=0., dropout_layer=None, add_identity=True, init_cfg=None, **kwargs):
        super(FFN, self).__init__(init_cfg=init_cfg)
        assert num_fcs >= 2, f'num_fcs should be no less than 2, but got {num_fcs}'
        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.num_fcs = num_fcs
        self.act_cfg = act_cfg
        self.activate = build_activation_layer(act_cfg)
        layers = []
        in_channels = embed_dims
        for _ in range(num_fcs-1):
            layers.append(nn.Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop)))
            in_channels = feedforward_channels
        layers.append(Linear(feedforward_channels, embed_dims))
        layers.append(nn.Dropout(ffn_drop))
        self.layers = nn.Sequential(*layers)
        self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity()
        self.add_identity = add_identity

    def forward(self, x, identity=None):
        out = self.layers(x)
        if not self.add_identity:
            return self.dropout_layer(out)
        if identity is None:
            identity = x
        return identity + self.dropout_layer(out)


### 4、vision transformer

In [5]:
"""
Author: xiao qiang
Time: 2023/2/21 22:37 
Version: env==torch py==3.9
"""
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn

from mmengine.model import BaseModule, ModuleList
from mmcv.cnn import build_norm_layer
from mmcls.models.utils import MultiheadAttention, to_2tuple, resize_pos_embed
from mmcv.cnn.bricks.transformer import FFN
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmengine.model.weight_init import trunc_normal_


class TransformerEncoderLayer(BaseModule):
    """
    Implement one encoder layer in vision transformer.
    Args:
        embed_dims（int）: the feature dimension.
        num_heads（int）: parallel attention heads.
        feedforward_channels（int）:the hidden dimension for FFNs.
        drop_rate（float）:probability of element to be zeroed after the feed forward layer, defaults 0.
        attn_drop_rate（float）: the drop out rate for attention output weights.
        drop_path_rate（float）:stochastic depth rate, defaults to 0.
        num_fcs（int）: the number of fully_connected layers for FFNs, defaults to 2.
        qkv_bias（bool）:enable bias for qkv if True, defaults to True.
        act_cfg（dict）: the activation config for FFNS, defaults to dict(type='GELU')
        norm_cfg（dict）:config dict for normalization layer, defaults to dict(type='LN')
        init_cfg（dict）: initialization config dict, defaults to None.
    """
    def __init__(self, embed_dims, num_heads, feedforward_channels, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., num_fcs=2, qkv_bias=True, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)
        self.embed_dims = embed_dims
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, self.embed_dims, postfix=1)
        self.add_module(self.norm1_name, norm1)
        self.attn = MultiheadAttention(embed_dims=embed_dims, num_heads=num_heads, attn_drop=attn_drop_rate,
                                       proj_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
                                       qkv_bias=qkv_bias)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, self.embed_dims, postfix=2)
        self.add_module(self.norm2_name, norm2)
        self.ffn = FFN(embed_dims=embed_dims, feedforward_channels=feedforward_channels, num_fcs=num_fcs,
                       ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate, act_cfg=act_cfg))

    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    def init_weights(self):
        super(TransformerEncoderLayer, self).init_weights()
        for m in self.ffn.modules():
            if isinstance(m, nn.Linear):
                nn.init_xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = self.ffn(self.norm2(x), identity=x)
        return x


class VisionTransformer(BaseBackbone):
    """
    Vision Transformer: A pytorch implement of 'an image is worth 16*16 words, transformers for image_recognition'
    Args:
        arch(str|dict): vision transformer architecture, if use string, choose from 'small','base','large','deit-tiny',
            'deit-small' and 'deit-base', if use dict, it should have below keys:
            - **embed_dims**(int): the dimension of embedding.
            - **num_layers**(int): the number of transformer encoder layers.
            - **num_heads**(int): the number of heads in attention modules.
            - **feedforward_channels**(int): the hidden dimensions in feedforward modules
            defaults: 'base'
        img_size(int|tuple): the expected input image shape, because we support dynamic input shape, just set the argument
            to the most common input image shape, defaults to 224.
        patch_size(int|tuple): the patch size in patch embedding. defaults to 16.
        in_channels(int): the number of input channels, defaults:3
        out_indices(sequence|int): output from which stages, defaults to -1, means the last stage
        drop_rate(float): probability of an element to be zeroed. defaults to 0
        drop_path_rate(float): stochastic depth rate, defaults to 0.
        qkv_bias(bool): whether to add bias for qkv in attention modules.
        norm_cfg(dict): config dict for normalization layer, defaults to dict(type='LN')
        final_norm(bool): whether to add a layer to normalize final feature map, defaults to True.
        with_cls_token(bool): whether concatenating class token into image tokens as transformer input, defaults True
        avg_token(bool): whether to use the mean patch token for classification, if true, the model will only
            take the average of all patch tokens, defaults to False
        frozen_stages(int): stages to be frozen, -1 means not freezing any parameters, defaults to -1
        output_cls_token(bool): whether output cls token, if set true, with_cls_token must be true,defaults to True
        beit_style(bool): whether to use beit_style, defaults to False
        layer_scale_init_value(float): the initialization value for the learnable scaling of attention and FFN, default
            to 0.1
        interpolate_mode(str):select the interpolate mode for position embeding vector resize, defaults to 'bicubic'
        patch_cfg(dict): configs of patch embeding, defaults to an empty dict
        layer_cfgs(sequence|dict): configs of each transformer layer in encoder, defaults to an empty dict
        init_cfg(dict, optional): initialization config dict, defaults to None
    """
    arch_zoo = {
        **dict.fromkeys(['s', 'small'], {'embed_dims': 768, 'num_layers': 8, 'num_heads': 8, 'feedforward_channels':768*3}),
        **dict.fromkeys(['b', 'base'], {'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 3072}),
        **dict.fromkeys(['l', 'large'], {'embed_dims': 1024, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096}),
        **dict.fromkeys(['h', 'huge'], {'embed_dims': 1280, 'num_layers': 32, 'num_heads': 16, 'feedforward_channels': 5120}),
        **dict.fromkeys(['deit-t', 'deit-tiny'], {'embed_dims': 192, 'num_layers': 12, 'num_heads': 3, 'feedforward_channels': 192*4}),
        **dict.fromkeys(['deit-s', 'deit-small'], {'embed_dims': 384, 'num_layers': 12, 'num_heads': 6, 'feedforward_channels': 384*4}),
        **dict.fromkeys(['deit-b', 'deit-base'], {'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 768*4})
    }
    # some structures have multiple extra tokens, like Deit
    num_extra_tokens = 1 # cls_token

    def __init__(self, arch='base', img_size=224, patch_size=16, in_channels=3, out_indices=-1, drop_rate=0.,
                 drop_path_rate=0., qkv_bias=True, norm_cfg=dict(type='LN'), final_norm=True, with_cls_token=True,
                 avg_token=False, frozen_stages=-1, output_cls_token=True, beit_style=False, layer_scale_init_value=0.1,
                 interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), init_cfg=None):
        super(VisionTransformer, self).__init__(init_cfg=init_cfg)
        if isinstance(arch, str):
            arch = arch.lower()
            # set(dict) -> dict.keys
            assert arch in set(self.arch_zoo), f'arch {arch} is not in default archs {set(self.arch_zoo)}'
            self.arch_settings = self.arch_zoo[arch]
        else:
            essential_keys = {'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'}
            assert isinstance(arch, dict) and essential_keys <= set(arch), f'custom arch needs a dict with keys ' \
                                                                           f'{essential_keys}'
            self.arch_settings = arch
        self.embed_dims = self.arch_settings['embed_dims']
        self.num_layers = self.arch_settings['num_layers']
        # img_size=224 -> (224, 224)
        self.img_size = to_2tuple(img_size)
        # set patch embedding
        _patch_cfg = dict(
            in_channels=in_channels,
            input_size=img_size,
            embed_dims=self.embed_dims,
            conv_type='Conv2d',
            kernel_size=patch_size,
            stride=patch_size,
        )
        _patch_cfg.update(patch_cfg)
        # execute an image to patch token
        self.patch_embed = PatchEmbed(**_patch_cfg)
        self.patch_resolution = self.patch_embed.init_out_size
        num_patches = self.patch_resolution[0]*self.pathch_resolution[1]

        # set cls token
        if output_cls_token:
            assert with_cls_token is True, f'with_cls_token must be True if set output_cls_token to True'
        self.with_cls_token = with_cls_token
        self.output_cls_token = output_cls_token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
        # set position embedding
        self.interpolate_mode = interpolate_mode
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+self.num_extra_tokens, self.embed_dims))
        self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
        self.drop_after_pos = nn.Dropout(p=drop_rate)
        # set output indices
        if isinstance(out_indices, int):
            out_indices = [out_indices]
        assert isinstance(out_indices, Sequence), f'out_indices must be sequence or int, but got {type(out_indices)}'
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = self.num_layers + index
            assert 0 <= out_indices[i] <= self.num_layers
        self.out_indices = out_indices
        # stochastic depth decay rule
        dpr = np.linspace(0, drop_path_rate, self.num_layers)
        self.layers = ModuleList()
        if isinstance(layer_cfgs, dict):
            layer_cfgs = [layer_cfgs]*self.num_layers
        for i in range(self.num_layers):
            _layer_cfg = dict(
                embed_dims=self.embed_dims,
                num_heads=self.arch_settings['num_heads'],
                feedforward_channels=self.arch_settings['feedforward_channels'],
                drop_rate=drop_rate,
                drop_path_rate=dpr[i],
                qkv_bias=qkv_bias,
                norm_cfg=norm_cfg)
            _layer_cfg.update(layer_cfgs[i])
            if beit_style:
                pass
            else:
                self.layers.append(TransformerEncoderLayer(**_layer_cfg))
        self.frozen_stages = frozen_stages
        self.final_norm = final_norm
        if final_norm:
            self.norm1_name, norm1 = build_norm_layer(norm_cfg, self.embed_dims, postfix=1)
            self.add_module(self.norm1_name, norm1)
        self.avg_token = avg_token
        if avg_token:
            self.norm2_name, norm2 = build_norm_layer(norm_cfg, self.embed_dims, postfix=2)
            self.add_module(self.norm2_name, norm2)
        if self.frozen_stages > 0:
            self._freeze_stages()

    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    def init_weights(self):
        super(VisionTransformer, self).init_weights()
        if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'):
            trunc_normal_(self.pos_embed, std=0.02)

    def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
        name = prefix + 'pos_embed'
        if name not in state_dict.keys():
            return
        ckpt_pos_embed_shape = state_dict[name].shape
        if self.pos_embed.shape != ckpt_pos_embed_shape:
            from mmengine.logging import MMLogger
            logger = MMLogger.get_current_instance()
            logger.info(f'resize the pos_embed shape from {ckpt_pos_embed_shape} to {self.pos_embed.shape}')
            ckpt_pos_embed_shape = to_2tuple(int(np.sqrt(ckpt_pos_embed_shape.shape[1]-self.num_extra_tokensda)))
            pos_embed_shape = self.patch_embed.init_out_size
            state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape,
                                                self.interpolate_mode, self.num_extra_tokens)

    def _freeze_stages(self):
        # freeze order: pos_embed, pos_embed_dropout -> patch_embed -> cls_token -> layers -> if freeze last layers \
        # freeze the norm
        # 涉及dropout, bn层的冻结时需要eval()模式
        # freeze position embedding
        self.pos_embed.requires_grad = False
        # set dropout to eval mode
        self.drop_after_pos.eval()
        # freeze patch embedding
        self.patch_embed.eval()
        for param in self.patch_embed.parameters():
            param.requires_grad = False
        # freeze cls token
        self.cls_token.requires_grad = False
        # freeze layers
        for i in range(1, self.frozen_stages+1):
            m = self.layers[i-1]
            m.eval()
            for param in m.parameters():
                param.requires_grad = False
        # freeze the last layer norm
        if self.frozen_stages == len(self.layers) and self.final_norm:
            self.norm1.eval()
            for param in self.norm1.parameters():
                param.requires_grad = False

    def forward(self, x):
        B = x.shape[0]
        x, patch_resolution = self.patch_embed(x)
        # cls_token: [1, 1, n_dim] -> [B, 1, n_dim]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        # pos_embed + patch_embed
        x = x + resize_pos_embed(self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode,
                                 num_extra_tokens=self.nu_extra_tokens)
        x = self.drop_after_pos(x)
        if not self.with_cls_token:
            x = x[:, 1:]
        outs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == len(self.layers) - 1 and self.final_norm:
                x = self.norm1(x)
            if i in self.out_indices:
                B, _, C = x.shape
                if self.with_cls_token:
                    # patch_token:[B, h, w, c] -> [B, c, h, w]
                    patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
                    patch_token = patch_token.permute(0, 3, 1, 2)
                    cls_token = x[:, 0]
                else:
                    patch_token = x.reshape(B, *patch_resolution, C)
                    patch_token = patch_token.permute(0, 3, 1, 2)
                    cls_token = None
                if self.avg_token:
                    # [B, h, w, c]
                    patch_token = patch_token.permute(0, 2, 3, 1)
                    patch_token = patch_token.reshape(B, patch_resolution[0]*patch_resolution[1], C).mean(dim=1)
                    patch_token = self.norm2(patch_token)
                if self.output_cls_token:
                    out = [patch_token, cls_token]
                else:
                    out = patch_token
                outs.append(out)
        return tuple(outs)

    @staticmethod
    def resize_pos_embed(*args, **kwargs):
        return resize_pos_embed(*args, **kwargs)
