In [2]:
import torch
from torch import nn

In [14]:
class Generator_ContractingBlock(nn.Module):
    def __init__(self, in_channels, use_dropout = False, use_batchNorm = True):
        super(Generator_ContractingBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        if use_batchNorm:
            self.batchnorm = nn.BatchNorm2d(in_channels * 2)
        self.use_batchNorm = use_batchNorm
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
    def forward(self, x):
        x = self.conv1(x)
        if self.use_batchNorm:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv2(x)
        if self.use_batchNorm:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x= self.maxpool(x)
        return x

In [16]:
img = torch.randn((1,3, 256, 256))
block = Generator_ContractingBlock(in_channels=3)


In [17]:
preds = block(img)

In [19]:
print(img.shape)
print(preds.shape)

torch.Size([1, 3, 256, 256])
torch.Size([1, 6, 128, 128])


In [34]:
class Generator_ExpandingBlock(nn.Module):
    def __init__(self, in_channels, use_dropout = False, use_batchNorm = True):
        super(Generator_ExpandingBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=2, padding=1)
        if use_batchNorm:
            self.batchnorm = nn.BatchNorm2d(in_channels // 2)
        self.use_batchNorm = use_batchNorm
        self.activation = nn.ReLU()
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
    
    def forward(self, x, skip_con_x):
        x = self.upsample(x)
        x = self.conv1(x)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        if self.use_batchNorm:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv3(x)
        if self.use_batchNorm:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

In [42]:
class Unet(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels =32):
        super(Unet, self).__init__()
        self.upfeature = nn.Conv2d(in_channels, out_channels, kernel_size =1)
        self.contract1 =  Generator_ContractingBlock(hidden_channels, use_dropout=True)
        self.contract2 = Generator_ContractingBlock(hidden_channels*2, use_dropout=True)
        self.contract3 =  Generator_ContractingBlock(hidden_channels*4, use_dropout=True)
        self.contract4 = Generator_ContractingBlock(hidden_channels*8)
        self.contract5 = Generator_ContractingBlock(hidden_channels*16)
        self.contract6 = Generator_ContractingBlock(hidden_channels*32)
        
        self.expand0 = Generator_ExpandingBlock(hidden_channels*64)
        self.expand1 = Generator_ExpandingBlock(hidden_channels*32)
        self.expand2 = Generator_ExpandingBlock(hidden_channels*16)
        self.expand3 = Generator_ExpandingBlock(hidden_channels*8)
        self.expand4 = Generator_ExpandingBlock(hidden_channels*4)
        self.expand5 = Generator_ExpandingBlock(hidden_channels*2)
        self.downfeature = nn.Conv2d(in_channels, out_channels, kernel_size =1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x0 = self.upfeature(x)
        x1 = self. contract1(x0)
        x2 = self.contract2(x1)
        x3 = self. contract3(x2)
        x4 = self.contract4(x3)
        x5 = self. contract5(x4)
        x6 = self.contract6(x5)
   
   