In [94]:
import torch
import torch.nn as nn
from segformer import *
from EfficientMISSFormer import *
from typing import Tuple
from einops import rearrange
from einops.layers.torch import Rearrange

from torch.nn import functional as F

i_stage = 1
b = 4 # batch size

if i_stage == 0:
    inputs = torch.rand(b, 1, 224, 224)  
    if inputs.size()[1] == 1:
        inputs = inputs.repeat(1,3,1,1)
    input_dim = 3
    
elif i_stage == 1:
    inputs = torch.rand(b, 64, 56, 56) 
    input_dim = 64
elif i_stage == 2:
    inputs = torch.rand(b, 128, 28, 28)
    input_dim = 128
else:
    inputs = torch.rand(b, 320, 14, 14)
    input_dim = 320


print("This version we add the inception in the 3 stages, the first stage we stay the same.")
print("Test the input of size {}".format(inputs.shape))
print("Now we are checking the process in stage:{}".format(i_stage))

This version we add the inception in the 3 stages, the first stage we stay the same.
Test the input of size torch.Size([4, 64, 56, 56])
Now we are checking the process in stage:1


In [95]:
class OverlapPatchEmbeddings_inception(nn.Module):
    def __init__(self, img_size=224, patch_size=7, stride=4, padding=1, in_ch=3, dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_ch, dim, patch_size, stride, padding)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        px = self.proj(x)
        _, _, H, W = px.shape
        fx = px.flatten(2).transpose(1, 2)
        nfx = self.norm(fx)
        return nfx, H, W

In [96]:
Hs=[56, 28, 14, 7]
Ws=[56, 28, 14, 7]

dil_conv = False
padding_sizes = [3, 1, 1, 1]
if dil_conv:

    patch_sizes1 = [7, 3, 3, 3]
    dil_padding_sizes1 = [1, 0, 0, 0]    
    patch_sizes2 = [1, 1, 1, 1]
    dil_padding_sizes2 = [0, 0, 0, 0]    


else:
    patch_sizes1 = [7, 3, 3, 3]
    dil_padding_sizes1 = [3, 1, 1, 1]
    patch_sizes2 = [5, 1, 1, 1]
    # dil_padding_sizes2 = [3, 0, 0, 0]
    dil_padding_sizes2 = [1, 0, 0, 0]
    
    
strides = [4, 2, 2, 2]
image_size = 224
in_dim = [64, 128, 320, 512]
layers = [2, 2, 2, 2]
head_count=1
token_mlp='mix_skip'
key_dim = value_dim = in_dim
B = inputs.shape[0]
print("config:\ndil_conv:{}\npatch_sizes1{} padding_size1:{} \npatch_size2:{} padding_size2{} \n".format(dil_conv,patch_sizes1,dil_padding_sizes1,patch_sizes2,dil_padding_sizes2))

config:
dil_conv:False
patch_sizes1[7, 3, 3, 3] padding_size1:[3, 1, 1, 1] 
patch_size2:[5, 1, 1, 1] padding_size2[1, 0, 0, 0] 



In [97]:
 #self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_sizes[0], strides[0], padding_sizes[0], 3, in_dim[0])
# patch_embed1 = OverlapPatchEmbeddings_inception(image_size, patch_sizes[0], strides[0], padding_sizes[0], 3, in_dim[0])
# x, H, W = patch_embed1(inputs)
# print("check the output-x:{}, H:{}, W:{}".format(x, H, W))

**See the pytorch tutorial about nn.Conv2d**
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
![image.png](attachment:image.png)

In [98]:
proj = nn.Conv2d(input_dim,in_dim[i_stage],kernel_size = patch_sizes1[i_stage], stride = strides[i_stage],padding=padding_sizes[i_stage], dilation=1)
if dil_conv:
    proj_dilation1 = nn.Conv2d(input_dim,in_dim[i_stage],kernel_size = patch_sizes1[i_stage], stride = strides[i_stage],padding=dil_padding_sizes1[i_stage], dilation=2)
#     proj_dilation2 = nn.Conv2d(input_dim,in_dim[i_stage],kernel_size = patch_sizes2[i_stage], stride = strides[i_stage],padding=dil_padding_sizes2[i_stage], dilation=2)
else:
    proj_dilation1 = nn.Conv2d(input_dim,in_dim[i_stage],kernel_size = patch_sizes1[i_stage], stride = strides[i_stage],padding=dil_padding_sizes1[i_stage], dilation=1)
#     proj_dilation2 = nn.Conv2d(input_dim,in_dim[i_stage],kernel_size = patch_sizes2[i_stage], stride = strides[i_stage],padding=dil_padding_sizes2[i_stage], dilation=1)

    
norm = nn.LayerNorm(in_dim[i_stage])

The kernel size can be set to 1 3 5 respectively. But because the patch sizes for different stages are [7 3 3 3], so I should consider what the kernel size will be. (Apparently it won't be larger than 3 itself)

In [99]:
print("\n\n----with no dilation----")
px = proj(inputs)
print("The result after conv:{}".format(px.shape))
_, _, H, W = px.shape   
fx = px.flatten(2)
print("The result after flatten:{}".format(fx.shape))
fx = fx.transpose(1,2)
print("The result after transpose(1,2):{}".format(fx.shape))
nfx = norm(fx)
print("check the output-x:{}, H:{}, W:{}".format(nfx.shape, H, W))



----with no dilation----
The result after conv:torch.Size([4, 128, 28, 28])
The result after flatten:torch.Size([4, 128, 784])
The result after transpose(1,2):torch.Size([4, 784, 128])
check the output-x:torch.Size([4, 784, 128]), H:28, W:28


# start 2 branches

In [100]:
print("\n\n----branch 1: normal patch merging------")
px1 = proj_dilation1(inputs)
print("The result after dilation_conv:{}".format(px1.shape))
_, _, H1, W1 = px1.shape   
fx1 = px1.flatten(2)
print("The result after flatten:{}".format(fx1.shape))
fx1 = fx1.transpose(1,2)
print("The result after transpose(1,2):{}".format(fx1.shape))
nfx1 = norm(fx1)
print("check the output-x:{}, H:{}, W:{}".format(nfx1.shape, H1, W1))
_, nfx1_len, _ = nfx1.shape





----branch 1: normal patch merging------
The result after dilation_conv:torch.Size([4, 128, 28, 28])
The result after flatten:torch.Size([4, 128, 784])
The result after transpose(1,2):torch.Size([4, 784, 128])
check the output-x:torch.Size([4, 784, 128]), H:28, W:28


In [155]:
# # proj_dilation1 = nn.Conv2d(input_dim,in_dim[i_stage],kernel_size = patch_sizes1[i_stage], stride = strides[i_stage],padding=dil_padding_sizes1[i_stage], dilation=2)
# conv3_3_1 =  nn.Conv2d(input_dim, in_dim[i_stage], kernel_size=3, padding =1)
# conv3_3_2 =  nn.Conv2d(input_dim, in_dim[i_stage], kernel_size=3, padding =1)
#https://github.com/Cassieyy/MultiResUnet3D/blob/main/MultiResUnet3D.py

class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, act='relu'):
        # print(ch_out)
        super(conv_block,self).__init__()
        if act == None:
            self.conv = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size,stride=stride,padding=padding),
                nn.BatchNorm2d(ch_out)
            )
        elif act == 'relu':
            self.conv = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size,stride=stride,padding=padding),
                nn.BatchNorm2d(ch_out),
                nn.ReLU(inplace=True)
            )
        elif act == 'sigmoid':
            self.conv = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size,stride=stride,padding=padding),
                nn.BatchNorm2d(ch_out),
                nn.Sigmoid()
            )

    def forward(self,x):
        x = self.conv(x)
        return x
    


class res_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(res_block,self).__init__()
        self.res = conv_block(ch_in,ch_out,1,1,0,None)
        self.main = conv_block(ch_in,ch_out)
        self.bn = nn.BatchNorm2d(ch_in)
    def forward(self,x):
        res_x = self.res(x)

        main_x = self.main(x)
        out = res_x.add(main_x)
        out = nn.ReLU(inplace=True)(out)
        # print(out.shape[1], type(out.shape[1]))
        # assert 1>3
        out = self.bn(out)
        return out
    


class MultiResBlock_1357(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_1357,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
        res = self.residual_layer(x)
        res_out = self.maxpool(res)
        res_out = (res_out.flatten(2)).transpose(1,2)
        out.append(res_out)
        print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
        sbs_out = self.maxpool(sbs)
        sbs_out = (sbs_out.flatten(2)).transpose(1,2)
        print("\n out_3*3:{}\n".format(sbs_out.shape))
        out.append(sbs_out)
        
        obo = self.conv5x5(sbs)
        obo_out = self.maxpool(obo)
        obo_out = (obo_out.flatten(2)).transpose(1,2)
        out.append(obo_out)
        print("\n out_5*5:{}\n".format(obo_out.shape))
        
        cbc = self.conv7x7(obo)
        cbc_out = self.maxpool(cbc)
        cbc_out = (cbc_out.flatten(2)).transpose(1,2)
        print("\n out_7*7:{}\n".format(cbc_out.shape))
        out.append(cbc_out)
        all_t = torch.cat((out[0], out[1], out[2],out[3]), 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t
    
    
class MultiResBlock_135(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_135,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
        res = self.residual_layer(x)
        res_out = self.maxpool(res)
        res_out = (res_out.flatten(2)).transpose(1,2)
        out.append(res_out)
        print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
        sbs_out = self.maxpool(sbs)
        sbs_out = (sbs_out.flatten(2)).transpose(1,2)
        print("\n out_3*3:{}\n".format(sbs_out.shape))
        out.append(sbs_out)
        
        obo = self.conv5x5(sbs)
        obo_out = self.maxpool(obo)
        obo_out = (obo_out.flatten(2)).transpose(1,2)
        out.append(obo_out)
        print("\n out_5*5:{}\n".format(obo_out.shape))
        
#         cbc = self.conv7x7(obo)
#         cbc_out = self.maxpool(cbc)
#         cbc_out = (cbc_out.flatten(2)).transpose(1,2)
#         out.append(cbc_out)
#         print("\n out_7*7:{}\n".format(cbc_out.shape))
        
        
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t

class MultiResBlock_157(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_157,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
        res = self.residual_layer(x)
        res_out = self.maxpool(res)
        res_out = (res_out.flatten(2)).transpose(1,2)
        out.append(res_out)
        print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
#         sbs_out = self.maxpool(sbs)
#         sbs_out = (sbs_out.flatten(2)).transpose(1,2)
#         print("\n out_3*3:{}\n".format(sbs_out.shape))
#         out.append(sbs_out)
        
        obo = self.conv5x5(sbs)
        obo_out = self.maxpool(obo)
        obo_out = (obo_out.flatten(2)).transpose(1,2)
        out.append(obo_out)
        print("\n out_5*5:{}\n".format(obo_out.shape))
        
        cbc = self.conv7x7(obo)
        cbc_out = self.maxpool(cbc)
        cbc_out = (cbc_out.flatten(2)).transpose(1,2)
        print("\n out_7*7:{}\n".format(cbc_out.shape))
        out.append(cbc_out)
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t
    

class MultiResBlock_15(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_15,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
        res = self.residual_layer(x)
        res_out = self.maxpool(res)
        res_out = (res_out.flatten(2)).transpose(1,2)
        out.append(res_out)
        print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
#         sbs_out = self.maxpool(sbs)
#         sbs_out = (sbs_out.flatten(2)).transpose(1,2)
#         print("\n out_3*3:{}\n".format(sbs_out.shape))
#         out.append(sbs_out)
        
        obo = self.conv5x5(sbs)
        obo_out = self.maxpool(obo)
        obo_out = (obo_out.flatten(2)).transpose(1,2)
        out.append(obo_out)
        print("\n out_5*5:{}\n".format(obo_out.shape))
        
#         cbc = self.conv7x7(obo)
#         cbc_out = self.maxpool(cbc)
#         cbc_out = (cbc_out.flatten(2)).transpose(1,2)
#         print("\n out_7*7:{}\n".format(cbc_out.shape))
#         out.append(cbc_out)
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t

    
class MultiResBlock_13(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_13,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
        res = self.residual_layer(x)
        res_out = self.maxpool(res)
        res_out = (res_out.flatten(2)).transpose(1,2)
        out.append(res_out)
        print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
        sbs_out = self.maxpool(sbs)
        sbs_out = (sbs_out.flatten(2)).transpose(1,2)
        print("\n out_3*3:{}\n".format(sbs_out.shape))
        out.append(sbs_out)
        
#         obo = self.conv5x5(sbs)
#         obo_out = self.maxpool(obo)
#         obo_out = (obo_out.flatten(2)).transpose(1,2)
#         out.append(obo_out)
#         print("\n out_5*5:{}\n".format(obo_out.shape))
        
#         cbc = self.conv7x7(obo)
#         cbc_out = self.maxpool(cbc)
#         cbc_out = (cbc_out.flatten(2)).transpose(1,2)
#         print("\n out_7*7:{}\n".format(cbc_out.shape))
#         out.append(cbc_out)
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t
    
    
    
class MultiResBlock_1(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_1,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
        res = self.residual_layer(x)
        res_out = self.maxpool(res)
        res_out = (res_out.flatten(2)).transpose(1,2)
        out.append(res_out)
        print("\n res:{}\n".format(res_out.shape))
        
#         sbs = self.conv3x3(x)
#         sbs_out = self.maxpool(sbs)
#         sbs_out = (sbs_out.flatten(2)).transpose(1,2)
#         print("\n out_3*3:{}\n".format(sbs_out.shape))
#         out.append(sbs_out)
        
#         obo = self.conv5x5(sbs)
#         obo_out = self.maxpool(obo)
#         obo_out = (obo_out.flatten(2)).transpose(1,2)
#         out.append(obo_out)
#         print("\n out_5*5:{}\n".format(obo_out.shape))
        
#         cbc = self.conv7x7(obo)
#         cbc_out = self.maxpool(cbc)
#         cbc_out = (cbc_out.flatten(2)).transpose(1,2)
#         print("\n out_7*7:{}\n".format(cbc_out.shape))
#         out.append(cbc_out)
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t
    
class MultiResBlock_3(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_3,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
#         res = self.residual_layer(x)
#         res_out = self.maxpool(res)
#         res_out = (res_out.flatten(2)).transpose(1,2)
#         out.append(res_out)
#         print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
        sbs_out = self.maxpool(sbs)
        sbs_out = (sbs_out.flatten(2)).transpose(1,2)
        print("\n out_3*3:{}\n".format(sbs_out.shape))
        out.append(sbs_out)
        
#         obo = self.conv5x5(sbs)
#         obo_out = self.maxpool(obo)
#         obo_out = (obo_out.flatten(2)).transpose(1,2)
#         out.append(obo_out)
#         print("\n out_5*5:{}\n".format(obo_out.shape))
        
#         cbc = self.conv7x7(obo)
#         cbc_out = self.maxpool(cbc)
#         cbc_out = (cbc_out.flatten(2)).transpose(1,2)
#         print("\n out_7*7:{}\n".format(cbc_out.shape))
#         out.append(cbc_out)
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t
    
class MultiResBlock_5(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_5,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
#         res = self.residual_layer(x)
#         res_out = self.maxpool(res)
#         res_out = (res_out.flatten(2)).transpose(1,2)
#         out.append(res_out)
#         print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
#         sbs_out = self.maxpool(sbs)
#         sbs_out = (sbs_out.flatten(2)).transpose(1,2)
#         print("\n out_3*3:{}\n".format(sbs_out.shape))
#         out.append(sbs_out)
        
        obo = self.conv5x5(sbs)
        obo_out = self.maxpool(obo)
        obo_out = (obo_out.flatten(2)).transpose(1,2)
        out.append(obo_out)
        print("\n out_5*5:{}\n".format(obo_out.shape))
        
#         cbc = self.conv7x7(obo)
#         cbc_out = self.maxpool(cbc)
#         cbc_out = (cbc_out.flatten(2)).transpose(1,2)
#         print("\n out_7*7:{}\n".format(cbc_out.shape))
#         out.append(cbc_out)
        
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t
    
    
    
class MultiResBlock_7(nn.Module):
    def __init__(self,in_ch,U,branch=1,downsample=2, alpha=1):
        super(MultiResBlock_7,self).__init__()
        self.W = alpha * U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(downsample, stride=downsample)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_1 = nn.BatchNorm2d(self.W)
#         self.batchnorm_2 = nn.BatchNorm2d(self.W)
        self.norm = nn.LayerNorm(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) 
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
#         res = self.residual_layer(x)
#         res_out = self.maxpool(res)
#         res_out = (res_out.flatten(2)).transpose(1,2)
#         out.append(res_out)
#         print("\n res:{}\n".format(res_out.shape))
        
        sbs = self.conv3x3(x)
#         sbs_out = self.maxpool(sbs)
#         sbs_out = (sbs_out.flatten(2)).transpose(1,2)
#         print("\n out_3*3:{}\n".format(sbs_out.shape))
#         out.append(sbs_out)
        
        obo = self.conv5x5(sbs)
#         obo_out = self.maxpool(obo)
#         obo_out = (obo_out.flatten(2)).transpose(1,2)
#         out.append(obo_out)
#         print("\n out_5*5:{}\n".format(obo_out.shape))
        
        cbc = self.conv7x7(obo)
        cbc_out = self.maxpool(cbc)
        cbc_out = (cbc_out.flatten(2)).transpose(1,2)
        print("\n out_7*7:{}\n".format(cbc_out.shape))
        out.append(cbc_out)
        all_t = torch.cat(out, 1)
        all_t = self.norm(all_t)
        print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)      
        
        return all_t

In [156]:
print("\n\n----branch 2: inception module------")



----branch 2: inception module------


In [157]:
print(inputs.shape)
print(input_dim)
print(in_dim[i_stage])
# _, _, H2, _= inputs.shape
# resInception = MultiResBlock_1357(input_dim,in_dim[i_stage],branch=1,downsample=strides[i_stage],alpha=1)
resInception = MultiResBlock_135(input_dim,in_dim[i_stage],branch=1,downsample=strides[i_stage],alpha=1)
nfx2 = resInception(inputs)
print(nfx2.shape)
# print(outputs[1].shape)
# print(outputs[2].shape)
# print(outputs[3].shape)



torch.Size([4, 64, 56, 56])
64
128
torch.Size([4, 64, 56, 56])

 W=alpha*U :128


 res:torch.Size([4, 784, 128])


 out_3*3:torch.Size([4, 784, 128])


 out_5*5:torch.Size([4, 784, 128])



NameError: name 'cbc_out' is not defined

We can see that the dilation will change the spatial size of the output.
To make the size stay the same as that in the original paper, I want to use new **padding or stride**.
I set padding[0] from 3 to 5.


![image.png](attachment:image.png)

So far we have seen how the patch merging patching the inputs[1,3, 224, 224] to the output [1, 3136, 64] 
where 3136 = 56\*56 and C=64

# Concatenate the two sequences

In [137]:
nfx_cat = torch.cat((nfx1,nfx2),1)
print("The shape of nfx_cat is {}".format(nfx_cat.shape))

The shape of nfx_cat is torch.Size([4, 2352, 128])


# Input of Transformer

In [138]:
from EfficientMISSFormer import *

In [139]:
# From MISSFormer.py class BridgeLayer_4
# From Transception.py line83 forward part
# FromEfficientAttention to FuseEfficientAttention
class FuseEfficientAttention(nn.Module):
    """
        input  -> x:[B, N, D]
        output ->   [B, N, D]
    
        in_channels:    int -> Embedding Dimension  d
        key_channels:   int -> Key Embedding Dimension,   Best: (in_channels)
        value_channels: int -> Value Embedding Dimension, Best: (in_channels or in_channels//2) 
        head_count:     int -> It divides the embedding dimension by the head_count and process each part individually
        
        Conv2D # of Params:  ((k_h * k_w * C_in) + 1) * C_out)
    """
    
    def __init__(self, in_channels, key_channels, value_channels, head_count=1):
        super().__init__()
        self.in_channels = in_channels
        self.key_channels = key_channels
        self.head_count = head_count
        self.value_channels = value_channels

        self.keys = nn.Linear(in_channels, key_channels, bias=True) 
        self.queries = nn.Linear(in_channels, key_channels, bias=True)
        self.values = nn.Linear(in_channels, value_channels,bias=True)
        self.reprojection = nn.Linear(value_channels, in_channels)

        
    def forward(self, input_):
        b, n, _ = input_.size()
        print("\nb is {}, n is {}\n".format(b, n))
        print("\nin_channels is {}, key_channels is {}\n".format( self.in_channels, self.key_channels))
        # n, _,  = input_.size()
        keys = self.keys(input_)#B N D
        keys = keys.reshape((b, self.key_channels, n))#B dk N
        queries = self.queries(input_).reshape(b, self.key_channels, n)#b dk n
        values = self.values(input_).reshape((b, self.value_channels, n))# b dv n
        
        head_key_channels = self.key_channels // self.head_count
        head_value_channels = self.value_channels // self.head_count
        
        attended_values = []
        for i in range(self.head_count):
            key = F.softmax(keys[
                :,
                i * head_key_channels: (i + 1) * head_key_channels,
                :
            ], dim=2)
            
            query = F.softmax(queries[
                :,
                i * head_key_channels: (i + 1) * head_key_channels,
                :
            ], dim=1)
                        
            value = values[
                :,
                i * head_value_channels: (i + 1) * head_value_channels,
                :
            ]            
            
            context = key @ value.transpose(1, 2) # dk*dv B dk N* B N dv=B dk dv
            attended_value = (context.transpose(1, 2) @ query).reshape(b, head_value_channels, n) 
            # (b dv dk @ b dk n)->b dv n         
            attended_values.append(attended_value)
                
        aggregated_values = torch.cat(attended_values, dim=1).permute(0,2,1)
        #b n h*dv
        attention = self.reprojection(aggregated_values)# b n d

        return attention

![image.png](attachment:image.png)

In [140]:
class MixFFN_skip_fuse(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.dwconv = DWConv(c2)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(c2, c1)
        self.norm1 = nn.LayerNorm(c2)
        self.norm2 = nn.LayerNorm(c2)
        self.norm3 = nn.LayerNorm(c2)
    def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
        ax = self.act(self.norm1(self.dwconv(self.fc1(x), H, W)+self.fc1(x)))
        out = self.fc2(ax)
        return out
    

    
#functions
norm1 = nn.LayerNorm(in_dim[i_stage])
norm2 = nn.LayerNorm(in_dim[i_stage])
attn = FuseEfficientAttention(in_channels=in_dim[i_stage], key_channels=key_dim[i_stage],value_channels=value_dim[0], head_count=1)
if token_mlp=='mix':
    mlp = MixFFN(in_dim[i_stage], int(in_dim[i_stage]*4))  
elif token_mlp=='mix_skip':
    mlp = MixFFN_skip_fuse(in_dim[i_stage], int(in_dim[i_stage]*4)) 
else:
    mlp = MLP_FFN(in_dim, int(in_dim[i_stage]*4))
    
    
#forward

# norm_1 = Rearrange('b n d -> b d n')(norm_1)#Or view??
norm_1 = norm1(nfx_cat)
atten = attn(norm_1)
atten = norm2(atten)
print(atten.size())

tx = nfx_cat + atten
tx = norm2(tx)
print(tx.size())
#This is the output of attention


b is 4, n is 2352


in_channels is 128, key_channels is 128

torch.Size([4, 2352, 128])
torch.Size([4, 2352, 128])


Now we could consider how to decompose the two part. 

In [141]:
_,tx_len,_ = tx.shape
z_total = []
for nz in range(int(tx_len/nfx1_len)):
    z = tx[:, nz*nfx1_len:(nz+1)*nfx1_len, :]
    z_total.append(z)
    print( z.shape)

# z2 = z2.view(1,-1,in_dim[i_stage])
print("check z_total: {}".format(len(z_total)))


torch.Size([4, 784, 128])
torch.Size([4, 784, 128])
torch.Size([4, 784, 128])
check z_total: 3


In [142]:
# print(H1, H2)

In [143]:
#pass the z1 and z2 through the next conv layer
b, _, _ = z_total[0].shape
# mx_total = []
map_mx_total = []
for nz in range(int(tx_len/nfx1_len)):
    mx = z_total[nz]+ mlp(z_total[nz], H1, W1)
#     mx_total.append(mx)
# mx1 = z1 + mlp1(z1, H2, W1)
# mx2 = z2 + mlp2(z2, H2, W2)
    print("check mx: {}".format(mx.shape))
    map_mx = mx.view(b,H1,W1,-1)
    
    
# print("check mx2: {}".format(z2.shape))
#Map them to the spatial dimension

# map_mx1 = mx1.view(b,H1,W1,-1)
# map_mx2 = mx2.view(b,H2,W2,-1)
# print("check map_mx1: {}".format(map_mx1.shape))
# print("check map_mx2: {}".format(map_mx2.shape))
    map_mx = map_mx.permute(0,3,1,2)
    map_mx_total.append(map_mx)
    print("check map_mx: {}".format(map_mx.shape))

check mx: torch.Size([4, 784, 128])
check map_mx: torch.Size([4, 128, 28, 28])
check mx: torch.Size([4, 784, 128])
check map_mx: torch.Size([4, 128, 28, 28])
check mx: torch.Size([4, 784, 128])
check map_mx: torch.Size([4, 128, 28, 28])


use 1by1 conv to combine the two maps to [1, 56, 56,c]???
Now the question turns to how to concatenate the layers.

https://pytorch.org/docs/stable/generated/torch.cat.html
torch.cat((x, x, x), 1)

In [144]:
# Concatenate together and conv1_1
cat_maps = torch.cat(map_mx_total,1)
print(cat_maps.shape)      
_, in_channel,_, _ = cat_maps.shape
conv1_1 = nn.Conv2d(in_channel, in_dim[i_stage], 1)
cat_maps = conv1_1(cat_maps)
print(cat_maps.shape)



torch.Size([4, 384, 28, 28])
torch.Size([4, 128, 28, 28])
