In [None]:
import torch
import torch.nn as nn

## PatchEmbed

In [None]:
class PatchEmbed(nn.Module):
  def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
    # image_size 图像大小 patch_size 每个patch的大小
    super().__init__()
    img_size = (img_size, img_size) # 将输入大小变为二维元组
    patch_size = (patch_size, patch_size)
    self.img_size = img_size
    self.patch_size = patch_size
    self.grid_size = (img_size[0]//patch_size[0], img_size[1]//patch_size[1]) # patch的网格大小
    self.num_patches = self.grid_size[0] * self.grid_size[1] # patch的总数 14*14=196

    self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,stride=patch_size) # B,3,224,224->B,768,14,14
    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() # 若有layer_norm则使用 若无则保持不变

  def forward(self,x):
    B,C,H,W = x.shape # 获取我们输入张量的形状
    assert H==self.img_size[0] and W==self.img_size[1],\
    f"输入图像的大小{H}*{W}与模型期望大小{self.img_size[0]}*{self.img_size[1]}不匹配"
    # B,3,224,224 -> B,768,14,14 -> B,768,196 -> B,196,768
    x = self.proj(x).flatten(2).transpose(1,2)
    x = self.norm(x) # 若有归一化层 则使用
    return x


## Attention

In [None]:
class Attention(nn.Module):
  def __init__(self,
         dim, # 输入的token维度, 768
         num_heads = 8, # 注意力头数， 为8
         qkv_bias = False, # 生成QKV的时候是否添加偏置
         qk_scale = None, # 用于缩放QK的系数，如果None，则使用1/sqrt(embed_dim_pre_head)
         atte_drop_ration = 0., # 注意力分数的dropout的比率，防止过拟合
         proj_drop_ration = 0.): # 最终投影层的dropout比例
    super().__init__()
    self.num_head = num_heads # 注意力头数
    head_dim = dim // num_heads # 每个注意力头的维度
    self.scale = qk_scale or head_dim ** -0.5 # qk的缩放因子
    self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias) # 通过全连接层生成QKV，为了并行计算，提高计算效率，参数更少
    self.att_drop = nn.Dropout(atte_drop_ration)
    self.proj_drop = nn.Dropout(proj_drop_ration)
    # 将每个head得到输出进行concat拼接，然后通过线性变换映射为原本的嵌入维度
    self.proj = nn.Linear(dim,dim)

  def forward(self,x):
    B,N,C = x.shape # batch,num_patch+1,embed_dim 这个1为clstoken
    # B N 3*C -> B,N,3,num_heads,C//self.num_heads
    # B,N,3,num_heads,C//self.num_heads -> 3,B,num_heads,N,C//self.num_heads
    qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4) # 方便我们之后做运算
    # 用切片拿到QKV，形状B,num_heads,N,C//self.num_heads
    q,k,v = qkv[0],qkv[1],qkv[2]
    # 计算qk的点积，并进行缩放，得到注意力分数
    # q:(B,num_heads,N,C//self.num_heads)
    # k.transpose(-2,-1) k:(B,num_heads,N,C//self.num_heads) -> (B,num_heads,C//self.num_heads,N)
    attn = (q @ k.transpose(-2,-1))*self.scale #(B,num_heads,N,N)
    attn = attn.softmax(dim=-1) # 对每行进行处理 使得每行的和为1
    # 注意力权重对v进行加权求和
    # attn @ v:(B,num_heads,N,C//self.num_heads)
    # transpose: (B,N,self.num_heads,C//self.num_heads)
    # reshape: (B,N,C), 将最后两个维度进行信息拼接合并多个头输出,回到总的嵌入维度
    x = (attn @ v).transpose(1,2).reshape(B,N,C)
    # 通过线性变换映射为原本的嵌入维度
    x = self.proj(x)
    x = self.proj_drop(x) #防止过拟合

    return x