# Text-to-Image GAN Demo (Low Compute 64×64)
A lightweight runnable notebook implementing a small-scale text-to-image GAN pipeline.
This demo uses a synthetic shapes dataset (colored geometric shapes) to train a conditional GAN.
You can run this on a single GPU or even CPU to see basic results.

In [None]:
# Cell 1: Imports
import os, random, math
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm


In [None]:
# Cell 2: Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUT = Path('ttig_demo_outputs'); OUT.mkdir(exist_ok=True)
IMG_SIZE, BATCH, Z_DIM, TEXT_DIM = 64, 64, 100, 32
LR, EPOCHS, SEED = 2e-4, 30, 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


In [None]:
# Cell 3: Synthetic shapes dataset
SHAPES = ['circle','square','triangle']
COLORS = ['red','green','blue','yellow','magenta','cyan']

def draw_shape(shape, color, size=64):
    img = Image.new('RGB', (size,size), (255,255,255))
    draw = ImageDraw.Draw(img)
    pad = int(size*0.15)
    bbox = [pad, pad, size-pad, size-pad]
    color_map = {
        'red':(230,25,75), 'green':(60,180,75), 'blue':(0,130,200),
        'yellow':(255,225,25), 'magenta':(240,50,230), 'cyan':(70,240,240)
    }
    c = color_map[color]
    if shape=='circle':
        draw.ellipse(bbox, fill=c)
    elif shape=='square':
        draw.rectangle(bbox, fill=c)
    elif shape=='triangle':
        x0,y0,x1,y1 = bbox
        pts = [(size/2,y0),(x1,y1),(x0,y1)]
        draw.polygon(pts, fill=c)
    return img

class ShapesDataset(Dataset):
    def __init__(self, n_images=2000, img_size=64):
        self.records = []
        for _ in range(n_images):
            shape = random.choice(SHAPES)
            color = random.choice(COLORS)
            caption = f"{color} {shape}"
            self.records.append({'shape':shape,'color':color,'caption':caption})
        self.img_size = img_size

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        img = draw_shape(rec['shape'], rec['color'], self.img_size)
        arr = torch.tensor(np.array(img).transpose(2,0,1)/127.5-1.0, dtype=torch.float32)
        txt = rec['caption']
        return arr, txt


In [None]:
# Cell 4: Text embedding + models
class SimpleTextEmbed(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.word_to_idx = {w:i for i,w in enumerate(vocab)}
        self.emb = nn.Embedding(len(vocab), TEXT_DIM)
    def forward(self, captions):
        ids = []
        for cap in captions:
            toks = cap.split()
            ids.append([self.word_to_idx[t] for t in toks])
        return self.emb(torch.tensor(ids)) .mean(1)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(Z_DIM+TEXT_DIM, 512*4*4), nn.ReLU(True))
        self.net = nn.Sequential(
            nn.ConvTranspose2d(512,256,4,2,1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256,128,4,2,1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128,64,4,2,1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.Conv2d(64,3,3,1,1), nn.Tanh())
    def forward(self, z, txt):
        x = torch.cat([z,txt],1)
        x = self.fc(x).view(-1,512,4,4)
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,64,4,2,1), nn.LeakyReLU(0.2,True),
            nn.Conv2d(64,128,4,2,1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2,True),
            nn.Conv2d(128,256,4,2,1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2,True),
            nn.Conv2d(256,512,4,2,1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2,True))
        self.fc_img = nn.Linear(512*4*4, 1)
        self.fc_txt = nn.Linear(TEXT_DIM, 512*4*4)
    def forward(self, img, txt):
        h = self.conv(img).view(img.size(0), -1)
        proj = torch.sum(h * self.fc_txt(txt), 1, keepdim=True)
        return self.fc_img(h) + proj


In [None]:
# Cell 5: Training loop
dataset = ShapesDataset(1000, IMG_SIZE)
vocab = sorted(list(set(sum([c.split() for _,c in dataset],[]))))
text_embed = SimpleTextEmbed(vocab).to(DEVICE)
G, D = Generator().to(DEVICE), Discriminator().to(DEVICE)
optG = torch.optim.Adam(G.parameters(), lr=LR, betas=(0.5,0.999))
optD = torch.optim.Adam(D.parameters(), lr=LR, betas=(0.5,0.999))
loader = DataLoader(dataset, batch_size=BATCH, shuffle=True)

def d_loss(real_pred, fake_pred):
    return torch.mean(F.relu(1.0 - real_pred)) + torch.mean(F.relu(1.0 + fake_pred))

def g_loss(fake_pred):
    return -torch.mean(fake_pred)

for epoch in range(EPOCHS):
    for imgs, caps in tqdm(loader):
        imgs = imgs.to(DEVICE)
        txt = text_embed(caps).to(DEVICE)
        z = torch.randn(imgs.size(0), Z_DIM, device=DEVICE)
        fake = G(z, txt)

        real_pred = D(imgs, txt)
        fake_pred = D(fake.detach(), txt)
        lossD = d_loss(real_pred, fake_pred)
        optD.zero_grad(); lossD.backward(); optD.step()

        fake_pred_g = D(fake, txt)
        lossG = g_loss(fake_pred_g)
        optG.zero_grad(); lossG.backward(); optG.step()

    print(f"Epoch {epoch+1}/{EPOCHS}  LossD={lossD.item():.3f}  LossG={lossG.item():.3f}")
    with torch.no_grad():
        sample_z = torch.randn(16, Z_DIM, device=DEVICE)
        sample_txt = text_embed([random.choice(COLORS)+' '+random.choice(SHAPES) for _ in range(16)]).to(DEVICE)
        samples = G(sample_z, sample_txt)
        save_image((samples+1)/2, OUT/f'sample_{epoch:03d}.png', nrow=4)
print('Training complete! Check ttig_demo_outputs for images.')
