In [15]:
import torch
import torch.nn as nn
# from networks.segformer import *
# For jupyter notebook below
from EffSegformer import *
from Transception_cnn import *

from typing import Tuple
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.nn import functional as F



# From MISSFormer.py class BridgeLayer_4
# From Transception.py line83 forward part
# FromEfficientAttention to FuseEfficientAttention

In [19]:
class ConvBNReLU(nn.Module):
    def __init__(self, c_in, c_out, kernel_size, stride=1, padding=1, activation=True):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(
            c_in, c_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
        )
        self.bn = nn.BatchNorm2d(c_out)
        self.relu = nn.ReLU()
        self.activation = activation

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.activation:
            x = self.relu(x)
        return x

class DoubleConv(nn.Module):

    def __init__(self, cin, cout):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            ConvBNReLU(cin, cout, 3, 1, padding=1),
            ConvBNReLU(cout, cout, 3, stride=1, padding=1, activation=False)
        )
        self.conv1 = nn.Conv2d(cout, cout, 1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(cout)

    def forward(self, x):
        x = self.conv(x)
        h = x
        x = self.conv1(x)
        x = self.bn(x)
        x = h + x
        x = self.relu(x)
        return x
    
class MiT_3_ResInception_cnn1(nn.Module):
    def __init__(self, image_size, in_dim, key_dim, value_dim, layers, head_count=1, dil_conv=1, token_mlp='mix_skip', inception="1"):
        super().__init__()

        self.Hs=[56, 28, 14, 7]
        self.Ws=[56, 28, 14, 7]
        patch_sizes = [7, 3, 3, 3]
        strides = [2, 2, 2, 2]
        padding_sizes = [3, 1, 1, 1]
        if dil_conv:  
            dilation = 2  
            patch_sizes1 = [7, 3, 3, 3]
            dil_padding_sizes1 = [3, 2, 2, 2]    
          
        else:
            dilation = 1
            patch_sizes1 = [7, 3, 3, 3]
            dil_padding_sizes1 = [3, 1, 1, 1]
           



        # 1 by 1 convolution to alter the dimension
      
        self.conv1_1_s1 = nn.Conv2d((len(inception)+1)*in_dim[0], in_dim[0], 1)
        self.conv1_1_s2 = nn.Conv2d((len(inception)+1)*in_dim[1], in_dim[1], 1)
        self.conv1_1_s3 = nn.Conv2d((len(inception)+1)*in_dim[2], in_dim[2], 1)
        self.conv1_1_s4 = nn.Conv2d((len(inception)+1)*in_dim[3], in_dim[3], 1)

        # patch_embed
        # layers = [2, 2, 2, 2] dims = [64, 128, 320, 512]
        self.patch_embed1 = OverlapPatchEmbeddings(image_size//2, patch_sizes[0], strides[0], padding_sizes[0], in_dim[0], in_dim[0])
        
        self.patch_embed2_1 = OverlapPatchEmbeddings_fuse(image_size//4, patch_sizes1[1], strides[1], dil_padding_sizes1[1],dilation, in_dim[0], in_dim[1])
        # self.patch_embed2_2 = OverlapPatchEmbeddings_fuse(image_size//4, patch_sizes2[1], strides[1], dil_padding_sizes2[1],dilation, in_dim[0], in_dim[1])
        
        
        self.patch_embed3_1 = OverlapPatchEmbeddings_fuse(image_size//8, patch_sizes1[2], strides[2], dil_padding_sizes1[2],dilation, in_dim[1], in_dim[2])
        # self.patch_embed3_2 = OverlapPatchEmbeddings_fuse(image_size//8, patch_sizes2[2], strides[2], dil_padding_sizes2[2],dilation, in_dim[1], in_dim[2])

        self.patch_embed4_1 = OverlapPatchEmbeddings_fuse(image_size//16, patch_sizes1[3], strides[3], dil_padding_sizes1[3],dilation, in_dim[2], in_dim[3])
        # self.patch_embed4_2 = OverlapPatchEmbeddings_fuse(image_size//16, patch_sizes2[3], strides[3], dil_padding_sizes2[3],dilation, in_dim[2], in_dim[3])
        
        # inception branch
        multiResBlock = {
                        '15': MultiResBlock_15,
                        '13': MultiResBlock_13,
                        '1': MultiResBlock_1,
                        '3': MultiResBlock_3,
                        '5': MultiResBlock_5,
                        }
        
       
        self.resInception2_2 = multiResBlock[inception](in_dim[0],in_dim[1],branch=1,downsample=strides[1],alpha=1)
        self.resInception3_2 = multiResBlock[inception](in_dim[1],in_dim[2],branch=1,downsample=strides[2],alpha=1)
        self.resInception4_2 = multiResBlock[inception](in_dim[2],in_dim[3],branch=1,downsample=strides[3],alpha=1)
        
        # CNN Layer
        self.block0 = DoubleConv(3, in_dim[0])
        self.pool0 = nn.MaxPool2d(2)
        
        # transformer encoder
        self.block1 = nn.ModuleList([ 
            EfficientTransformerBlock(in_dim[0], key_dim[0], value_dim[0], head_count, token_mlp)
        for _ in range(layers[0])])
        self.norm1 = nn.LayerNorm(in_dim[0])

        self.block2 = nn.ModuleList([
            EfficientTransformerBlockFuse_res(in_dim[1], key_dim[1], value_dim[1], head_count, token_mlp)
        for _ in range(layers[1])])
        self.norm2 = nn.LayerNorm(in_dim[1])

        self.block3 = nn.ModuleList([
            EfficientTransformerBlockFuse_res(in_dim[2], key_dim[2], value_dim[2], head_count, token_mlp)
        for _ in range(layers[2])])
        self.norm3 = nn.LayerNorm(in_dim[2])

        self.block4 = nn.ModuleList([
            EfficientTransformerBlockFuse_res(in_dim[3], key_dim[3], value_dim[3], head_count, token_mlp)
        for _ in range(layers[3])])
        self.norm4 = nn.LayerNorm(in_dim[3])
        

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        outs = []

        # stage 0: CNN
        x = self.block0(x)# 224 224 64
        x = self.pool0(x)#112 112 64
        outs.append(x)
    
        # stage 1
        x, H, W = self.patch_embed1(x)# 56 56 128
        for blk in self.block1:
            x = blk(x, H, W)
        x = self.norm1(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

      

        # merge 2
        # print("-------EN: Stage 2------\n\n")
        x1, H1, W1 = self.patch_embed2_1(x)#28 28 256
        H2 = H1
        W2 = W1
        # print("\n S2: H1:{}, H2:{}".format(H1,H2))
        _, nfx1_len, _ = x1.shape
        x2 = self.resInception2_2(x)
        _, nfx2_len, _ = x2.shape
        # print("\n x2 shape:", x2.shape)
        nfx_cat = torch.cat((x1,x2),1)

        # stage 2

        for blk in self.block2:
            nfx_cat = blk(nfx_cat, nfx1_len, nfx2_len, H1, W1, H2, W2)
        tx = self.norm2(nfx_cat)
        # The mlp has been passed in blk, so next just split the sequence and 
        # reshape to spatial dimension
        b,tx_len,_ = tx.shape
        # z_total = []
        map_mx_total = []
        for nz in range(int(tx_len/nfx1_len)):
            z = tx[:, nz*nfx1_len:(nz+1)*nfx1_len, :]
            # z_total.append(z)
            # print( z.shape)
            map_mx = z.reshape(b,H1,W1,-1)
            map_mx = map_mx.permute(0,3,1,2)
            # print( "\nmap_mx: ",map_mx.shape)
            map_mx_total.append(map_mx)

        cat_maps = torch.cat(map_mx_total,1)
        x = self.conv1_1_s2(cat_maps)
        outs.append(x)

        
       # merge 3
        x1, H1, W1 = self.patch_embed3_1(x) # 14 14 512
        H2 = H1
        W2 = W1
        # print("\n S3: H1:{}, H2:{}".format(H1,H2))
        _, nfx1_len, _ = x1.shape
        x2 = self.resInception3_2(x)
        _, nfx2_len, _ = x2.shape
        nfx_cat = torch.cat((x1,x2),1)

        # stage 3
        for blk in self.block3:
            nfx_cat = blk(nfx_cat, nfx1_len, nfx2_len, H1, W1, H2, W2)
        tx = self.norm3(nfx_cat)

        b,tx_len,_ = tx.shape
        # z_total = []
        map_mx_total = []
        for nz in range(int(tx_len/nfx1_len)):
            z = tx[:, nz*nfx1_len:(nz+1)*nfx1_len, :]
            # z_total.append(z)
            # print( z.shape)
            map_mx = z.reshape(b,H1,W1,-1)
            map_mx = map_mx.permute(0,3,1,2)
            # print( "\nmap_mx: ",map_mx.shape)
            map_mx_total.append(map_mx)

        cat_maps = torch.cat(map_mx_total,1)
        x = self.conv1_1_s3(cat_maps)
        outs.append(x)


        # merge 4
      
        x1, H1, W1 = self.patch_embed4_1(x)#7 7 1024
        H2 = H1
        W2 = W1
        # print("\n S4: H1:{}, H2:{}".format(H1,H2))
        _, nfx1_len, _ = x1.shape
        x2 = self.resInception4_2(x)
        _, nfx2_len, _ = x2.shape
        nfx_cat = torch.cat((x1,x2),1)

        # stage 4
        for blk in self.block4:
            nfx_cat = blk(nfx_cat, nfx1_len, nfx2_len, H1, W1, H2, W2)
        tx = self.norm4(nfx_cat)
        b,tx_len,_ = tx.shape
        # z_total = []
        map_mx_total = []
        for nz in range(int(tx_len/nfx1_len)):
            z = tx[:, nz*nfx1_len:(nz+1)*nfx1_len, :]
            # z_total.append(z)
            # print( z.shape)
            map_mx = z.reshape(b,H1,W1,-1)
            map_mx = map_mx.permute(0,3,1,2)
            # print( "\nmap_mx: ",map_mx.shape)
            map_mx_total.append(map_mx)

        cat_maps = torch.cat(map_mx_total,1)
        x = self.conv1_1_s4(cat_maps)
        outs.append(x)


        return outs# len=5
    


In [20]:
class Transception(nn.Module):
    def __init__(self, num_classes=9, head_count=1, dil_conv=1, token_mlp_mode="mix_skip", inception="135"):
        super().__init__()
    
        # Encoder
        dims, key_dim, value_dim, layers = [[64, 128, 320, 512], [64, 128, 320, 512], [64, 128, 320, 512], [2, 2, 2, 2]]        
        self.backbone = MiT_3_ResInception_cnn1(image_size=224, in_dim=dims, key_dim=key_dim, value_dim=value_dim, layers=layers,
                            head_count=head_count, dil_conv=dil_conv, token_mlp=token_mlp_mode, inception=inception)
        # self.backbone = MiT_3inception_padding(image_size=224, in_dim=dims, key_dim=key_dim, value_dim=value_dim, layers=layers,
        #                     head_count=head_count, dil_conv=dil_conv, token_mlp=token_mlp_mode)

        # Here options:(1) MiT_3inception->3 stages;(2) MiT->4 stages; 
        # (3)MiT_3inception_padding: padding before transformer after patch embedding (follow depthconcat)
        # Decoder
        d_base_feat_size = 7 #16 for 512 input size, and 7 for 224
        in_out_chan = [[32, 64, 64, 64],[144, 128, 128, 128],[288, 320, 320, 320],[512, 512, 512, 512]]  # [dim, out_dim, key_dim, value_dim]

        self.decoder_3 = MyDecoderLayer((d_base_feat_size, d_base_feat_size), in_out_chan[3], head_count, 
                                        token_mlp_mode, n_class=num_classes)
        self.decoder_2 = MyDecoderLayer((d_base_feat_size*2, d_base_feat_size*2), in_out_chan[2], head_count,
                                        token_mlp_mode, n_class=num_classes)
        self.decoder_1 = MyDecoderLayer((d_base_feat_size*4, d_base_feat_size*4), in_out_chan[1], head_count, 
                                        token_mlp_mode, n_class=num_classes) 
        self.decoder_0 = MyDecoderLayer((d_base_feat_size*4, d_base_feat_size*4), in_out_chan[1], head_count, 
                                        token_mlp_mode, n_class=num_classes) 
        self.decoder_cnn = MyDecoderLayer((d_base_feat_size*8, d_base_feat_size*8), in_out_chan[0], head_count,
                                        token_mlp_mode, n_class=num_classes, is_last=True)

        
    def forward(self, x):
        #---------------Encoder-------------------------
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)

        output_enc = self.backbone(x)

        b,c,_,_ = output_enc[3].shape

        #---------------Decoder-------------------------     
        tmp_3 = self.decoder_3(output_enc[4].permute(0,2,3,1).view(b,-1,c))
        tmp_2 = self.decoder_2(tmp_3, output_enc[3].permute(0,2,3,1))
        tmp_1 = self.decoder_1(tmp_2, output_enc[2].permute(0,2,3,1))
        tmp_0 = self.decoder_0(tmp_1, output_enc[1].permute(0,2,3,1))
        final_pred = self.decoder_cnn(tmp_0, output_enc[0].permute(0,2,3,1))

        return tmp_0
    

    

In [21]:
if __name__ == "__main__":
    model = Transception(num_classes=9, head_count=1, dil_conv = 1, token_mlp_mode="mix_skip", inception='13')
    print(model(torch.rand(1, 3, 224, 224)).shape)

RuntimeError: shape '[1, -1, 320]' is invalid for input of size 25088