vit 实现的第一种方法，也是比赛中使用的就是从timm库里面直接调用

In [117]:
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

In [118]:
class Encoder(nn.Module):
    def __init__(self, model_name='vit_base_patch16_224', pretrained=False):
        super().__init__()
        self.cnn = timm.create_model(model_name, pretrained)
        self.n_features = self.cnn.head.in_features
        self.cnn.head = nn.Identity()
        
    def forward(self, x):
        B = x.shape[0]
        x = self.cnn.patch_embed(x)
        # 我们的vit有个cls头
        cls_token = self.cnn.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.cnn.pos_embed
        x = self.cnn.pos_drop(x)
        for blk in self.cnn.blocks:
            x = blk(x)
        x = self.cnn.norm(x)
        return x
        

In [119]:
model = Encoder()
a = torch.rand(4,3,224,224)
model(a).size()

torch.Size([4, 197, 768])

In [120]:
# for i,j in a.named_parameters():
#     print(i, j.size())
model.cnn.patch_embed(a).size()

torch.Size([4, 196, 768])

#### vit 实现

In [133]:
image_size = 224
patch_size = 16
d_model = 768
n_head = 8
qkv_dim = 96 # 在vit中qkv_dim等于d_model/n_head
n_layers = 3
ffn_dim = 768
num_classes = 2

### 没有mask的transfomer结构

In [122]:
class Attention(nn.Module):
    def __init__(self, image_dim=d_model, head=n_head, dim_head=qkv_dim, dropout=0):
        super().__init__()
        inner_dim = dim_head * head
        self.image_dim = image_dim
        self.head = head
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(image_dim, inner_dim*3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, image_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x, mask=None):
        '''x_size : bs, leng((image_size/patch_size)**2+1), dim'''
        b, n, _, = x.shape
        residual = x
        head = self.head
        q,k,v = self.to_qkv(x).chunk(3, dim=-1) # bs, len, dim_head*n_head
        q = q.reshape(b, n, head, -1).transpose(1,2) # bs, len, dim_head*n_head->bs, len, n_head, dim_head->bs, n_head, len, dim_head
        k = k.reshape(b, n, head, -1).transpose(1,2)
        v = v.reshape(b, n, head, -1).transpose(1,2)
        attn_score = torch.matmul(q, k.transpose(-1, -2)) / self.scale 
        attn_score = nn.Softmax(dim=-1)(attn_score)
        
        context = torch.matmul(attn_score, v)
        context = context.transpose(1, 2).reshape(b, n, -1)
        output = self.to_out(context)
        return nn.LayerNorm(self.image_dim)(residual+output)

In [123]:
class FeedForward(nn.Module):
    def __init__(self, dim=d_model, hidden_dim=ffn_dim, dropout=0):
        super().__init__()
        self.dim = dim
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        residual = x
        output = self.net(x)
        return nn.LayerNorm(self.dim)(residual+output)

In [124]:
class Transfomer(nn.Module):
    def __init__(self, dim=d_model, depth=n_layers):
        super().__init__()
        self.layers = []
        for i in range(depth):
            self.layers.append(Attention())
            self.layers.append(FeedForward())
        self.net = nn.Sequential(*self.layers)
    def forward(self, x):
        return self.net(x)

In [125]:
model = Transfomer()
model(x).size()

torch.Size([4, 197, 768])

In [126]:
from einops import rearrange, repeat

In [177]:
class ViT(nn.Module):
    def __init__(self, image_size=image_size, patch_size=patch_size, num_classes=num_classes, dim=d_model, depth=n_layers, heads=n_head, mlp_dim=ffn_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        self.num_patches = (image_size//patch_size)**2
        self.conv = nn.Conv2d(3, dim, patch_size, stride=patch_size, padding=patch_size//2-1)
        self.patch_size = patch_size
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches+1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(dropout)
        
        self.transfomer = Transfomer()
        self.pool = pool
        self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
    def forward(self, img):
        p = self.num_patches
        b = img.size(0)
        x = self.conv(img).permute(0,2,3,1).contiguous().reshape(b, p, -1)
        _,n,_ = x.size()
        cls_token = self.cls_token.repeat(b, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x += self.pos_embed
        x = self.transfomer(x)
        x = x.mean(dim=1) if self.pool=='mean' else x[:, 0]
        return self.mlp_head(x)   

In [178]:
conv = nn.Conv2d(3, d_model, patch_size, stride=patch_size, padding=patch_size//2-1)

In [179]:
img = torch.rand(4,3,224,224)
conv(img).permute(0,2,3,1).contiguous().reshape(4, (image_size//patch_size)**2, -1).size()

torch.Size([4, 196, 768])

In [180]:
img = torch.rand(4,3,224,224)
model = ViT()
model(img).size()

torch.Size([4, 2])

In [173]:
x.size()

torch.Size([4, 197, 768])