<a href="https://colab.research.google.com/github/tirtthshah/text-to-image-pipeline/blob/main/Task_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install matplotlib numpy torch torchvision

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw

In [None]:
def generate_shape(label, size=28):
    img = Image.new("L", (size, size), 0)
    draw = ImageDraw.Draw(img)
    if label == 0:
        draw.ellipse((6, 6, 22, 22), fill=255)
    elif label == 1:
        draw.rectangle((6, 6, 22, 22), fill=255)
    return transforms.ToTensor()(img)

def create_dataset(num_samples=1000):
    images, labels = [], []
    for _ in range(num_samples):
        label = np.random.randint(0, 2)
        img = generate_shape(label)
        images.append(img)
        labels.append(label)
    return torch.stack(images), torch.tensor(labels)

images, labels = create_dataset()

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, label_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim + label_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        input = torch.cat((z, labels), dim=1)
        out = self.model(input)
        return out.view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self, label_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784 + label_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        input = torch.cat((img.view(img.size(0), -1), labels), dim=1)
        return self.model(input)

In [None]:
latent_dim = 100
label_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = Generator(latent_dim, label_dim).to(device)
D = Discriminator(label_dim).to(device)

criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002)

In [None]:
epochs = 100
batch_size = 64

for epoch in range(epochs):
    for i in range(0, len(images), batch_size):
        real_imgs = images[i:i+batch_size].to(device)
        real_labels = labels[i:i+batch_size]
        real_onehot = F.one_hot(real_labels, num_classes=2).float().to(device)

        z = torch.randn(real_imgs.size(0), latent_dim).to(device)
        fake_labels = torch.randint(0, 2, (real_imgs.size(0),)).to(device)
        fake_onehot = F.one_hot(fake_labels, num_classes=2).float()
        fake_imgs = G(z, fake_onehot)

        real_validity = D(real_imgs, real_onehot)
        fake_validity = D(fake_imgs.detach(), fake_onehot)

        d_loss = criterion(real_validity, torch.ones_like(real_validity)) + \
                 criterion(fake_validity, torch.zeros_like(fake_validity))

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        gen_validity = D(fake_imgs, fake_onehot)
        g_loss = criterion(gen_validity, torch.ones_like(gen_validity))

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    with torch.no_grad():
        real_pred = (real_validity > 0.5).float()
        fake_pred = (fake_validity < 0.5).float()
        real_acc = real_pred.mean().item()
        fake_acc = fake_pred.mean().item()
        d_acc = (real_acc + fake_acc) / 2

    print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | D Acc: {d_acc*100:.2f}%")

In [None]:
def generate_and_plot(label):
    z = torch.randn(1, latent_dim).to(device)
    label_onehot = F.one_hot(torch.tensor([label]), num_classes=2).float().to(device)
    img = G(z, label_onehot).detach().cpu().squeeze()
    plt.imshow(img, cmap="gray")
    plt.title("Circle" if label == 0 else "Square")
    plt.axis("off")
    plt.show()

generate_and_plot(0)
generate_and_plot(1)

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw
from torchvision.utils import save_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
label_dim = 2
epochs = 100
batch_size = 64

def generate_shape(label, size=128):
    img = Image.new("L", (size, size), 0)
    draw = ImageDraw.Draw(img)
    if label == 0: draw.ellipse((32, 32, 96, 96), fill=255)
    elif label == 1: draw.rectangle((32, 32, 96, 96), fill=255)
    return transforms.ToTensor()(img) * 2 - 1

def create_dataset(n=1000):
    imgs, lbls = [], []
    for _ in range(n):
        label = np.random.randint(0, 2)
        imgs.append(generate_shape(label))
        lbls.append(label)
    return torch.stack(imgs), torch.tensor(lbls)

images, labels = create_dataset()

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim + label_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 16384),
            nn.Tanh()
        )
    def forward(self, z, y):
        input = torch.cat((z, y), dim=1)
        return self.model(input).view(-1, 1, 128, 128)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(16384 + label_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
        )
    def forward(self, x, y):
        input = torch.cat((x.view(x.size(0), -1), y), dim=1)
        return self.model(input)

G = Generator().to(device)
D = Discriminator().to(device)
opt_G = torch.optim.Adam(G.parameters(), lr=0.0002)
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002)

for epoch in range(epochs):
    for i in range(0, len(images), batch_size):
        real_imgs = images[i:i+batch_size].to(device)
        real_lbls = labels[i:i+batch_size]
        real_onehot = F.one_hot(real_lbls, num_classes=2).float().to(device)

        z = torch.randn(real_imgs.size(0), latent_dim).to(device)
        fake_lbls = torch.randint(0, 2, (real_imgs.size(0),)).to(device)
        fake_onehot = F.one_hot(fake_lbls, num_classes=2).float().to(device)
        fake_imgs = G(z, fake_onehot)

        real_validity = D(real_imgs, real_onehot)
        fake_validity = D(fake_imgs.detach(), fake_onehot)
        d_loss = torch.mean(F.relu(1. - real_validity)) + torch.mean(F.relu(1. + fake_validity))

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        gen_validity = D(fake_imgs, fake_onehot)
        g_loss = -torch.mean(gen_validity)

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

    with torch.no_grad():
        real_acc = (real_validity > 0).float().mean().item()
        fake_acc = (fake_validity < 0).float().mean().item()
        d_acc = (real_acc + fake_acc) / 2

    print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | D Acc: {d_acc*100:.2f}%")

    if (epoch + 1) % 10 == 0:
        titles = ["Circle", "Square"]
        for label in range(2):
            z = torch.randn(1, latent_dim).to(device)
            onehot = F.one_hot(torch.tensor([label]), num_classes=2).float().to(device)
            img = G(z, onehot).detach().cpu().squeeze()
            plt.figure(figsize=(5,5))
            plt.imshow(img, cmap="gray")
            plt.title(titles[label])
            plt.axis("off")
            plt.show()
            save_image(img, f"shape_{titles[label].lower()}_epoch{epoch+1}.png", normalize=True)

In [None]:
from ipywidgets import interact
import ipywidgets as widgets

def generate_and_plot(label_name):
    shape_map = {"circle": 0, "square": 1}
    label = shape_map[label_name]
    z = torch.randn(1, latent_dim).to(device)
    label_onehot = F.one_hot(torch.tensor([label]), num_classes=2).float().to(device)
    img = G(z, label_onehot).detach().cpu().squeeze()

    plt.imshow(img, cmap="gray")
    plt.title(label_name.capitalize())
    plt.axis("off")
    plt.show()

interact(generate_and_plot, label_name=widgets.Dropdown(options=["circle", "square"], description="Shape:"))