In [2]:
from collections import OrderedDict
from typing import Union, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

In [29]:
d_model = 768
n_head = 12
model = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head)

In [3]:
class Bottleneck(nn.Module):
    expansion = 4
    
    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
        
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu3 = nn.ReLU(inplace=True)
        
        self.downsample = None
        self.stride = stride
        
        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))
            
    def forward(self, x: torch.Tensor):
        identity = x
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(x)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(x))
        
        if self.downsample is not None:
            # 这里主要是为了原始的x与三层卷积后的x维度相匹配
            identity = self.downsample(x)
        
        out += identity
        out = self.relu3(out)
        return out
        

In [41]:
class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int=None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x:torch.Tensor):
        x = x.flatten(start_dim=2).permute(2, 0, 1) # 从第3维开始展平 NC(HW)->(HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)# (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype) # positional这里是在中间增加了一维
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
            embed_dim_to_check=x.shape[-1], # channel的特征为embedding
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj,
            k_proj_weight=self.k_proj,
            v_proj_weight=self.v_proj,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), # 64 64 64 -> 192
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight, # 不懂这是干嘛的
            out_proj_bias=self.c_proj.bias, # 不懂这是干嘛的
            use_separate_proj_weight=True, # 不懂这是干嘛的
            training=self.training,
            need_weights=False
        )
        
        return x.squeeze(0)
        
        

In [40]:
q = nn.Linear(64, 64)
k = nn.Linear(64, 64)
bias = torch.cat([q.bias, k.bias])
bias.shape

torch.Size([128])

In [None]:
class ModifiedResNet(nn.Module):
    
    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution
        
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)
        
        self._inplanes = width
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
        
        embed_dim = width * 32
        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
        
    
    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]
        
        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        def stem(x):
            x = self.relu1(self.bn1(self.conv1(x)))
            x = self.relu2(self.bn2(self.conv2(x)))
            x = self.relu3(self.bn3(self.conv3(x)))
            x = self.avgpool(x)
            return x
        
        x = x.type(self.conv1.weight.dtype)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.attnpool(x)
        
        return x 

In [12]:
class LayerNorm(nn.LayerNorm):
    
    def forward(self, x:torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

In [None]:
class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        # 这里是利用了sigmoid对GELU的计算进行了简化？
        return x * torch.sigmoid(1.702 * x)

In [None]:
class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model, n_head, attn_mask=None):
        super(ResidualAttentionBlock, self).__init__()
        
        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask
        
    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] # 三个x分别为q，k，v
    
    def forward(self, x: torch.Tensor):
        # openai的clip中使用的transformer模型是将正则化放在输入数据之前
        # 与常规的atten->add->norm->feedforward->add->norm写法有所不同
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    
class Transformer(nn.Module):
    # 在NLP里叫d_model，在CV里叫width
    def __init__(self, d_model: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.d_model = d_model
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(d_model, heads, attn_mask) for _ in range(layers)])
        
    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

In [11]:
class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int,
                 heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        # 卷积核大小为一个patch的大小，卷积核个数与输入attention层中的d_model(width)维度相同
        # 例如512*512大小的图像，使用width个大小为128的卷积核，得到的数据就是b*width*4*4
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
        
        scale = width ** -0.5  # 1/sqrt(width)
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)
        
        self.transformer = Transformer(width, layers, heads)
        
        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
        
    def forward(self, x:torch.Tensor):
        x = self.conv1(x) # shape=[*,width,grid,grid]
        x = x.reshape(x.shape[0], x.shape[1], -1) # shape=[*,width,grid**2]
        x = x.permute(0, 2, 1) # shape=[*,grid**2,width] width表示这个patch的特征维度
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        # shape = [*, grid**2+1, width]
        x = x + self.positional_embedding.to(x.dtype) #这里维度不一样相加可能有问题？
        x = self.ln_pre(x)
        
        x = x.permute(1, 0, 2) #这里为啥要转换一下维度？ [grid**2+1, *, width]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        
        x = self.ln_post(x)
        if self.proj is not None:
            x = x @ self.proj

        return x
        
        

In [None]:
class CLIP(nn.Module):
    def __init__(self, 
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int):
        super().__init__()
        
        self.context_length = context_length
        
        if isinstance(vision_layers, (tuple, list)):
            vision_heads = vision_width * 32 // 64  # 这里的32和64都没看懂是干啥的
            self.visual = ModifiedResNet(
                layers=vision_layers,
                output_dim=embed_dim,
                heads=vision_heads,
                input_resolution=image_resolution,
                width=vision_width
            )
        else:
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(
                input_resolution=image_resolution,
                patch_size=vision_patch_size,
                width=vision_width,
                layers=vision_layers,
                heads=vision_heads,
                output_dim=embed_dim
            )
            
        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask==self.build_attention_mask()
        )
        
        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)
        
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([])) * np.log(1 / 0.07)

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)
        
        if isinstance(self.visual, ModifiedResNet):
            if self.visual.attnpool is Not None:
                std = self.visual.attnpool.c_proj.in_features ** -0.5
                nn.init.normal_(self.visual.attnpool)
        
    def build_attention_mask(self):
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1) # 上三角矩阵，不包含对角线
        return mask
        
    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

In [17]:
input_resolution = 512
patch_size = 128
width = 768
scale = 128 ** -0.5

positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
positional_embedding.shape


torch.Size([17, 768])

In [19]:
b = torch.zeros((8,17,168))
b + positional_embedding

RuntimeError: The size of tensor a (168) must match the size of tensor b (768) at non-singleton dimension 2

In [9]:
a = torch.IntTensor([[1,2,3],[4,5,6],[7,8,9]])
a.triu_()

tensor([[1, 2, 3],
        [0, 5, 6],
        [0, 0, 9]], dtype=torch.int32)