In [None]:
# @title Notes
# https://miro.com/app/board/uXjVGaELrhQ=/

In [2]:
# @title Imports
import torch
import torch.nn as nn

In [3]:
# @title Global variables

In [41]:
# @title Efficient self attention
class EfficientSelfAttention(nn.Module):
  def __init__(self,num_heads,reduction_ratio,dim):
    super().__init__()
    self.num_heads=num_heads
    self.reduction_ratio=reduction_ratio
    self.q_proj=nn.Linear(dim,dim)
    self.kv_proj=nn.Linear(dim,dim*2)
    self.out_proj=nn.Linear(dim,dim)
    self.norm=nn.LayerNorm(dim)
    self.sr=nn.Conv2d(dim,dim,kernel_size=reduction_ratio,stride=reduction_ratio)
    self.scale=(dim//num_heads)**-0.5
    self.out_drop=nn.Dropout(0.1)

  def forward(self,x,H,W):
    B, N, C = x.shape
    q=self.q_proj(x).reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3)

    x_=x.permute(0,2,1).reshape(B, C, H, W)
    x_=self.sr(x_).reshape(B,C,-1).permute(0,2,1)
    x_=self.norm(x_)
    kv=self.kv_proj(x_).reshape(B,-1,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
    k=kv[0]
    v=kv[1]

    attn=(q@k.transpose(-1,-2))*self.scale
    attn=attn.softmax(dim=-1)
    out=(attn@v).transpose(1,2).reshape(B,N,C)
    out=self.out_proj(out)
    out=self.out_drop(out)
    return out

In [42]:
# @title Feed Forward
class FeedForward(nn.Module):
  def __init__(self, n_embd,hidden_features):
    super().__init__()
    self.fc1=nn.Linear(n_embd,hidden_features)
    self.fc2=nn.Linear(hidden_features,n_embd)
    self.drop=nn.Dropout(0.1)
    self.act=nn.GELU()
    self.conv=nn.Conv2d(hidden_features,hidden_features,kernel_size=3,stride=1,padding=1)

  def forward(self,x_in,H,W):
    B,N,C=x_in.shape
    x=self.fc1(x_in)
    x=x.transpose(1,2).reshape(B,-1,H,W)
    x=self.conv(x)
    x=x.reshape(B,-1,N).transpose(1,2)
    x=self.act(x)
    x=self.fc2(x)
    x=self.drop(x)
    return x+x_in

In [43]:
# @title Encoder Block
class EncoderBlock(nn.Module):
    def __init__(self, dim,input_channels,r):
        super().__init__()
        self.attention=EfficientSelfAttention(num_heads=8,reduction_ratio=r,dim=dim)
        self.ffn=FeedForward(dim,dim*4)
        self.norm1=nn.LayerNorm(dim)
        self.norm2=nn.LayerNorm(dim)

    def forward(self,x,H,W):
        x=x+self.attention(self.norm1(x),H,W)
        x=x+self.ffn(self.norm2(x),H,W)
        return x

In [44]:
# @title Overlap patch Embeddings
class OverlapPatchEmbeddings(nn.Module):
  def __init__(self,input_channels,embed_dim,kernel_size,stride,padding):
    super().__init__()
    self.norm1=nn.LayerNorm(embed_dim)
    self.conv=nn.Conv2d(input_channels,embed_dim,kernel_size=kernel_size,stride=stride,padding=padding)

  def forward(self,x):
    x=self.conv(x)
    _,_,H,W=x.shape
    x=x.flatten(2).transpose(1,2)
    x=self.norm1(x)
    return x,H,W

In [69]:
# @title Encoder
class Encoder(nn.Module):
  def __init__(self,embed_dims=[64, 128, 256, 512],in_chans=3,sr=[8, 4, 2, 1],depths=[3, 4, 6, 3]):
    super().__init__()
    self.patch_embed1= OverlapPatchEmbeddings(kernel_size=7, stride=4,padding=3, input_channels=in_chans,
                                              embed_dim=embed_dims[0])
    self.patch_embed2= OverlapPatchEmbeddings(kernel_size=3, stride=2,padding=1, input_channels=embed_dims[0],
                                              embed_dim=embed_dims[1])
    self.patch_embed3= OverlapPatchEmbeddings(kernel_size=3, stride=2,padding=1, input_channels=embed_dims[1],
                                              embed_dim=embed_dims[2])
    self.patch_embed4= OverlapPatchEmbeddings(kernel_size=3, stride=2,padding=1, input_channels=embed_dims[2],
                                              embed_dim=embed_dims[3])

    self.block1=nn.ModuleList([EncoderBlock(dim=embed_dims[0],input_channels=in_chans,r=sr[0]) for _ in range(depths[0])])
    self.block2=nn.ModuleList([EncoderBlock(dim=embed_dims[1],input_channels=embed_dims[0],r=sr[1]) for _ in range(depths[1])])
    self.block3=nn.ModuleList([EncoderBlock(dim=embed_dims[2],input_channels=embed_dims[1],r=sr[2]) for _ in range(depths[2])])
    self.block4=nn.ModuleList([EncoderBlock(dim=embed_dims[3],input_channels=embed_dims[2],r=sr[3]) for _ in range(depths[3])])
    self.norm1=nn.LayerNorm(embed_dims[0])
    self.norm2=nn.LayerNorm(embed_dims[1])
    self.norm3=nn.LayerNorm(embed_dims[2])
    self.norm4=nn.LayerNorm(embed_dims[3])
    self.depths=[3, 4, 6, 3]

  def forward(self,x):
    # block 1
    B=x.shape[0]
    outs=[]
    x,H,W=self.patch_embed1(x)
    for blk in self.block1:
      x=blk(x,H,W)
    x=self.norm1(x)
    x=x.reshape(B,H,W,-1).permute(0,3,1,2)
    outs.append(x)
    # print("yes")
    # print(x.shape)
    # block 2
    x,H,W=self.patch_embed2(x)
    for blk in self.block2:
      x=blk(x,H,W)
    x=self.norm2(x)
    x=x.reshape(B,H,W,-1).permute(0,3,1,2)
    outs.append(x)
    # print("yes")
    # print(x.shape)

    # block 3
    x,H,W=self.patch_embed3(x)
    # print(x.shape)
    for blk in self.block3:
      x=blk(x,H,W)
    x=self.norm3(x)
    x=x.reshape(B,H,W,-1).permute(0,3,1,2)
    outs.append(x)
    # print("yes")

    # block 3
    x,H,W=self.patch_embed4(x)
    for blk in self.block4:
      x=blk(x,H,W)
    x=self.norm4(x)
    x=x.reshape(B,H,W,-1).permute(0,3,1,2)
    outs.append(x)
    print("yes")
    return outs

In [70]:
# @title Decoder
class Decoder(nn.Module):
    def __init__(self, embed_dims, decoder_dim, num_classes, out_size):
        super().__init__()

        self.proj = nn.ModuleList([
            nn.Conv2d(embed_dims[i], decoder_dim, kernel_size=1)
            for i in range(4)
        ])

        self.fuse = nn.Conv2d(decoder_dim * 4, decoder_dim, kernel_size=1)
        self.cls = nn.Conv2d(decoder_dim, num_classes, kernel_size=1)
        self.upsample = nn.Upsample(size=out_size, mode='bilinear', align_corners=False)

    def forward(self, feats):
        x = []
        for i in range(4):
            xi = self.proj[i](feats[i])
            xi = self.upsample(xi)
            x.append(xi)

        x = torch.cat(x, dim=1)
        x = self.fuse(x)
        x = self.cls(x)
        return x

In [71]:
# @title SegFormer
class SegFormer(nn.Module):
    def __init__(self, embed_dims, decoder_dim, num_classes, out_size):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(embed_dims, decoder_dim, num_classes, out_size)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [73]:
demo_image=torch.randn(1,3,512,512)
model=SegFormer(embed_dims=[64, 128, 256, 512], decoder_dim=256, num_classes=19, out_size=(int(512/4),int(512/4)))
with torch.no_grad():
  print(model(demo_image).shape)

yes
torch.Size([1, 19, 128, 128])
