<a href="https://colab.research.google.com/github/travislatchman/Single-Image-Deraining/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### create the building blocks for the architecture, including the encoder, decoder, and skip connections. 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dropout_rate=0.5, batch_norm=True):
        super(ConvBlock, self).__init__()
        self.batch_norm = batch_norm
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        if self.batch_norm:
            x = self.bn1(x)
        x = F.relu(self.conv2(x))
        if self.batch_norm:
            x = self.bn2(x)
        x = self.dropout(x)
        return x



### defines a U-Net architecture for single image deraining. You can adjust the input_channels, num_filters_list, and other parameters as needed 

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, num_filters_list):
        super(UNet, self).__init__()

        self.encoder_blocks = []
        self.decoder_blocks = []

        # Encoder
        for i, num_filters in enumerate(num_filters_list[:-1]):
            self.encoder_blocks.append(ConvBlock(in_channels, num_filters))
            in_channels = num_filters

        # Bottleneck
        self.bottleneck = ConvBlock(num_filters_list[-2], num_filters_list[-1])

        # Decoder
        for num_filters in reversed(num_filters_list[1:]):
            self.decoder_blocks.append(ConvBlock(in_channels, num_filters))
            in_channels = num_filters

        # Output layer
        self.output_layer = nn.Conv2d(num_filters_list[0], 1, kernel_size=1)

        self.encoder_blocks = nn.ModuleList(self.encoder_blocks)
        self.decoder_blocks = nn.ModuleList(self.decoder_blocks)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for enc_block in self.encoder_blocks:
            x = enc_block(x)
            skip_connections.append(x)
            x = F.max_pool2d(x, kernel_size=2, stride=2)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder
        for i, dec_block in enumerate(self.decoder_blocks):
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            x = torch.cat((skip_connections[-i-1], x), dim=1)
            x = dec_block(x)

        # Output layer
        x = self.output_layer(x)
        return x



In [None]:
# Model parameters
input_channels = 1
num_filters_list = [64, 128, 256, 512]

# Build U-Net
unet_model = UNet(input_channels, num_filters_list)
print(unet_model)

### U-Net architecture using the functional API

In [None]:
def conv_block(in_channels, out_channels, kernel_size=3, dropout_rate=0.5, batch_norm=True):
    layers = []
    layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=1))
    if batch_norm:
        layers.append(nn.BatchNorm2d(out_channels))
    layers.append(nn.ReLU(inplace=True))
    layers.append(nn.Conv2d(out_channels, out_channels, kernel_size, padding=1))
    if batch_norm:
        layers.append(nn.BatchNorm2d(out_channels))
    layers.append(nn.ReLU(inplace=True))
    if dropout_rate:
        layers.append(nn.Dropout(dropout_rate))
    return nn.Sequential(*layers)

In [None]:
def upconv_block(in_channels, out_channels, kernel_size=2):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=2)
    )

In [None]:
class UNetFunctional(nn.Module):
    def __init__(self, in_channels, num_filters_list):
        super(UNetFunctional, self).__init__()

        # Encoder blocks
        self.enc_blocks = nn.ModuleList([
            conv_block(in_ch, out_ch) for in_ch, out_ch in zip([in_channels, *num_filters_list[:-1]], num_filters_list)
        ])

        # Bottleneck
        self.bottleneck = conv_block(num_filters_list[-1], num_filters_list[-1] * 2)

        # Decoder blocks
        self.upconvs = nn.ModuleList([
            upconv_block(in_ch, out_ch) for in_ch, out_ch in zip(num_filters_list[::-1], num_filters_list[-2::-1])
        ])
        self.dec_blocks = nn.ModuleList([
            conv_block(in_ch * 2, out_ch) for in_ch, out_ch in zip(num_filters_list[::-1], num_filters_list[-2::-1])
        ])

        # Output layer
        self.output_layer = nn.Conv2d(num_filters_list[0], 1, kernel_size=1)

    def forward(self, x):
        encoder_outputs = []
        for enc_block in self.enc_blocks:
            x = enc_block(x)
            encoder_outputs.append(x)
            x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = self.bottleneck(x)

        for upconv, dec_block, skip in zip(self.upconvs, self.dec_blocks, encoder_outputs[::-1]):
            x = upconv(x)
            x = torch.cat((x, skip), dim=1)
            x = dec_block(x)

        x = self.output_layer(x)
        return x

In [None]:
# Model parameters
input_channels = 1
num_filters_list = [64, 128, 256, 512]


In [None]:

# Build U-Net
unet_model_functional = UNetFunctional(input_channels, num_filters_list)
print(unet_model_functional)