<a href="https://colab.research.google.com/github/travislatchman/Single-Image-Deraining/blob/main/BasicRestormer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### he original Restormer model is quite complex and contains many components. In this basic version, I've simplified the architecture by using fewer Transformer blocks and eliminating the downsampling and upsampling layers. This should make the model easier to train and understand, while still providing reasonable performance for single-image deraining tasks.

In [None]:
class BasicRestormer(nn.Module):
    def __init__(self, num_blocks=4, num_heads=4, channels=48, expansion_factor=2):
        super(BasicRestormer, self).__init__()

        self.embed_conv = nn.Conv2d(3, channels, kernel_size=3, padding=1, bias=False)

        self.transformer_blocks = nn.Sequential(*[TransformerBlock(
            channels, num_heads, expansion_factor) for _ in range(num_blocks)])

        self.output = nn.Conv2d(channels, 3, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        fo = self.embed_conv(x)
        ft = self.transformer_blocks(fo)
        out = self.output(ft) + x
        return out

### enhance the Restormer model with Residual Connections

In [None]:
class ResidualTransformerBlock(nn.Module):
    def __init__(self, channels, num_heads, expansion_factor):
        super(ResidualTransformerBlock, self).__init__()

        self.norm1 = nn.LayerNorm(channels)
        self.attn = MDTA(channels, num_heads)
        self.norm2 = nn.LayerNorm(channels)
        self.ffn = GDFN(channels, expansion_factor)

    def forward(self, x):
        attn_out = self.attn(self.norm1(x))
        x = x + attn_out
        ffn_out = self.ffn(self.norm2(x))
        x = x + ffn_out
        return x