In [64]:
import torch
import torch.nn as nn
# from networks.segformer import *
# For jupyter notebook below
from Transception import *
from EffSegformer import *
from typing import Tuple
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.nn import functional as F

In [65]:
class MiT(nn.Module):
    def __init__(self, image_size, in_dim, key_dim, value_dim, layers, head_count=1, token_mlp='mix_skip'):
        super().__init__()

        self.Hs=[56, 28, 14, 7]
        self.Ws=[56, 28, 14, 7]

#         patch_sizes = [7, 3, 3, 3]
        patch_sizes1 = [7, 3, 3, 3]
        patch_sizes2 = [5, 1, 1, 1]

        strides = [4, 2, 2, 2]
        # padding_sizes = [3, 1, 1, 1]
        dil_padding_sizes1 = [3, 0, 0, 0]
        dil_padding_sizes2 = [3, 0, 0, 0]

        # 1 by 1 convolution to alter the dimension
        self.conv1_1_s1 = nn.Conv2d(2*in_dim[0], in_dim[0], 1)
        self.conv1_1_s2 = nn.Conv2d(2*in_dim[1], in_dim[1], 1)
        self.conv1_1_s3 = nn.Conv2d(2*in_dim[2], in_dim[2], 1)
        self.conv1_1_s4 = nn.Conv2d(2*in_dim[3], in_dim[3], 1)

        # patch_embed
        # layers = [2, 2, 2, 2] dims = [64, 128, 320, 512]
        self.patch_embed1_1 = OverlapPatchEmbeddings_fuse(image_size, patch_sizes1[0], strides[0], dil_padding_sizes1[0], 3, in_dim[0])
        self.patch_embed1_2 = OverlapPatchEmbeddings_fuse(image_size, patch_sizes2[0], strides[0], dil_padding_sizes2[0], 3, in_dim[0])

        self.patch_embed2_1 = OverlapPatchEmbeddings_fuse(image_size//4, patch_sizes1[1], strides[1], dil_padding_sizes1[1],in_dim[0], in_dim[1])
        self.patch_embed2_2 = OverlapPatchEmbeddings_fuse(image_size//4, patch_sizes2[1], strides[1], dil_padding_sizes2[1],in_dim[0], in_dim[1])

        self.patch_embed3_1 = OverlapPatchEmbeddings_fuse(image_size//8, patch_sizes1[2], strides[2], dil_padding_sizes1[2],in_dim[1], in_dim[2])
        self.patch_embed3_2 = OverlapPatchEmbeddings_fuse(image_size//8, patch_sizes2[2], strides[2], dil_padding_sizes2[2],in_dim[1], in_dim[2])

        self.patch_embed4_1 = OverlapPatchEmbeddings_fuse(image_size//16, patch_sizes1[3], strides[3], dil_padding_sizes1[3],in_dim[2], in_dim[3])
        self.patch_embed4_2 = OverlapPatchEmbeddings_fuse(image_size//16, patch_sizes2[3], strides[3], dil_padding_sizes2[3],in_dim[2], in_dim[3])
        
        # transformer encoder
        self.block1 = nn.ModuleList([ 
            EfficientTransformerBlockFuse(in_dim[0], key_dim[0], value_dim[0], head_count, token_mlp)
        for _ in range(layers[0])])
        self.norm1 = nn.LayerNorm(in_dim[0])

        self.block2 = nn.ModuleList([
            EfficientTransformerBlockFuse(in_dim[1], key_dim[1], value_dim[1], head_count, token_mlp)
        for _ in range(layers[1])])
        self.norm2 = nn.LayerNorm(in_dim[1])

        self.block3 = nn.ModuleList([
            EfficientTransformerBlockFuse(in_dim[2], key_dim[2], value_dim[2], head_count, token_mlp)
        for _ in range(layers[2])])
        self.norm3 = nn.LayerNorm(in_dim[2])

        self.block4 = nn.ModuleList([
            EfficientTransformerBlockFuse(in_dim[3], key_dim[3], value_dim[3], head_count, token_mlp)
        for _ in range(layers[3])])
        self.norm4 = nn.LayerNorm(in_dim[3])
        

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        outs = []

        # stage 1
        x1, H1, W1 = self.patch_embed1_1(x)
        _, nfx1_len, _ = x1.shape
        x2, H2, W2 = self.patch_embed1_2(x)
        _, nfx2_len, _ = x2.shape
        nfx_cat = torch.cat((x1,x2),1)


        for blk in self.block1:
            nfx_cat = blk(nfx_cat, nfx1_len, nfx2_len, H1, W1, H2, W2)
        nfx_cat = self.norm1(nfx_cat)
        mx1 = nfx_cat[:, :nfx1_len, :]
        mx2 = nfx_cat[:, nfx1_len: :]
        map_mx1 = mx1.view(1,H1,W1,-1)
        map_mx2 = mx2.view(1,H2,W2,-1)
        map_mx1 = map_mx1.permute(0,3,1,2)
        map_mx2 = map_mx2.permute(0,3,1,2)
#         print("check map_mx1 and map_mx2 before interpolation: \n map_mx1.shape:{} \n map_mx2.shape:{}".format(map_mx1.shape, map_mx2.shape))
        map_mx1 = F.interpolate(map_mx1,[self.Hs[0], self.Ws[0]])
        cat_maps = torch.cat((map_mx1, map_mx2),1)
        x = self.conv1_1_s1(cat_maps)
        # x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 2
#         print('------stage 2---\n')
        x1, H1, W1 = self.patch_embed2_1(x)
        _, nfx1_len, _ = x1.shape
        x2, H2, W2 = self.patch_embed2_2(x)
        _, nfx2_len, _ = x2.shape
        nfx_cat = torch.cat((x1,x2),1)
#         print("H1 W1{} {}\n H2 W2:{} {}".format(H1, W1, H2, W2))

        for blk in self.block2:
            nfx_cat = blk(nfx_cat, nfx1_len, nfx2_len, H1, W1, H2, W2)
        nfx_cat = self.norm2(nfx_cat)
        mx1 = nfx_cat[:, :nfx1_len, :]
        mx2 = nfx_cat[:, nfx1_len: :]
        map_mx1 = mx1.view(1,H1,W1,-1)
        map_mx2 = mx2.view(1,H2,W2,-1)
        map_mx1 = map_mx1.permute(0,3,1,2)
        map_mx2 = map_mx2.permute(0,3,1,2)
#         print("check map_mx1 and map_mx2 before interpolation: \n map_mx1.shape:{} \n map_mx2.shape:{}".format(map_mx1.shape, map_mx2.shape))
        map_mx1 = F.interpolate(map_mx1,[self.Hs[1], self.Ws[1]])
#         print("check map_mx1 and map_mx2 before interpolation: /n map_mx1.shape:{} /n map_mx2.shape:{}".format(map_mx1.shape, map_mx2.shape))
        cat_maps = torch.cat((map_mx1, map_mx2),1)
        x = self.conv1_1_s2(cat_maps)
        outs.append(x)

        # stage 3
        x1, H1, W1 = self.patch_embed3_1(x)
        _, nfx1_len, _ = x1.shape
        x2, H2, W2 = self.patch_embed3_2(x)
        _, nfx2_len, _ = x2.shape
        nfx_cat = torch.cat((x1,x2),1)

        for blk in self.block3:
            nfx_cat = blk(nfx_cat, nfx1_len, nfx2_len, H1, W1, H2, W2)
        nfx_cat = self.norm3(nfx_cat)
        mx1 = nfx_cat[:, :nfx1_len, :]
        mx2 = nfx_cat[:, nfx1_len: :]
        map_mx1 = mx1.view(1,H1,W1,-1)
        map_mx2 = mx2.view(1,H2,W2,-1)
        map_mx1 = map_mx1.permute(0,3,1,2)
        map_mx2 = map_mx2.permute(0,3,1,2)
        map_mx1 = F.interpolate(map_mx1,[self.Hs[2], self.Ws[2]])
        cat_maps = torch.cat((map_mx1, map_mx2),1)
        x = self.conv1_1_s3(cat_maps)
        outs.append(x)

        # stage 4
        x1, H1, W1 = self.patch_embed4_1(x)
        _, nfx1_len, _ = x1.shape
        x2, H2, W2 = self.patch_embed4_2(x)
        _, nfx2_len, _ = x2.shape
        nfx_cat = torch.cat((x1,x2),1)

        for blk in self.block4:
            nfx_cat = blk(nfx_cat, nfx1_len, nfx2_len, H1, W1, H2, W2)
        nfx_cat = self.norm4(nfx_cat)
        mx1 = nfx_cat[:, :nfx1_len, :]
        mx2 = nfx_cat[:, nfx1_len: :]
        map_mx1 = mx1.view(1,H1,W1,-1)
        map_mx2 = mx2.view(1,H2,W2,-1)
        map_mx1 = map_mx1.permute(0,3,1,2)
        map_mx2 = map_mx2.permute(0,3,1,2)
        map_mx1 = F.interpolate(map_mx1,[self.Hs[3], self.Ws[3]])
        cat_maps = torch.cat((map_mx1, map_mx2),1)
        x = self.conv1_1_s4(cat_maps)
        outs.append(x)

        return outs
    

In [66]:
class MyDecoderLayer(nn.Module):
    def __init__(self, input_size, in_out_chan, head_count, token_mlp_mode, n_class=9,
                 norm_layer=nn.LayerNorm, is_last=False):
        super().__init__()
        dims = in_out_chan[0]
        out_dim = in_out_chan[1]
        key_dim = in_out_chan[2]
        value_dim = in_out_chan[3]
        if not is_last:
            self.concat_linear = nn.Linear(dims*2, out_dim)
            # transformer decoder
            self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
            self.last_layer = None
        else:
            self.concat_linear = nn.Linear(dims*4, out_dim)
            # transformer decoder
            self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
            # self.last_layer = nn.Linear(out_dim, n_class)
            self.last_layer = nn.Conv2d(out_dim, n_class,1)
            # self.last_layer = None

        self.layer_former_1 = EfficientTransformerBlock(out_dim, key_dim, value_dim, head_count, token_mlp_mode)
        self.layer_former_2 = EfficientTransformerBlock(out_dim, key_dim, value_dim, head_count, token_mlp_mode)
       

        def init_weights(self): 
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Conv2d):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

        init_weights(self)
      
    def forward(self, x1, x2=None):
        if x2 is not None:# skip connection exist
            b, h, w, c = x2.shape
            x2 = x2.view(b, -1, c)
            cat_x = torch.cat([x1, x2], dim=-1)
            cat_linear_x = self.concat_linear(cat_x)
            tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
            tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
            
            if self.last_layer:
                out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2)) 
            else:
                out = self.layer_up(tran_layer_2)
        else:
            # if len(x1.shape)>3:
            #     x1 = x1.permute(0,2,3,1)
            #     b, h, w, c = x1.shape
            #     x1 = x1.view(b, -1, c)
            out = self.layer_up(x1)
        return out
    

In [67]:
class Transception(nn.Module):
    def __init__(self, num_classes=9, head_count=1, token_mlp_mode="mix_skip"):
        super().__init__()
    
        # Encoder
        dims, key_dim, value_dim, layers = [[64, 128, 320, 512], [64, 128, 320, 512], [64, 128, 320, 512], [2, 2, 2, 2]]        
        self.backbone = MiT(image_size=224, in_dim=dims, key_dim=key_dim, value_dim=value_dim, layers=layers,
                            head_count=head_count, token_mlp=token_mlp_mode)
        
        # Decoder
        d_base_feat_size = 7 #16 for 512 input size, and 7 for 224
        in_out_chan = [[32, 64, 64, 64],[144, 128, 128, 128],[288, 320, 320, 320],[512, 512, 512, 512]]  # [dim, out_dim, key_dim, value_dim]

        self.decoder_3 = MyDecoderLayer((d_base_feat_size, d_base_feat_size), in_out_chan[3], head_count, 
                                        token_mlp_mode, n_class=num_classes)
        self.decoder_2 = MyDecoderLayer((d_base_feat_size*2, d_base_feat_size*2), in_out_chan[2], head_count,
                                        token_mlp_mode, n_class=num_classes)
        self.decoder_1 = MyDecoderLayer((d_base_feat_size*4, d_base_feat_size*4), in_out_chan[1], head_count, 
                                        token_mlp_mode, n_class=num_classes) 
        self.decoder_0 = MyDecoderLayer((d_base_feat_size*8, d_base_feat_size*8), in_out_chan[0], head_count,
                                        token_mlp_mode, n_class=num_classes, is_last=True)

        
    def forward(self, x):
        #---------------Encoder-------------------------
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)

        output_enc = self.backbone(x)

        b,c,_,_ = output_enc[3].shape

        #---------------Decoder-------------------------     
        tmp_3 = self.decoder_3(output_enc[3].permute(0,2,3,1).view(b,-1,c))
        tmp_2 = self.decoder_2(tmp_3, output_enc[2].permute(0,2,3,1))
        tmp_1 = self.decoder_1(tmp_2, output_enc[1].permute(0,2,3,1))
        tmp_0 = self.decoder_0(tmp_1, output_enc[0].permute(0,2,3,1))

        return tmp_0
    

In [70]:
model = Transception(num_classes=9, head_count=1, token_mlp_mode="mix_skip")
# print(model(torch.rand(1, 3, 224, 224)).shape)
print(model)

Transception(
  (backbone): MiT(
    (conv1_1_s1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (conv1_1_s2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
    (conv1_1_s3): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))
    (conv1_1_s4): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
    (patch_embed1_1): OverlapPatchEmbeddings_fuse(
      (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3), dilation=(2, 2))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed1_2): OverlapPatchEmbeddings_fuse(
      (proj): Conv2d(3, 64, kernel_size=(5, 5), stride=(4, 4), padding=(3, 3), dilation=(2, 2))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2_1): OverlapPatchEmbeddings_fuse(
      (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), dilation=(2, 2))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2_2): OverlapPatchEmbedd