<a href="https://colab.research.google.com/github/travislatchman/Single-Image-Deraining/blob/main/BasicRestormer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### he original Restormer model is quite complex and contains many components. In this basic version, I've simplified the architecture by using fewer Transformer blocks and eliminating the downsampling and upsampling layers. This should make the model easier to train and understand, while still providing reasonable performance for single-image deraining tasks.

In [None]:
class BasicRestormer(nn.Module):
    def __init__(self, num_blocks=4, num_heads=4, channels=48, expansion_factor=2):
        super(BasicRestormer, self).__init__()

        self.embed_conv = nn.Conv2d(3, channels, kernel_size=3, padding=1, bias=False)

        self.transformer_blocks = nn.Sequential(*[TransformerBlock(
            channels, num_heads, expansion_factor) for _ in range(num_blocks)])

        self.output = nn.Conv2d(channels, 3, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        fo = self.embed_conv(x)
        ft = self.transformer_blocks(fo)
        out = self.output(ft) + x
        return out

### enhance the Restormer model with Residual Connections

In [None]:
class ResidualTransformerBlock(nn.Module):
    def __init__(self, channels, num_heads, expansion_factor):
        super(ResidualTransformerBlock, self).__init__()

        self.norm1 = nn.LayerNorm(channels)
        self.attn = MDTA(channels, num_heads)
        self.norm2 = nn.LayerNorm(channels)
        self.ffn = GDFN(channels, expansion_factor)

    def forward(self, x):
        attn_out = self.attn(self.norm1(x))
        x = x + attn_out
        ffn_out = self.ffn(self.norm2(x))
        x = x + ffn_out
        return x

### Add downsampling and upsampling layers: Implement a U-Net-like architecture by incorporating downsampling and upsampling layers between the Transformer blocks. This can help the model to capture hierarchical features at different scales, which could improve its performance on the deraining task.

In [None]:
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)

    def forward(self, x):
        return self.conv(x)


class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * 4, kernel_size=3, padding=1, bias=False)
        self.pixel_shuffle = nn.PixelShuffle(2)

    def forward(self, x):
        return self.pixel_shuffle(self.conv(x))

In [None]:
class UNetRestormer(nn.Module):
    def __init__(self, num_blocks=4, num_heads=4, channels=48, expansion_factor=2):
        super(UNetRestormer, self).__init__()

        self.embed_conv = nn.Conv2d(3, channels, kernel_size=3, padding=1, bias=False)

        # Encoder layers
        self.enc_transformer_blocks = nn.Sequential(*[ResidualTransformerBlock(
            channels, num_heads, expansion_factor) for _ in range(num_blocks)])

        self.downsample = Downsample(channels, channels * 2)

        # Bottleneck layers
        self.bottleneck_transformer_blocks = nn.Sequential(*[ResidualTransformerBlock(
            channels * 2, num_heads, expansion_factor) for _ in range(num_blocks)])

        self.upsample = Upsample(channels * 2, channels)

        # Decoder layers
        self.dec_transformer_blocks = nn.Sequential(*[ResidualTransformerBlock(
            channels, num_heads, expansion_factor) for _ in range(num_blocks)])

        self.output = nn.Conv2d(channels, 3, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        # Encoder
        fo = self.embed_conv(x)
        enc_out = self.enc_transformer_blocks(fo)
        enc_down_out = self.downsample(enc_out)

        # Bottleneck
        bottleneck_out = self.bottleneck_transformer_blocks(enc_down_out)

        # Decoder
        upsampled_out = self.upsample(bottleneck_out)
        dec_in = torch.cat([upsampled_out, enc_out], dim=1)  # Skip connection
        dec_out = self.dec_transformer_blocks(dec_in)

        # Output
        out = self.output(dec_out) + x
        return out