<a href="https://colab.research.google.com/github/travislatchman/Single-Image-Deraining/blob/main/Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# (THESE WERE NOT USED) - INITIAL IDEAS FOR SELF-ATTENTION

In [None]:
import torch.nn.functional as F
import torch
import torch.nn as nn

# U-Net + Self-Attention

Add a self-attention layer after each convolutional layer in the encoder and decoder parts of the U-Net. By adding self-attention layers, the U-Net model could have the ability to capture long-range dependencies and focus on important regions in the input, which may lead to improved deraining performance. However, adding self-attention mechanisms increased the model's complexity and computational cost.

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, _, height, width = x.size()
        query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, height * width)
        energy = torch.bmm(query, key)
        attention = F.softmax(energy, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, -1, height, width)
        out = self.gamma * out + x
        return out

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                SelfAttention(out_channels),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                SelfAttention(out_channels)
            )


        def up_block(in_channels, out_channels):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                nn.ReLU(inplace=True)
            )

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)
        
        self.middle = conv_block(512, 1024)

        self.up4 = up_block(1024, 512)
        self.dec4 = conv_block(1024, 512)
        self.up3 = up_block(512, 256)
        self.dec3 = conv_block(512, 256)
        self.up2 = up_block(256, 128)
        self.dec2 = conv_block(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.output = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        middle = self.middle(self.pool(enc4))

        up4 = self.up4(middle)
        merge4 = torch.cat([enc4, up4], dim=1)
        dec4 = self.dec4(merge4)

        up3 = self.up3(dec4)
        merge3 = torch.cat([enc3, up3], dim=1)
        dec3 = self.dec3(merge3)

        up2 = self.up2(dec3)
        merge2 = torch.cat([enc2, up2], dim=1)
        dec2 = self.dec2(merge2)

        up1 = self.up1(dec2)
        merge1 = torch.cat([enc1, up1], dim=1)
        dec1 = self.dec1(merge1)

        output = self.output(dec1)
        return output

unet = UNet(3, 3)  # Adjust input and output channels to 3 for RGB images

# U-Net + Self-Attention + Residual

integrate both residual blocks and self-attention mechanisms into a U-Net architecture. The ResidualBlock class includes a self-attention mechanism, which is applied after the second convolution layer before the residual connection. The encoder and decoder blocks are then replaced with these modified residual blocks.

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, height * width)
        attention = torch.softmax(torch.bmm(query, key), dim=-1)
        value = self.value(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1)).view(batch_size, channels, height, width)
        return self.gamma * out + x

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.attention = SelfAttention(in_channels)

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = self.attention(out)
        return out + residual


class UNetWithResidualAndAttention(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetWithResidualAndAttention, self).__init__()
        
        # ... (other layers such as encoders, decoders, and upsampling)

        self.enc1 = ResidualBlock(64)
        self.enc2 = ResidualBlock(128)
        self.enc3 = ResidualBlock(256)
        self.enc4 = ResidualBlock(512)

        self.middle = ResidualBlock(1024)

        self.dec4 = ResidualBlock(512)
        self.dec3 = ResidualBlock(256)
        self.dec2 = ResidualBlock(128)
        self.dec1 = ResidualBlock(64)

    def forward(self, x):
        # ... (forward pass with the modified architecture)

unet = UNetWithResidualAndAttention(3, 3)

