In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, skip_connection = False, mid_channels = None):
        super(DoubleConv, self).__init__()
        self.skip_connection = skip_connection
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(mid_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels)
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        if self.skip_connection == False:
            x = self.double_conv(x)
            x = self.relu(x)
            return x
        else:
            # double_conv output is same shape
            x = self.double_conv(x) + x
            x = self.relu(x)
            return x

In [7]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
    
    def forward(self, x):
        return self.conv(x)
    

In [8]:
class RNnet(nn.Module):
    def __init__(self, n_channels):
        super(RNnet, self).__init__()
        self.n_channels = n_channels

        self.layer1 = DoubleConv(n_channels, 64, skip_connection = False, mid_channels = None)
        self.layer2 = self._make_layer(blocks = 4, in_channels = 64, out_channels = 64, skip_connection = True, mid_channels=None)
        self.layer3 = DoubleConv(64, 128, skip_connection= False, mid_channels = None)
        self.layer4 = self._make_layer(blocks = 4, in_channels = 128, out_channles = 128, skip_connection = True, mid_channels = None)
        self.layer5 = OutConv(128, 1)
        self.sigmoid = nn.Sigmoid()
    
    def _make_layer(self, blocks, in_channels, out_channels, skip_connection, mid_channels):
        layers = []
        for _ in range(1, blocks):
            layers.append(DoubleConv(in_channels, out_channels, skip_connection=skip_connection, mid_channels=mid_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.sigmoid(x)
        return x