## 实现SwinTransformer

### 1. 如何基于图片生成patch embedding?
**方法1**
- 基于pytorch unfold的api将图片进行分块, 也就是模仿卷积的思路, 设置kernel_size=patch_size,得到分块后的图片
- 得到格式为[bs, num_patch, patch_depth]的张量
- 将张量与形状为[patch_depth, model_dim_C]的权重矩阵进行乘法操作, 即可得到形状为[bs, num_patch, model_dim_C]的patch embedding

**方法2**
- patch_depth等于input_channel * patch_size * patch_size
- model_dim_C相当于二维卷积的输出通道数目
- 将形状为[patch_depth, model_dim_C]的权重矩阵转换为[model_dim_C, input_channel, patch_size, patch_size]的卷积核
- 调用PyTorch的conv2d API得到卷积的输出张量 形状为[bs, output_channel, width]
- 转换为[bs, num_patch, model_dim_C]的格式, 即为patch embedding

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
def image2emb_naive(image, patch_size, weight):
    """直观方法实现patch embedding"""
    # image_size: bs * channel * h * w
    patch = F.unfold(image, kernel_size=(patch_size, patch_size),
                     stride=(patch_size, patch_size)).transpose(-1, -2)  # [bs, num_patch, patch_depth]
    patch_embedding = patch @ weight  # [bs, num_patch, model_dim_c]
    return patch_embedding


def image2emb_conv(image, kernel, stride):
    """基于二维卷积来实现patch embedding, embedding的维度就是卷积的输出通道数"""
    conv_output = F.conv2d(image, kernel, stride=stride)  # bs * oc * oh * ow
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh * ow)).transpose(-1, -2)  # [bs, num_patch, model_dim_c]
    return patch_embedding

### 2. 如何构建MHSA并计算其复杂度
- 基于输入x进行三个映射分别得到k、q、v
    - 此步复杂度为$3LC^2$,其中L为序列长度, C为特征大小
- 将q、v、v拆分为多头的形式, 注意这里的多头各自计算不影响, 所以可以与bs维度进行统一看待
- 计算$qk^T$, 并考虑可能的掩码,即让无效的两两位置之间的能量为负无穷, 掩码是在shift window MHSA中会需要,在window MHSA中暂时不需要
    - 此步复杂度为$L^2C$
- 计算概率值与v的乘积
    - 此步复杂度为$L^2C$
- 对输出再次进行映射
    - 此步复杂度为$LC^2$
- 总体复杂度为$4LC^2 + 2L^2C$

In [3]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, model_dim, num_head):
        super().__init__()
        self.num_head = num_head
        self.model_dim = model_dim
        self.proj_linear_layer = nn.Linear(model_dim, 3 * model_dim)
        self.final_linear_layer = nn.Linear(model_dim, model_dim)

    def forward(self, input, additive_mask=None):
        bs, seq_len, model_dim = input.shape
        num_head = self.num_head
        head_dim = model_dim // num_head

        proj_output = self.proj_linear_layer(input)  # [bs, seq_len, 3 * model_dim]
        q, k, v = proj_output.chunk(3, dim=-1)  # 3 * [bs, seq_len, model_dim]

        q = q.reshape(bs, seq_len, num_head, head_dim).transpose(1, 2)  # [bs, num_head, seq_len, head_dim]
        q = q.reshape(bs * num_head, seq_len, head_dim)

        k = k.reshape(bs, seq_len, num_head, head_dim).transpose(1, 2)  # [bs, num_head, seq_len, head_dim]
        k = k.reshape(bs * num_head, seq_len, head_dim)

        v = k.reshape(bs, seq_len, num_head, head_dim).transpose(1, 2)  # [bs, num_head, seq_len, head_dim]
        v = k.reshape(bs * num_head, seq_len, head_dim)

        if additive_mask is None:
            attn_prob = F.softmax(torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(head_dim), dim=-1)
        else:
            additive_mask = additive_mask.tile(num_head, 1, 1)
            attn_prob = F.softmax(torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(head_dim) + additive_mask, dim=-1)

        output = torch.bmm(attn_prob, v)  # [bs * num_head, seq_len, head_dim]
        output = output.reshape(bs, num_head, seq_len, head_dim).transpose(1, 2)  # [bs, seq_len, num_head, head_dim]
        output = output.reshape(bs, seq_len, model_dim)

        output = self.final_linear_layer(output)
        return attn_prob, output

### 3. 如何构建Window MHSA并计算其复杂度？
- 将patch组成的图片进一步划分成一个个更大的window
    - 首先需要将三维的patch embedding转换成图片格式
    - 使用unfold来将patch划分成window
- 在每个window内部计算MHSA
    - window数目其实可以跟batch_size进行统一对待, 因为window与window之间没有交互计算
    - 关于计算复杂度
        - 假设窗的边长为W, 那么计算每个窗的总体复杂度是$4W^2C^2 + 2W^4C$
        - 假设patch的总数目为L, 那么窗的数目为$L / W^2$
        - 因此, W-MHSA的总体复杂度为$4LC^2 + 2LW^2C$
    - 此处不需要mask
    - 将计算结果转换为带window的四维张量格式
- 复杂度对比
    - **MHSA**: $4LC^2 + 2L^2C$
    - **W-MHSA**: $4LC^2 + 2LW^2C$

In [4]:
def window_multi_head_self_attention(patch_embedding, mhsa, window_size=4, num_head=2):
    num_patch_in_window = window_size * window_size
    bs, num_patch, model_dim = patch_embedding.shape
    image_height = image_width = int(math.sqrt(num_patch))

    patch_embedding = patch_embedding.transpose(-1, -2)
    patch = patch_embedding.reshape(bs, model_dim, image_height, image_width)
    window = F.unfold(patch, kernel_size=(window_size, window_size),
                      stride=(window_size, window_size)).transpose(-1, -2)  # [bs, num_window, window_depth]

    bs, num_window, model_dim_times_num_patch_in_window = window.shape
    # [bs * num_window, num_patch, model_dim]
    window = window.reshape(bs * num_window, model_dim, num_patch_in_window).transpose(-1, -2)

    attn_prob, output = mhsa(window)  # [bs * num_window, num_patch_in_window, model_dim]

    output = output.reshape(bs, num_window, num_patch_in_window, model_dim)
    return output

### 4. 如何构建Shift Window MHSA及其Mask?
- 将上一步的W-HMSA结果转换成图片格式
- 假设已经做了新的window划分, 这一步叫做shift-window
- 为了保持window数目不变从而有高效的计算, 需要将图片的patch往左和往上各自滑动半个窗口大小的步长,保持patch所属window类别不变
- 将图片patch还原成window的数据格式
- 由于cycle shift后, 每个window虽然形状规整,但部分window中存在原本不属于同一个窗口的patch,所以需要生成mask
- 如何生成mask?
    - 首先构建一个shift-window的patch所属的window类别矩阵
    - 对该矩阵进行同样的往左和往上各自滑动半个窗口大小的步长的操作
    - 通过unfold操作得到[bs, num_window, num_patch_in_window]形状的类别矩阵
    - 对该矩阵进行扩维成[bs, num_window, num_patch_in_window, 1]
    - 对该矩阵与其转置矩阵进行作差, 得到同类关系矩阵(为0的位置上的patch属于同类, 否则属于不同类)
    - 对同类关系矩阵中非零位置的作用负无穷数进行填充, 对于零的位置用0去填充, 这样就构建好了MHSA需要的mask
    - 此mask的形状为[bs, num_window, num_patch_in_window, patch_depth]
- 将window转换为三维的格式, [bs*num_window, num_patch_in_window, patch_depth]
- 将三维格式的特征连同mask一起送入MHSA中计算得到注意力输出
- 将注意力输出转换成图片patch格式, [bs, num_window, num_patch_in_window, patch_depth]
- 为了恢复位置, 需要将图片的patch往右和往下各自滑动半个窗口大小的步长, 至此, SW-MHSA计算完毕

In [5]:
# 生成类别矩阵
a = torch.randint(10, size=(4, 1))
a

tensor([[8],
        [2],
        [7],
        [0]])

In [6]:
print(a - a.T)
(a - a.T) == 0

tensor([[ 0,  6,  1,  8],
        [-6,  0, -5,  2],
        [-1,  5,  0,  7],
        [-8, -2, -7,  0]])


tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

In [7]:
# 定义一个辅助函数, window2image, 也就是将transformer block的结果转换为图片的格式
def window2image(msa_output):
    bs, num_window, num_patch_in_window, patch_depth = msa_output.shape
    window_size = int(math.sqrt(num_patch_in_window))
    image_height = int(math.sqrt(num_window)) * window_size
    image_width = image_height

    msa_output = msa_output.reshape(bs, int(math.sqrt(num_window)), int(math.sqrt(num_window)),
                                    window_size, window_size, patch_depth)
    msa_output = msa_output.transpose(2, 3)
    image = msa_output.reshape(bs, image_height * image_width, patch_depth)
    image = image.transpose(-1, -2).reshape(bs, patch_depth, image_height, image_width)  # 跟卷积格式一致
    return image

In [8]:
# 构建 shift window multi-head attention mask
def build_mask_for_shifted_wsma(batch_size, image_height, image_width, window_size):
    index_matrix = torch.zeros(image_height, image_width)

    for i in range(image_height):
        for j in range(image_width):
            row_times = (i + window_size // 2) // window_size
            col_times = (j + window_size // 2) // window_size
            index_matrix[i, j] = row_times * (image_height // window_size) + col_times + 1
    rolled_index_matrix = torch.roll(index_matrix, shifts=(-window_size // 2, -window_size // 2), dims=(0, 1))
    rolled_index_matrix = rolled_index_matrix.unsqueeze(0).unsqueeze(0)  # [bs, ch, h, w]

    c = F.unfold(rolled_index_matrix, kernel_size=(window_size, window_size),
                 stride=(window_size, window_size)).transpose(-1, -2)
    c = c.tile(batch_size, 1, 1)  # [bs, num_window, num_patch_in_window]

    bs, num_window, num_patch_in_window = c.shape

    c1 = c.unsqueeze(-1)  # [bs, num_window, num_patch_in_window, 1]
    c2 = (c1 - c1.transpose(-1, -2)) == 0  # [bs, num_window, num_patch_in_window, num_patch_in_window]
    valid_matrix = c2.to(torch.float32)
    additive_mask = (1 - valid_matrix) * (-1e9)  # [bs, num_window, num_patch_in_window, num_patch_in_window]

    additive_mask = additive_mask.reshape(bs * num_window, num_patch_in_window, num_patch_in_window)

    return additive_mask

In [9]:
# 定义辅助函数 shift_window, 高校计算swmhsa
def shift_window(w_msa_output, window_size, shift_size, generate_mask=False):
    bs, num_window, num_patch_in_window, patch_depth = w_msa_output.shape

    w_msa_output = window2image(w_msa_output)  # [bs, depth, h, w]
    bs, patch_depth, image_height, image_width = w_msa_output.shape

    rolled_w_msa_output = torch.roll(w_msa_output, shifts=(shift_size, shift_size), dims=(2, 3))
    shifted_w_msa_input = rolled_w_msa_output.reshape(bs, patch_depth, int(math.sqrt(num_window)), window_size,
                                                       int(math.sqrt(num_window)), window_size)

    shifted_w_msa_input = shifted_w_msa_input.transpose(3, 4)
    shifted_w_msa_input = shifted_w_msa_input.reshape(bs, patch_depth, num_window * num_patch_in_window)
    shifted_w_msa_input = shifted_w_msa_input.transpose(-1, -2)  # [bs, num_window * num_patch_in_window, patch_depth]

    if generate_mask:
        additive_mask = build_mask_for_shifted_wsma(bs, image_height, image_width, window_size)
    else:
        additive_mask = None

    return shifted_w_msa_input, additive_mask

In [10]:
def shift_window_multi_head_self_attention(w_msa_output, mhsa, window_size=4, num_head=2):
    bs, num_window, num_patch_in_window, patch_depth = w_msa_output.shape
    # [ bs, num_window, num_patch_in_window, patch_depth]
    # [bs * num_window, window_size, window_size]
    shift_w_mas_input, additive_mask = shift_window(w_msa_output, window_size,
                                                    shift_size=-window_size//2, generate_mask=True)

    shift_w_mas_input = shift_w_mas_input.reshape(bs * num_window, num_patch_in_window, patch_depth)

    attn_prob, output = mhsa(shift_w_mas_input, additive_mask=additive_mask)

    output = output.reshape(bs, num_window, num_patch_in_window, patch_depth)

    # [bs, num_window, num_patch_in_window, patch_depth]
    output, _ = shift_window(output, window_size, shift_size=window_size//2, generate_mask=False)
    return output

### 5. 如何构建Patch Merging?
- 将window格式的特征转换为图片patch格式
- 利用unfold操作, 按照merge_size * merge_size的大小得到新的patch, 形状为
  [bs, num_patch_new, merge_size * merge_size * patch_depth_old]
- 使用一个全连接层对depth进行降维成0.5倍, 也就是
  从merge_size * merge_size * patch_depth_old 映射到 0.5 * merge_size * merge_size * patch_depth_old
- 输出的是patch embedding的形状格式, [bs, num_patch, patch_depth]
- 举例说明: 以merge_size=2为例, 经过patch_merging后, patch的数目减少为之前的$\frac{1}{4}$, 但是depth增加为原来的2倍, 而不是4倍

In [11]:
class PatchMerging(nn.Module):
    def __init__(self, model_dim, merge_size, output_depth_scale=0.5):
        super().__init__()
        self.merge_size = merge_size
        self.proj_layer = nn.Linear(
            model_dim * merge_size * merge_size,
            int(model_dim * merge_size * merge_size * output_depth_scale)
        )

    def forward(self, input):
        bs, num_window, num_patch_in_window, patch_depth = input.shape
        window_size = int(math.sqrt(num_patch_in_window))

        input = window2image(input)  # [bs, patch_depth, image_h, image_w]

        merged_window = F.unfold(input, kernel_size=(self.merge_size, self.merge_size),
                                 stride=(self.merge_size, self.merge_size)).transpose(-1, -2)
        merged_window = self.proj_layer(merged_window)  # [bs, num_path, new_patch_depth]
        return merged_window

### 6. 如何构建SwinTransformerBlock?
- 每个block包含LayerNorm、W-MHSA、MLP、SW-MHSA、残差连接等模块
- 输入是patch embedding格式
- 每个MLP包含两层, 分别是4 * model_dim和model_dim大小
- 输出的是window的数据格式, [bs, num_window, num_patch_in_window, patch_depth]
- 需要注意残差连接对数据形状的要求

In [12]:
class SwinTransformerBlock(nn.Module):

    def __init__(self, model_dim, window_size, num_head):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(model_dim)
        self.layer_norm2 = nn.LayerNorm(model_dim)
        self.layer_norm3 = nn.LayerNorm(model_dim)
        self.layer_norm4 = nn.LayerNorm(model_dim)

        self.wsma_mlp1 = nn.Linear(model_dim, 4 * model_dim)
        self.wsma_mlp2 = nn.Linear(4 * model_dim, model_dim)
        self.swsma_mlp1 = nn.Linear(model_dim, 4 * model_dim)
        self.swsma_mlp2 = nn.Linear(4 * model_dim, model_dim)

        self.mhsa1 = MultiHeadSelfAttention(model_dim, num_head)
        self.mhsa2 = MultiHeadSelfAttention(model_dim, num_head)

    def forward(self, input):
        bs, num_patch, patch_depth = input.shape

        input1 = self.layer_norm1(input)
        w_msa_output = window_multi_head_self_attention(input, self.mhsa1, window_size=4, num_head=2)
        bs, num_window, num_patch_in_window, patch_depth = w_msa_output.shape
        w_msa_output = input + w_msa_output.reshape(bs, num_patch, patch_depth)
        output1 = self.wsma_mlp2(self.wsma_mlp1(self.layer_norm2(w_msa_output)))
        output1 += w_msa_output

        input2 = self.layer_norm3(output1)
        input2 = input2.reshape(bs, num_window, num_patch_in_window, patch_depth)
        sw_msa_output = shift_window_multi_head_self_attention(input2, self.mhsa2, window_size=4, num_head=2)
        sw_msa_output = output1 + sw_msa_output.reshape(bs, num_patch, patch_depth)
        output2 = self.swsma_mlp2(self.swsma_mlp1(self.layer_norm4(sw_msa_output)))
        output2 += sw_msa_output

        output2 = output2.reshape(bs, num_window, num_patch_in_window, patch_depth)

        return output2

### 7. 如何构建SwinTransformerModel?
- 输入是图片
- 首先对图片进行分块并得到Patch embedding
- 经过第一个stage
- 进行patch merging, 再进行第二个stage
- 以此类推...
- 对最后一个block的输出转换成patch embedding的格式, [bs, num_patch, patch_depth]
- 对patch embedding在时间维度进行平均池化, 并映射到分类层得到分类的logits

In [13]:
class SwinTransformerModel(nn.Module):

    def __init__(self, input_image_channel=3, patch_size=4, model_dim_C=8,
                 num_classes=10, window_size=4, num_head=2, merge_size=2):
        super(SwinTransformerModel, self).__init__()
        patch_depth = patch_size * patch_size * input_image_channel
        self.patch_size = patch_size
        self.model_dim_C = model_dim_C
        self.num_classes = num_classes

        self.patch_embdding_weight = nn.Parameter(torch.randn(patch_depth, model_dim_C))
        self.block1 = SwinTransformerBlock(model_dim_C, window_size, num_head)
        self.block2 = SwinTransformerBlock(2 * model_dim_C, window_size, num_head)
        self.block3 = SwinTransformerBlock(4 * model_dim_C, window_size, num_head)
        self.block4 = SwinTransformerBlock(8 * model_dim_C, window_size, num_head)

        self.patch_merging1 = PatchMerging(model_dim_C, merge_size)
        self.patch_merging2 = PatchMerging(model_dim_C * 2, merge_size)
        self.patch_merging3 = PatchMerging(model_dim_C * 4, merge_size)

        self.final_layer = nn.Linear(model_dim_C * 8, model_dim_C)

    def forward(self, image):
        print('self.patch_size', self.patch_size)
        patch_embedding_naive = image2emb_naive(image, self.patch_size, self.patch_embdding_weight)

        patch_embedding = patch_embedding_naive
        print(patch_embedding.shape)

        sw_msa_output = self.block1(patch_embedding_naive)
        print('block1_output:', sw_msa_output.shape)

        merged_patch1 = self.patch_merging1(sw_msa_output)
        sw_msa_output_1 = self.block2(merged_patch1)
        print('block2_output:', sw_msa_output_1.shape)

        merged_patch2 = self.patch_merging2(sw_msa_output_1)
        sw_msa_output_2 = self.block3(merged_patch2)
        print('block3_output:', sw_msa_output_2.shape)

        merged_patch3 = self.patch_merging3(sw_msa_output_2)
        sw_msa_output_3 = self.block4(merged_patch3)
        print('block4_output:', sw_msa_output_3.shape)

        bs, num_window, num_patch_in_window, patch_depth = sw_msa_output_3.shape
        sw_msa_output_3 = sw_msa_output_3.reshape(bs, -1, patch_depth)

        pool_output = torch.mean(sw_msa_output_3, dim=1)
        logits = self.final_layer(pool_output)
        print('logits:', logits.shape)
        return logits

### 8. 模型测试代码

In [14]:
if __name__ == '__main__':
    bs, ic, image_h, image_w = 4, 3, 256, 256
    patch_size = 4
    model_dim_C = 8  # 一开始的patch embedding 大小
    max_num_token = 16
    num_classes = 10
    window_size = 4
    num_head = 2
    merge_size = 2

    patch_depth = patch_size * patch_size * ic
    image = torch.randn(bs, ic, image_h, image_w)
    model = SwinTransformerModel(ic, patch_size, model_dim_C, num_classes, window_size, num_head, merge_size)
    model(image)

self.patch_size 4
torch.Size([4, 4096, 8])
block1_output: torch.Size([4, 256, 16, 8])
block2_output: torch.Size([4, 64, 16, 16])
block3_output: torch.Size([4, 16, 16, 32])
block4_output: torch.Size([4, 4, 16, 64])
logits: torch.Size([4, 8])
