In [1]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m594.4 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [2]:
import random
import numpy as np
import torch
import cv2
import math
import torch.nn.functional as F
from torch import nn
from skimage import transform
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [3]:
class OurFE(nn.Module):
  def __init__(self,channel,dim):
    super(OurFE,self).__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(channel, channel, kernel_size=1),
        nn.BatchNorm2d(channel),
        nn.ReLU()
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(channel, channel, kernel_size=1),
        nn.BatchNorm2d(channel),
        nn.ReLU()
    )
    self.conv3 = nn.Sequential(
        nn.Conv2d(channel, channel, kernel_size=1),
        nn.BatchNorm2d(channel),
        nn.ReLU()
    )
    self.out_conv = nn.Sequential(
        nn.Conv2d(3*channel,channel,kernel_size=3,padding=1),
        nn.BatchNorm2d(channel),
        nn.ReLU()
    )
  def forward(self,x):
    out1 = self.conv1(x)
    out2 = self.conv2(out1)
    out3 = self.conv3(out2)
    out = self.out_conv(torch.cat((out1,out2,out3),dim=1))
    return out

In [4]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))

In [5]:
class FeedForward(nn.Module):
  def __init__(self,dim):
    super().__init__()
    self.net = nn.Sequential(
        DEPTHWISECONV(dim,256,kernel_size=3,padding=1,stride=1),
        nn.BatchNorm2d(256),
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1),
        nn.GELU(),
        nn.Conv2d(in_channels=512, out_channels=dim, kernel_size=1),
        nn.GELU(),
    )
  def forward(self,x):
    b,d,c=x.shape
    w = int(math.sqrt(d))
    x1 = rearrange(x, 'b (w h) c -> b c w h', w=w, h=w)
    x1 = self.net(x1)
    x1 = rearrange(x1, 'b c w h -> b (w h) c')
    x = x + x1
    return x

In [6]:
class DEPTHWISECONV(nn.Module):
  def __init__(self,in_ch,out_ch,kernel_size=1,padding=0,stride=1,is_fe=False):
    super(DEPTHWISECONV,self).__init__()
    self.is_fe = is_fe
    self.depth_conv = nn.Conv2d(
        in_channels=in_ch,
        out_channels=in_ch,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        groups=in_ch
    )
    self.point_conv = nn.Conv2d(
        in_channels=in_ch,
        out_channels=out_ch,
        kernel_size=1,
        stride=1,
        padding=0,
        groups=1
    )
  def forward(self,input):
    out = self.depth_conv(input)
    if self.is_fe:
      return out
    out = self.point_conv(out)
    return out

In [19]:
class Attention(nn.Module):
  def __init__(self,dim,heads=4,dim_head=64,dropout=0.,num_patches=10):
    super().__init__()
    inner_dim = dim_head*heads
    project_out = not(heads==1 and dim_head==dim)

    self.heads = heads
    self.scale = dim_head ** -0.5
    self.attend = nn.Softmax(dim=-1)
    self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
    self.to_out = nn.Sequential(
        nn.Linear(inner_dim,dim),
        nn.Dropout(dropout)
    ) if project_out else nn.Identity()

    self.spatial_norm = nn.BatchNorm2d(heads)
    self.spatial_conv = nn.Conv2d(heads, heads, kernel_size=3, padding=1)

    self.spectral_norm = nn.BatchNorm2d(1)
    self.spectral_conv = nn.Conv2d(1, 1, kernel_size=3, padding=1)
    self.to_qkv_spec = nn.Linear(dim, dim*3, bias=False)
    self.attend_spec = nn.Softmax(dim=-1)
  def forward(self,x):
    print(".........spaAttention")
    qkv = self.to_qkv(x).chunk(3,dim=-1)
    #print(".........qkv.shape:",qkv.shape)
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
    print(".........q.shape：{} | k.shape:{} | v.shape:{}".format(str(q.shape),str(k.shape),str(v.shape)))
    dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
    print(".........dots_spa.shape:",dots.shape)
    attn = self.attend(dots)
    attn = self.spatial_conv(attn)
    print(".........afetr conv2d attn_spa.shape:",attn.shape)
    out = torch.matmul(attn, v)
    print(".........out_spa.shape:",out.shape)
    out = rearrange(out, 'b h n d -> b n (h d)')
    print(".........after rearrange out.spa.shape:",out.shape)
    output = self.to_out(out)
    print(".........final out.spa.shape:",output.shape)


    print(".........speAttention")
    #x = x.transpose(-2, -1)
    print(".........speAttention inputdata x.shape:",x.shape)
    qkv_spec = self.to_qkv_spec(x).chunk(3, dim=-1)
    q_spec, k_spec, v_spec = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=1), qkv_spec)

    print(".........q_spec.shape：{} | k_spec.shape:{} | v_spec.shape:{}".format(str(q_spec.shape),str(k_spec.shape),str(v_spec.shape)))
    dots_spec = torch.matmul(q_spec.transpose(-1, -2), k_spec) * self.scale
    print(".........dots_spec.shape:",dots_spec.shape)
    attn = self.attend_spec(dots_spec)  # .squeeze(dim=1)
    print(".........attend_spec.shape:",attn.shape)
    attn = self.spectral_conv(attn)
    print(".........afetr attend_spec.shape:",attn.shape)
    out_spec = torch.matmul(attn, v_spec.transpose(-2, -1))
    out_spec = out_spec.squeeze(dim=1)
    out_spec = out_spec.transpose(-2, -1)
    print(".........out_spec.shape:",out_spec.shape)

    #out_final = torch.matmul(output, out_spec).squeeze(dim=1)
    out_final = output+out_spec
    print(".........out_final.shape:",out_final.shape)

    return out_final

In [8]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, dropout=0., num_patches=25):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.index = 0
        for i in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, num_patches=num_patches)),
                PreNorm(dim, FeedForward(dim)),
            ]))

    def forward(self, x):
        print("...attentionModule input x.shape:",x.shape)
        output = []
        for attn, ff in self.layers:
            x = attn(x) + x
            print("...attentionModule after attn x.shape:",x.shape)
            x = ff(x) + x
            print("...attentionModule after FeedForward x.shape:",x.shape)
            output.append(x)
            print("...the len of output is:",len(output))

        return x, output#output是将每层transformerencoder的结果都储存了下来。

In [9]:
class SubNet(nn.Module):
  def __init__(self,patch_size,num_patches,dim,emb_dropout,depth,heads,dim_head,mlp_dim,dropout):
    super(SubNet,self).__init__()
    self.to_patch_embedding = nn.Sequential(
        DEPTHWISECONV(in_ch=dim, out_ch=dim, kernel_size = patch_size, stride = patch_size, padding=0, is_fe=True),
        Rearrange('b c w h -> b (h w) c '),
    )
    self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
    self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches+1, dim))
    self.dropout = nn.Dropout(emb_dropout)
    self.transformer = Transformer(dim, depth, heads, dim_head, dropout=dropout, num_patches=num_patches)


In [10]:
def get_num_patches(ps, ks):
    return int((ps - ks)/ks)+1

In [11]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super(ViT, self).__init__()
        self.ournet = OurFE(channels, dim)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=channels, out_channels=dim, kernel_size=1)
        self.net = nn.Sequential()#多个subnet的集合
        self.mlp_head = nn.ModuleList()#多个分类头的集合
        for ps in patch_size:# subpatch size
          #如果将patch_size设置为【3，5】这会产生两个子网络。子网络包含MDTE模块
          num_patches = get_num_patches
          num_patches = get_num_patches(image_size, ps) ** 2
          patch_dim = dim * num_patches
          print("ps:{}, num_patches:{}, dim:{}, patch_dim:{}".format(ps,num_patches,dim,patch_dim))
          sub_net = SubNet(ps, num_patches, dim, emb_dropout, depth, heads, dim_head, mlp_dim, dropout)
          self.net.append(sub_net)
          self.mlp_head.append(nn.Sequential(
              nn.LayerNorm(patch_dim),
              nn.Linear(patch_dim,num_classes)
          ))
        self.weight = torch.ones(len(patch_size))
    def forward(self,img):
      #if len(img.shape) == 5: img = img.squeeze()
      print("input data img.shape:",img.shape)
      img = self.ournet(img)# CNN feature extractor
      print("after FE img.shape:",img.shape)
      img = self.pool(img)
      print("after pool img.shape:",img.shape)
      img = self.conv4(img)
      print("after conv4 img.shape:",img.shape)
      print("****************************【FeatureExtraction finish】*******************************")
      all_branch = []
      ith_layer=1
      for sub_branch in self.net:
        spatial = sub_branch.to_patch_embedding(img)#多粒度深度卷积标记嵌入（MDTE）模块  这一步会得到token
        print("after to_patch_embedding spatial.shape:",spatial.shape)
        b, n, c = spatial.shape
        print("b:{} | n:{} | c:{}".format(b,n,c))
        print("****************************【tokenlize finish】*******************************")
        spatial = spatial + sub_branch.pos_embedding[:, :n]
        print("(spatial + sub_branch.pos_embedding).shape:",spatial.shape)
        spatial = sub_branch.dropout(spatial)
        _, outputs = sub_branch.transformer(spatial)
        #print("after transformer outputs:",outputs)
        res = outputs[-1]#选取最后一层transformer encoder的输出作为res【transformerencoder的层数就是depth,output会保存每层的输出】【为什么不直接输出最后一个结果呢？？？】
        print("res.shape",res.shape)
        all_branch.append(res)#将每一个子网络的结果存入列表
        print("第{}支子网络运行完毕！".format(ith_layer))
        ith_layer+=1

      print("****************************【TransformerEncoder finish】*******************************")
      print("all_branch:",len(all_branch))
      self.weight = F.softmax(self.weight, 0)
      res = 0
      for i, mlp_head in enumerate(self.mlp_head):
        out1 = all_branch[i].flatten(start_dim=1)
        print("压缩第{}个子网络的输出，压缩后的维度为：{}".format(int(i+1),str(out1.shape)))
        cls1 = mlp_head(out1)
        print("分类头处理第{}个子网络的压缩输出，处理后的维度为：{}".format(int(i+1),str(cls1.shape)))
        res = res + cls1 * self.weight[i]
      print("****************************【classification finish】*******************************")
      return res

In [20]:
ps=[3, 5]
d_h=[4,2]#深度为四，将 transformer encoder循环四次
#model = ViT(image_size=15, patch_size=ps, num_classes=16, dim=100, depth=d_h[0], heads=d_h[1],mlp_dim=2048, channels=3, dropout=0.2, emb_dropout=0.2)
# 随机输入，测试网络结构是否通
x = torch.randn(1, 30, 15, 15)
net = ViT(image_size=15, patch_size=ps, num_classes=16, dim=100, depth=d_h[0], heads=d_h[1],mlp_dim=2048, channels=30, dropout=0.2, emb_dropout=0.2)
y = net(x)
print(y.shape)

ps:3, num_patches:25, dim:100, patch_dim:2500
ps:5, num_patches:9, dim:100, patch_dim:900
input data img.shape: torch.Size([1, 30, 15, 15])
after FE img.shape: torch.Size([1, 30, 15, 15])
after pool img.shape: torch.Size([1, 30, 15, 15])
after conv4 img.shape: torch.Size([1, 100, 15, 15])
****************************【FeatureExtraction finish】*******************************
after to_patch_embedding spatial.shape: torch.Size([1, 25, 100])
b:1 | n:25 | c:100
****************************【tokenlize finish】*******************************
(spatial + sub_branch.pos_embedding).shape: torch.Size([1, 25, 100])
...attentionModule input x.shape: torch.Size([1, 25, 100])
.........spaAttention
.........q.shape：torch.Size([1, 2, 25, 64]) | k.shape:torch.Size([1, 2, 25, 64]) | v.shape:torch.Size([1, 2, 25, 64])
.........dots_spa.shape: torch.Size([1, 2, 25, 25])
.........afetr conv2d attn_spa.shape: torch.Size([1, 2, 25, 25])
.........out_spa.shape: torch.Size([1, 2, 25, 64])
.........after rearrange o