<a href="https://colab.research.google.com/github/tirtthshah/text-to-image-pipeline/blob/main/Task_2(CA).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Cross Attention

In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [None]:
z_dim = 100
lr = 2e-4
batch_size = 8
epochs = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
text_features_dim = 512

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

gen = Generator(z_dim=z_dim, text_features_dim=text_features_dim).to(device)
disc = Discriminator().to(device)

opt_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [None]:
for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):
        if batch_idx >= 10: break

        print(f"Epoch [{epoch+1}/{epochs}] Batch [{batch_idx+1}/10]")

        real = real.to(device)
        noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
        dummy_text_features = torch.randn(batch_size, text_features_dim, 1, 1).to(device)
        combined_input = torch.cat((noise, dummy_text_features), dim=1)

        fake = gen(combined_input, dummy_text_features.squeeze())

        disc_real = disc(real).view(-1)
        loss_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).view(-1)
        loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_real + loss_fake) / 2
        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        output = disc(fake).view(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

gen.eval()
with torch.no_grad():
    noise = torch.randn(64, z_dim, 1, 1).to(device)
    dummy_text_features = torch.randn(64, text_features_dim, 1, 1).to(device)
    combined_input = torch.cat((noise, dummy_text_features), dim=1)

    fake_images = gen(combined_input, dummy_text_features.squeeze()).cpu()
    grid = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)
    plt.figure(figsize=(5, 5))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()