In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Double Convolutional Layers
For contracting path, it consists of the repeated application of <span style="color:red">two 3 $\times$ 3 convolutions(unpadded convolutions)</span> <br>
Each followed by <span style="color:red">ReLU</span> and <span style="color:red"> a 2 $\times$ 2 max pooling operation with stride 2</span> for downsampling.

In [5]:
class TwoConv(nn.Module):
    # 3x3 conv -> BN -> ReLU -> 3x3 conv -> BN -> ReLU
    # Definition of the two 3x3 convolutions
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.two_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    # Forward Pass
    def forward(self, x):
        return self.two_conv(x)

### Downsampling Layers (Contracting Path)
For contracting path, at each downsampling step we double the number of feature channels. 
A <span style="color:red">2 $\times$ 2 max pooling operation with stride 2</span>.

In [6]:
class Downsampling(nn.Module):
    # 2x2 max pooling operation with stride 2 then apply two convolutional layers
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down_sampling = nn.Sequential(
            nn.MaxPool2d(2),
            TwoConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.down_sampling(x)