In [13]:
# https://blog.csdn.net/qq_40211493/article/details/106790238 
# https://blog.csdn.net/qq_37280534/article/details/115374551?spm=1001.2101.3001.6661.1&utm_medium=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7ERate-1-115374551-blog-106790238.pc_relevant_3mothn_strategy_recovery&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7ERate-1-115374551-blog-106790238.pc_relevant_3mothn_strategy_recovery&utm_relevant_index=1

In [14]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F

In [16]:
class InceptionA(nn.Module):
    def __init__(self, in_channels):
        super(InceptionA, self).__init__()
        self.branch1_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        
        self.branch5_5_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        self.branch5_5_2 = nn.Conv2d(16, 24, kernel_size=5,padding=2)
        
        self.branch3_3_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        self.branch3_3_2 = nn.Conv2d(16, 24, kernel_size=3, padding =1)
        self.branch3_3_3 = nn.Conv2d(24, 24, kernel_size=3, padding=1)
        
        self.branch_pool = nn.Conv2d(in_channels, 24, kernel_size = 1)
    def forward(self, x):
        out_branch1_1 = self.branch1_1(x)
        print('\n out_branch1_1: {}'.format(out_branch1_1.shape))
        out_branch5_5 = self.branch5_5_1(x)
        out_branch5_5 = self.branch5_5_2(out_branch5_5)
        print('\n out_branch5_5: {}'.format(out_branch5_5.shape))
        out_branch3_3 = self.branch3_3_1(x)
        out_branch3_3 = self.branch3_3_2(out_branch3_3)
        out_branch3_3 = self.branch3_3_3(out_branch3_3)
        print('\n out_branch3_3: {}'.format(out_branch3_3.shape))
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)
        print('\n branch_pool: {}'.format(branch_pool.shape))
        outputs = [out_branch1_1, out_branch5_5, out_branch3_3, branch_pool]
        return torch.cat(outputs, dim=1)
    

In [17]:
input_size = 644
inputs = torch.rand(1,3,input_size,input_size)


In [19]:
incept = InceptionA(3)
incept_out = incept(inputs)
print('\n incept_out: {}'.format(incept_out.shape))


 out_branch1_1: torch.Size([1, 16, 644, 644])

 out_branch5_5: torch.Size([1, 24, 644, 644])

 out_branch3_3: torch.Size([1, 24, 644, 644])

 branch_pool: torch.Size([1, 24, 644, 644])

 incept_out: torch.Size([1, 88, 644, 644])


# MultiResUnet inception

In [20]:
#https://github.com/Cassieyy/MultiResUnet3D/blob/main/MultiResUnet3D.py

class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=1, act='relu'):
        # print(ch_out)
        super(conv_block,self).__init__()
        if act == None:
            self.conv = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size,stride=stride,padding=padding),
                nn.BatchNorm2d(ch_out)
            )
        elif act == 'relu':
            self.conv = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size,stride=stride,padding=padding),
                nn.BatchNorm2d(ch_out),
                nn.ReLU(inplace=True)
            )
        elif act == 'sigmoid':
            self.conv = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size,stride=stride,padding=padding),
                nn.BatchNorm2d(ch_out),
                nn.Sigmoid()
            )

    def forward(self,x):
        x = self.conv(x)
        return x
    


class res_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(res_block,self).__init__()
        self.res = conv_block(ch_in,ch_out,1,1,0,None)
        self.main = conv_block(ch_in,ch_out)
        self.bn = nn.BatchNorm2d(ch_in)
    def forward(self,x):
        res_x = self.res(x)

        main_x = self.main(x)
        out = res_x.add(main_x)
        out = nn.ReLU(inplace=True)(out)
        # print(out.shape[1], type(out.shape[1]))
        # assert 1>3
        out = self.bn(out)
        return out
    


class MultiResBlock(nn.Module):
    def __init__(self,in_ch,U,branch=1, alpha=1.67):
        super(MultiResBlock,self).__init__()
#         self.W = alpha * U
        self.W = U
        self.one_ch = conv_block(in_ch, 1)
#         self.residual_layer = conv_block(1, self.W, 1, 1, 0, act=None)
        self.residual_layer = conv_block(1, self.W)
#         self.conv3x3 = conv_block(1, int(self.W*0.167))
#         self.conv5x5 = conv_block(int(self.W*0.167), int(self.W*0.333))
#         self.conv7x7 = conv_block(int(self.W*0.333), self.W-int(self.W*0.167)-int(self.W*0.333))
        self.conv3x3 = conv_block(1, int(self.W))
        self.conv5x5 = conv_block(int(self.W), int(self.W))
        self.conv7x7 = conv_block(int(self.W), self.W)
        self.maxpool = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU(inplace=True)
#         self.batchnorm_1 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
#         self.batchnorm_2 = nn.BatchNorm2d(int(self.W*0.167)+int(self.W*0.333)+int(self.W*0.5))
        self.batchnorm_1 = nn.BatchNorm2d(self.W)
        self.batchnorm_2 = nn.BatchNorm2d(self.W)
        
    def forward(self, x):
        out = []
        print(x.shape) # 1 51 128 128
        print("\n W=alpha*U :{}\n".format(self.W))
        x = self.one_ch(x) 
        res = self.residual_layer(x)
        out.append(res)
        print("\n res:{}\n".format(res.shape))
        
        sbs = self.conv3x3(x)
        sbs_out = self.maxpool(sbs)
        print("\n out_3*3:{}\n".format(sbs.shape))
        out.append(sbs_out)
        
        obo = self.conv5x5(sbs)
        obo_out = self.maxpool(obo)
        out.append(obo_out)
        print("\n out_5*5:{}\n".format(obo.shape))
        
        cbc = self.conv7x7(obo)
        cbc_out = self.maxpool(cbc)
        print("\n out_7*7:{}\n".format(cbc.shape))
        out.append(cbc_out)
#         all_t = torch.cat((sbs, obo, cbc), 1)
#         print("\n cat_together:{}\n".format(all_t.shape))
#         all_t_b = self.batchnorm_1(all_t)
#         out = all_t_b.add(res)
#         out = self.relu(out)
#         out = self.batchnorm_2(out)
        
        
        
        return out

In [21]:
print("\n\n----branch 2: res_inception------")
in_ch = 64
out_ch= 128
inputs = torch.randn(1,in_ch , 7,7 ) # BCHW 
print(inputs.shape)



----branch 2: res_inception------
torch.Size([1, 64, 7, 7])


In [22]:
model = MultiResBlock(in_ch,128)
outputs = model(inputs)
print(len(outputs))

torch.Size([1, 64, 7, 7])

 W=alpha*U :128


 res:torch.Size([1, 128, 7, 7])


 out_3*3:torch.Size([1, 128, 7, 7])


 out_5*5:torch.Size([1, 128, 7, 7])


 out_7*7:torch.Size([1, 128, 7, 7])

4
