<a href="https://colab.research.google.com/github/tirtthshah/text-to-image-pipeline/blob/main/Task_2(SA).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Self Attention

In [None]:
!pip install torch torchvision matplotlib

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)
        proj_key = self.key(x).view(B, -1, H * W)
        energy = torch.bmm(proj_query, proj_key)
        attention = torch.softmax(energy, dim=-1)
        proj_value = self.value(x).view(B, -1, H * W)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)
        return self.gamma * out + x

class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=1, feature_g=64, text_features_dim=512):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim + text_features_dim, feature_g * 8, 4, 1, 0),
            nn.BatchNorm2d(feature_g * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_g * 8, feature_g * 4, 4, 2, 1),
            nn.BatchNorm2d(feature_g * 4),
            nn.ReLU(True),
            SelfAttention(feature_g * 4),

            nn.ConvTranspose2d(feature_g * 4, feature_g * 2, 4, 2, 1),
            nn.BatchNorm2d(feature_g * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_g * 2, img_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_d=64):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, feature_d, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_d, feature_d * 2, 4, 2, 1),
            nn.BatchNorm2d(feature_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(feature_d * 2),

            nn.Conv2d(feature_d * 2, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

In [None]:
z_dim = 100
lr = 2e-4
batch_size = 32
epochs = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
text_features_dim = 512

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

gen = Generator(z_dim=z_dim, text_features_dim=text_features_dim).to(device)
disc = Discriminator().to(device)

opt_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

for epoch in range(epochs):
    for real, _ in loader:
        real = real.to(device)
        noise = torch.randn(batch_size, z_dim, 1, 1).to(device)

        dummy_text_features = torch.randn(batch_size, text_features_dim, 1, 1).to(device)
        combined_input = torch.cat((noise, dummy_text_features), dim=1)

        fake = gen(combined_input)

        disc_real = disc(real).view(-1)
        loss_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).view(-1)
        loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_real + loss_fake) / 2
        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        output = disc(fake).view(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

    gen.eval()
with torch.no_grad():
    noise = torch.randn(64, z_dim, 1, 1).to(device)

    dummy_text_features = torch.randn(64, text_features_dim, 1, 1).to(device)
    combined_input = torch.cat((noise, dummy_text_features), dim=1)

    fake_images = gen(combined_input).cpu()
    grid = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)
    plt.figure(figsize=(5,5))
    plt.axis("off")
    plt.show()