In [None]:
from collections import OrderedDict
from typing import Union, Tuple

import torch
import torch.nn as nn

In [None]:
d_model = 768
n_head = 12
model = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head)

In [None]:
class Bottleneck(nn.Module):
    expansion = 4
    
    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
        
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu3 = nn.ReLU(inplace=True)
        
        self.downsample = None
        self.stride = stride
        
        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))
            
    def forward(self, x: torch.Tensor):
        identity = x
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(x)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(x))
        
        if self.downsample is not None:
            # 这里主要是为了原始的x与三层卷积后的x维度相匹配
            identity = self.downsample(x)
        
        out += identity
        out = self.relu3(out)
        return out
        

In [None]:
class AttentionPool2d(nn.Module):
    pass

In [None]:
class ModifiedResNet(nn.Module):
    
    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution
        
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)
        
    
    def _make_layer(self, planes, blocks, stride=1):
        

In [12]:
class LayerNorm(nn.LayerNorm):
    
    def forward(self, x:torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

In [None]:
class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        # 这里是利用了sigmoid对GELU的计算进行了简化？
        return x * torch.sigmoid(1.702 * x)

In [None]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model, n_head, attn_mask=None):
        super(ResidualAttentionBlock, self).__init__()
        
        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask
        
    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] # 三个x分别为q，k，v
    
    def forward(self, x: torch.Tensor):
        # openai的clip中使用的transformer模型是将正则化放在输入数据之前
        # 与常规的atten->add->norm->feedforward->add->norm写法有所不同
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    
class Transformer(nn.Module):
    # 在NLP里叫d_model，在CV里叫width
    def __init__(self, d_model: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.d_model = d_model
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(d_model, heads, attn_mask) for _ in range(layers)])
        
    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: 
                 int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        # 卷积核大小为一个patch的大小，卷积核个数与输入attention层中的d_model维度相同
        # 例如512*512大小的图像，使用width个大小为128的卷积核，得到的数据就是b*width*4*4
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
        
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)
        
        self.transformer = Transformer(width, layers, heads)
        
        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
        
    def forward(self, x:torch.Tensor):
        x = self.conv1(x) # shape=[*,width,grid,grid]
        x = x.reshape(x.shape[0], x.shape[1], -1) # shape=[*,width,grid**2]
        x = x.permute(0, 2, 1) # shape=[*,grid**2,width] width表示这个patch的特征维度
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        # shape = [*, grid**2+1, width]
        x = x + self.positional_embedding.to(x.dtype) #这里维度不一样相加可能有问题？
        x = self.ln_pre(x)
        
        x = x.permute(1, 0, 2) #这里为啥要转换一下维度？
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        
        x = self.ln_post(x)
        if self.proj is not None:
            x = x @ self.proj

        return x
        
        

In [None]:
class CLIP(nn.Module):
    def __init__(self, 
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int):
        super().__init__()
        self.context_length = context_length

In [17]:
input_resolution = 512
patch_size = 128
width = 768
scale = 128 ** -0.5

positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
positional_embedding.shape


torch.Size([17, 768])

In [19]:
b = torch.zeros((8,17,168))
b + positional_embedding

RuntimeError: The size of tensor a (168) must match the size of tensor b (768) at non-singleton dimension 2