In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
print(torch.__version__)
from torchsummary import summary

1.0.1.post2


In [5]:
def conv_2d(inp,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
# in_channels (int) – Number of channels in the input image
# out_channels (int) – Number of channels produced by the convolution
# kernel_size (int or tuple) – Size of the convolving kernel
# stride (int or tuple, optional) – Stride of the convolution. Default: 1
# padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 0
# dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1
# groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1
# bias (bool, optional) – If True, adds a learnable bias to the output. Default: True
    inp_C = torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dialation=dialation,groups=groups,bias=bias)
    out=inp_C(inp)
    inp_C = torch.nn.BatchNorm2d(out, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#num_features – CC from an expected input of size (N, C, H, W)(N,C,H,W)
# eps – a value added to the denominator for numerical stability. Default: 1e-5
# momentum – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1
# affine – a boolean value that when set to True, this module has learnable affine parameters. Default: True
# track_running_stats – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True
    return inp_C

In [6]:
def trans_conv_2dtorch(inp,in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1):
    
    inp_D=torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
    out=inp_C(inp)
    inp_D = torch.nn.BatchNorm2d(out, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    return inp_D

In [7]:
def inblock(inp,inchannel):
    '''
    Block 1 input
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    print(inp)
    shortcut = conv_2d(inp,inchannel,51, (1, 1))# 32filters of size 1x1
    conv3x3 = conv_2d(shortcut,inchannel,8, (3, 3))# 32filters of size 3x3

    conv5x5 = conv_2d(conv3x3,8,17, (3, 3))# 32filters of size 3x3

    conv7x7 = conv_2d(conv5x5,17,26, (3, 3))#32filters of size 3x3

    out = torch.cat((conv3x3, conv5x5, conv7x7), dim=3)
    
    out = torch.nn.BatchNorm2d(out)

    #out = torch.add([shortcut, out])
    out = torch.cat((shortcut, out), dim=3)
    
    out = torch.nn.functional.relu(out, inplace=False)
    
    out = torch.nn.BatchNorm2d(out)

    return out#102

In [8]:
def resblock_A1(inp,filter_size):
        '''
    resblock input
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
        

        '''
     
        inp = torch.nn.functional.relu(inp, inplace=False)  
    
        B1 = conv_2d(inp,102,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(inp,102,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(B2,filter_size,filter_size, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(inp,102,filter_size, (1, 1))# 32filters of size 1x1
    
        B3 = conv_2d(B3,filter_size,48, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(B3,48,64,(3, 3))# 32filters of size 1x1
    
        out = torch.cat((B1, B2, B3), dim=3)
    
        out = torch.nn.BatchNorm2d(out)
                       
        out = conv_2d(out,128,384, (1, 1))# 32filters of size 1x1  
    
        #out = torch.add([inp, out])
        out = torch.cat((inp, out), dim=3)
        out = torch.nn.functional.relu(out, inplace=False)
        out = torch.nn.BatchNorm2d(out)
    
        return out#486

In [9]:
def Path_1(inp):
    '''
    path from first layer to final
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''

    shortcut = conv_2d(inp,486, 32, (1, 1))
    shortcut = torch.nn.functional.relu(shortcut, inplace=False)

    out = conv_2d(inp,486, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)
    out = conv_2d(out,32, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)

    out = torch.add(shortcut, out)#32
    out = torch.nn.functional.relu(out, inplace=False)
    branch = torch.nn.BatchNorm2d(out)
    
    out = conv_2d(branch,32, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)
    out = conv_2d(out, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)

    out = torch.add(branch, out)
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)


    return out#32

In [11]:
def reduction_A(inp):
    '''
    reduction block A
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    m = torch.nn.MaxPool2d(3, stride=2)#486
    pooling = m(inp)
    
    B1 = conv_2d(inp,486,384, (3, 3), strides=(2,2))
    
    B2 = conv_2d(inp,486,192, (1, 1))# 64filters of size 1x1  
    
    B2 = conv_2d(B2,192,224, (3, 3))# 64filters of size 1x1 
    
    B2 = conv_2d(B2,224,256, (3, 3),strides=(2,2))# 64filters of size 1x1 
    
    out = torch.cat((B1, B2, pooling), dim=3)
    
    out = torch.nn.BatchNorm2d(out)
    
    return out#1126

In [None]:
def resblock_A2(inp,filter_size):
        '''
    resblock input
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
        

        '''
     
        inp = torch.nn.functional.relu(inp, inplace=False)  
    
        B1 = conv_2d(inp,486,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(inp,486,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(B2,filter_size,filter_size, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(inp,486,filter_size, (1, 1))# 32filters of size 1x1
    
        B3 = conv_2d(B3,filter_size,48, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(B3,48,128,(3, 3))# 32filters of size 1x1
    
        out = torch.cat((B1, B2, B3), dim=3)
    
        out = torch.nn.BatchNorm2d(out)
                       
        out = conv_2d(out,256,384, (1, 1))# 32filters of size 1x1  
    
        #out = torch.add([inp, out])
        out = torch.cat((inp, out), dim=3)
        out = torch.nn.functional.relu(out, inplace=False)
        out = torch.nn.BatchNorm2d(out)
    
        return out#870

In [None]:
def reduction_B(inp):
    '''
    reduction block A
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    m = torch.nn.MaxPool2d(3, stride=2)
    pooling = m(inp)
    
    B1 = conv_2d(inp,870, 256, (1, 1))
    B1 = conv_2d(B1,256, 384, (3, 3), strides=(2,2))
    
    B2 = conv_2d(inp,870,256, 1, 1)# 64filters of size 1x1
    B2 = conv_2d(B2,256, 288, (1, 1), strides=(2,2))
    
    B3 = conv_2d(inp,870, 256, (1, 1))# 64filters of size 1x1 
    B3 = conv_2d(B3,256, 288, (1, 1))
    B3 = conv_2d(B3,288, 320, (1, 1), strides=(2,2))
    
    out = torch.cat((pooling, B1, B2, B3), dim=3)
    out = torch.nn.BatchNorm2d(out)
    
    return out#2502

In [None]:
def max_pool(inp,filtersize,stride):
    m = torch.nn.MaxPool2d(filtersize, stride=stride)
    pooling = m(inp)
    return pooling

In [12]:
def resblock_B1(inp):
    '''
    resblock B
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    inp = torch.nn.functional.relu(inp, inplace=False)#1126
    
    B1 = conv_2d(inp,870,192, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(inp,870,128, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(B2,128,160, (1, 7))
    
    B2 = conv_2d(B2,160,192, (7, 1))
    
    out = torch.cat((B1, B2), dim=3)#384
    out = torch.nn.BatchNorm2d(out)
    out = conv_2d(out,384,1154, (1, 1))
    
    #out = torch.add([inp, out])
    out = torch.cat((inp, out), dim=3)
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)
    
    return out#2024
    

In [None]:
def resblock_B2(inp):
    '''
    resblock B
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    inp = torch.nn.functional.relu(inp, inplace=False)#1126
    
    B1 = conv_2d(inp,2024,192, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(inp,2024,128, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(B2,128,160, (1, 7))
    
    B2 = conv_2d(B2,160,192, (7, 1))
    
    out = torch.cat((B1, B2), dim=3)#384
    out = torch.nn.BatchNorm2d(out)
    out = conv_2d(out,384,1154, (1, 1))
    
    #out = torch.add([inp, out])
    out = torch.cat((inp, out), dim=3)
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)
    
    return out#3178

In [13]:
def resblock_C(inp):
    '''
    resblock C
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    inp = torch.nn.functional.relu(inp, inplace=False)
    
    B1 = conv_2d(inp,3178,192, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(inp,3178,192, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(B2,192,224, (1, 3))
    
    B2 = conv_2d(B2,192,256, (3, 1))
    
    out = torch.cat((B1, B2), dim=3)
    out = torch.nn.BatchNorm2d(out)
    out = conv_2d(out,448,2048, 1, 1)
    
    #out = torch.add([inp, out])
    out = torch.cat((inp, out), dim=3)
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)
    
    return out#5226

In [15]:
def Path_2(inp):
    '''
    path for second layer
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''

    shortcut = conv_2d(inp,1510, 32, (1, 1))
    shortcut = torch.nn.functional.relu(shortcut, inplace=False)

    out = conv_2d(inp,1510, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)
    out = conv_2d(out, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)

    out = torch.add(shortcut, out)#32
    #out = torch.cat((shortcut, out), dim=3)#64
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)
    
    out = conv_2d(out,32, 32, (3, 3))

    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)


    return out#32

In [16]:
def Path_3(inp):
    '''
    path for third layer
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''

    shortcut = conv_2d(inp,3674, 32, (1, 1))
    shortcut = torch.nn.functional.relu(shortcut, inplace=False)

    out = conv_2d(inp,3674, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)
    out = conv_2d(out,32, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)

    out = torch.add(shortcut, out)
    #out = torch.cat((shortcut, out), dim=3)
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)

    return out#32

In [17]:
def Path_4(inp):
    '''
    path for third layer
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    out = conv_2d(inp,4828, 32, (3, 3))
    out = torch.nn.functional.relu(out, inplace=False)
    
    out = torch.nn.BatchNorm2d(out)



    return out#32

In [None]:
def resblock_B3(path4,inp):
    '''
    resblock B
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    out = torch.cat((path_4, res_c=trans_conv_2d(inp,5226, 256, 2)), dim=3)#
    inp = torch.nn.functional.relu(out, inplace=False)#1126
    
    B1 = conv_2d(inp,288,192, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(inp,288,128, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(B2,128,160, (1, 7))
    
    B2 = conv_2d(B2,160,192, (7, 1))
    
    out = torch.cat((B1, B2), dim=3)#384
    out = torch.nn.BatchNorm2d(out)
    out = conv_2d(out,384,1154, (1, 1))
    
    #out = torch.add([inp, out])
    out = torch.cat((inp, out), dim=3)
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)
    return out#1410


In [None]:
def resblock_B4(path3,inp):
    '''
    resblock B
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''
    out = torch.cat((path_3, res_c=trans_conv_2d(inp,1410, 128, 2), dim=3)#
    inp = torch.nn.functional.relu(out, inplace=False)#1126
    
    B1 = conv_2d(inp,1442,192, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(inp,1442,128, (1, 1))# 32filters of size 1x1
    
    B2 = conv_2d(B2,128,160, (1, 7))
    
    B2 = conv_2d(B2,160,192, (7, 1))
    
    out = torch.cat((B1, B2), dim=3)#384
    out = torch.nn.BatchNorm2d(out)
    out = conv_2d(out,384,1154, (1, 1))
    
    #out = torch.add([inp, out])
    out = torch.cat((inp, out), dim=3)
    out = torch.nn.functional.relu(out, inplace=False)
    out = torch.nn.BatchNorm2d(out)
    return out#2594


In [None]:
def resblock_A3(path2,filter_size,inp):
        '''
    resblock input
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
        

        '''
        out = torch.cat((path_2, res_c=trans_conv_2d(inp,2594, 64, 2), dim=3)#9278
        
        inp = torch.nn.functional.relu(out, inplace=False)  
    
        B1 = conv_2d(inp,2594,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(inp,2594,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(B2,filter_size,filter_size, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(inp,2594,filter_size, (1, 1))# 32filters of size 1x1
    
        B3 = conv_2d(B3,filter_size,48, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(B3,48,64,(3, 3))# 32filters of size 1x1
    
        out = torch.cat((B1, B2, B3), dim=3)#192
    
        out = torch.nn.BatchNorm2d(out)
                       
        out = conv_2d(out,192,384, (1, 1))# 32filters of size 1x1  
    
        #out = torch.add([inp, out])
        out = torch.cat((inp, out), dim=3)
        out = torch.nn.functional.relu(out, inplace=False)
        out = torch.nn.BatchNorm2d(out)
    
        return out#2978

In [None]:
def resblock_A4(path1,filter_size,inp):
        '''
    resblock input
    
    Arguments:
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
        

        '''
        out = torch.cat((path_1, res_c=trans_conv_2d(inp,2978, 32, 2), dim=3)#9694
        
        inp = torch.nn.functional.relu(out, inplace=False)  
    
        B1 = conv_2d(inp,2978,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(inp,2978,filter_size, (1, 1))# 32filters of size 1x1
    
        B2 = conv_2d(B2,filter_size,filter_size, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(inp,2978,filter_size, (1, 1))# 32filters of size 1x1
    
        B3 = conv_2d(B3,filter_size,48, (3, 3))# 32filters of size 1x1
    
        B3 = conv_2d(B3,48,64,(3, 3))# 32filters of size 1x1
    
        out = torch.cat((B1, B2, B3), dim=3)#128
    
        out = torch.nn.BatchNorm2d(out)
                       
        out = conv_2d(out,128,384, (1, 1))# 32filters of size 1x1  
    
        #out = torch.add([inp, out])
        out = torch.cat((inp, out), dim=3)
        out = torch.nn.functional.relu(out, inplace=False)
        out = torch.nn.BatchNorm2d(out)
    
        return out#3362

In [None]:
class  AD_net(nn.Module):
    
    def __init__(self):
        super(AD_net,self).__init__()
        self.input_size=(3,256,256)
        self.input=inblock(self.input_size,self.input_size[0])
        self.resblock_a1=resblock_A1(self.input,32)
        self.maxpool1 = nn.max_pool(self.resblock_a1,2,2)
        self.resblock_a2=resblock_A2(self.maxpool ,64)
        self.maxpool2 = nn.max_pool(self.resblock_a2,2,2)
        self.resblock_b1=resblock_B1(self.maxpool2)
        self.maxpool3 = nn.max_pool(self.resblock_b1,2,2)
        self.resblock_b2=resblock_B2(self.maxpool3)
        self.maxpool4 = nn.max_pool(self.resblock_b2,2,2)
        self.resblock_c=resblock_C(self.maxpool4,32)
        
        self.path1=Path_1(self.resblock_a1)
        self.path2=Path_2(self.resblock_a2
        self.path3=Path_3(self.resblock_b1)
        self.path4=Path_4(self.resblock_b2)
                          
        self.resblock_b3=resblock_B3(self.path4,self.resblock_c)
        self.resblock_b4=resblock_B4(self.path3,self.resblock_b3)
        self.resblock_a3=resblock_A3(self.path2,64,self.resblock_b4)
        self.resblock_a4=resblock_B4(self.path1,32,self.resblock_a3)
                          
    def out_(self):
                          
        out_lay = conv_2d(self.resblock_a4, 3362, 1)
        return out_lay

    

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AD_net()
model = model.to(device)

summary(model)