Based on the `nn.Module` class, PyTorch models are defined through `nn.Sequential`, `nn.ModuleList`, and `nn.ModuleDict`.

`nn.Sequential` is a container for sequential layers. It's fit for quickly checking the validation of the model.    
`nn.ModuleList` is a container for a list of modules.    
`nn.ModuleDict` is a container for a dictionary of modules.    

When it comes to retrieving the information in previous layers, e.g. for the residual calculation, `nn.ModuleList` and `nn.ModuleDict` are more flexible than `nn.Sequential`.

### Sequential

In [None]:
from collections import OrderedDict

class MySequential(nn.Module):
    '''
    A custom sequential container for modules
    It's used to assemble multiple nn layers in a sequential manner
    It's similar to PyTorch's nn.Sequential
    but it's more flexible and can handle custom indexing.
    '''

    def __init__(self, *args):
        '''
        If only one argument is provided and it is an OrderedDict,
        then the modules will be added to the model in the order of the keys of the OrderedDict.
        Otherwise, the modules will be added to the model in the order of the arguments.
        '''
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module) # register submodules in the model
                
    def forward(self, input):
        '''
        Iterate through the modules in the model
        and apply them to the input.
        '''
        for module in self._modules.values():
            input = module(input)
        return input

In [None]:
import torch.nn as nn

net = nn.Sequential(
    '''
    We define a neural network with three layers:
    - A linear layer with 784 input features and 256 output features
    - A ReLU activation function
    - A linear layer with 256 input features and 10 output features
    '''
        nn.Linear(784, 256),  # 784 features from the input image
        nn.ReLU(),            # 256 features from the ReLU activation
        nn.Linear(256, 10),   # 10 labels for the output
        )

print(net)

### ModuleList

In [None]:
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # like List's append operation
print(net[-1])                 # like List's indexing
print(net)

class model(nn.Module):
    
    # ...
    
    # use for loop to assign the layer order    
    def forward(self, x):
        for layer in self.modulelist:
            x = layer(x)
        return x

### ModuleDict

In [None]:
net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # add a new layer
print(net['linear'])               # access the layer
print(net.output)
print(net)

---

### U-Net
![UNet Structure](https://datawhalechina.github.io/thorough-pytorch/_images/5.2.1unet.png)

U-Net model consists of the following parts:
1. Double Convolution in each block
2. Max Pooling in the left side of each block
3. Up Sampling in the right side of each block
4. Output

In [None]:
# 1. Define module blocks

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=False):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
# 2. Assemble the model

class UNet(nn.Module):
    
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits