# Text-to-Image GAN Demo (Low Compute 64×64)
A lightweight runnable notebook implementing a small-scale text-to-image GAN pipeline.
This demo uses a synthetic shapes dataset (colored geometric shapes) to train a conditional GAN.
You can run this on a single GPU or even CPU to see basic results.

In [None]:
# Cell 1: Imports
import os, random, math
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm


In [None]:
# Cell 2: Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUT = Path('ttig_demo_outputs'); OUT.mkdir(exist_ok=True)
IMG_SIZE, BATCH, Z_DIM, TEXT_DIM = 64, 64, 100, 32
LR, EPOCHS, SEED = 2e-4, 30, 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


In [None]:
# Cell 3: Synthetic shapes dataset
SHAPES = ['circle','square','triangle']
COLORS = ['red','green','blue','yellow','magenta','cyan']

def draw_shape(shape, color, size=64):
    img = Image.new('RGB', (size,size), (255,255,255))
    draw = ImageDraw.Draw(img)
    pad = int(size*0.15)
    bbox = [pad, pad, size-pad, size-pad]
    color_map = {
        'red':(230,25,75), 'green':(60,180,75), 'blue':(0,130,200),
        'yellow':(255,225,25), 'magenta':(240,50,230), 'cyan':(70,240,240)
    }
    c = color_map[color]
    if shape=='circle':
        draw.ellipse(bbox, fill=c)
    elif shape=='square':
        draw.rectangle(bbox, fill=c)
    elif shape=='triangle':
        x0,y0,x1,y1 = bbox
        pts = [(size/2,y0),(x1,y1),(x0,y1)]
        draw.polygon(pts, fill=c)
    return img

class ShapesDataset(Dataset):
    def __init__(self, n_images=2000, img_size=64):
        self.records = []
        for _ in range(n_images):
            shape = random.choice(SHAPES)
            color = random.choice(COLORS)
            caption = f"{color} {shape}"
            self.records.append({'shape':shape,'color':color,'caption':caption})
        self.img_size = img_size

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        img = draw_shape(rec['shape'], rec['color'], self.img_size)
        arr = torch.tensor(np.array(img).transpose(2,0,1)/127.5-1.0, dtype=torch.float32)
        txt = rec['caption']
        return arr, txt


In [None]:
# Cell 4: Text embedding + models
class SimpleTextEmbed(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.word_to_idx = {w:i for i,w in enumerate(vocab)}
        self.emb = nn.Embedding(len(vocab), TEXT_DIM)
    def forward(self, captions):
        ids = []
        for cap in captions:
            toks = cap.split()
            ids.append([self.word_to_idx[t] for t in toks])
        return self.emb(torch.tensor(ids)) .mean(1)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(Z_DIM+TEXT_DIM, 512*4*4), nn.ReLU(True))
        self.net = nn.Sequential(
            nn.ConvTranspose2d(512,256,4,2,1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256,128,4,2,1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128,64,4,2,1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.Conv2d(64,3,3,1,1), nn.Tanh())
    def forward(self, z, txt):
        x = torch.cat([z,txt],1)
        x = self.fc(x).view(-1,512,4,4)
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,64,4,2,1), nn.LeakyReLU(0.2,True),
            nn.Conv2d(64,128,4,2,1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2,True),
            nn.Conv2d(128,256,4,2,1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2,True),
            nn.Conv2d(256,512,4,2,1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2,True))
        self.fc_img = nn.Linear(512*4*4, 1)
        self.fc_txt = nn.Linear(TEXT_DIM, 512*4*4)
    def forward(self, img, txt):
        h = self.conv(img).view(img.size(0), -1)
        proj = torch.sum(h * self.fc_txt(txt), 1, keepdim=True)
        return self.fc_img(h) + proj


In [None]:
# Cell 5: Training loop
dataset = ShapesDataset(1000, IMG_SIZE)
vocab = sorted(list(set(sum([c.split() for _,c in dataset],[]))))
text_embed = SimpleTextEmbed(vocab).to(DEVICE)
G, D = Generator().to(DEVICE), Discriminator().to(DEVICE)
optG = torch.optim.Adam(G.parameters(), lr=LR, betas=(0.5,0.999))
optD = torch.optim.Adam(D.parameters(), lr=LR, betas=(0.5,0.999))
loader = DataLoader(dataset, batch_size=BATCH, shuffle=True)

def d_loss(real_pred, fake_pred):
    return torch.mean(F.relu(1.0 - real_pred)) + torch.mean(F.relu(1.0 + fake_pred))

def g_loss(fake_pred):
    return -torch.mean(fake_pred)

for epoch in range(EPOCHS):
    for imgs, caps in tqdm(loader):
        imgs = imgs.to(DEVICE)
        txt = text_embed(caps).to(DEVICE)
        z = torch.randn(imgs.size(0), Z_DIM, device=DEVICE)
        fake = G(z, txt)

        real_pred = D(imgs, txt)
        fake_pred = D(fake.detach(), txt)
        lossD = d_loss(real_pred, fake_pred)
        optD.zero_grad(); lossD.backward(); optD.step()

        fake_pred_g = D(fake, txt)
        lossG = g_loss(fake_pred_g)
        optG.zero_grad(); lossG.backward(); optG.step()

    print(f"Epoch {epoch+1}/{EPOCHS}  LossD={lossD.item():.3f}  LossG={lossG.item():.3f}")
    with torch.no_grad():
        sample_z = torch.randn(16, Z_DIM, device=DEVICE)
        sample_txt = text_embed([random.choice(COLORS)+' '+random.choice(SHAPES) for _ in range(16)]).to(DEVICE)
        samples = G(sample_z, sample_txt)
        save_image((samples+1)/2, OUT/f'sample_{epoch:03d}.png', nrow=4)
print('Training complete! Check ttig_demo_outputs for images.')


# Improvements: Text preprocessing, BERT embeddings, and Evaluation
This appended section improves your notebook technically:

1. **Text preprocessing**: normalization, punctuation removal, optional stopword removal, and preparing data for BERT tokenization.

2. **BERT embeddings**: example code using Hugging Face `transformers` to get sentence / token embeddings (pooled output or average of token vectors).

3. **Evaluation metrics**: functions to compute Inception Score (IS) and Frechet Inception Distance (FID) using torchvision's InceptionV3 activations. These are implemented as reusable functions — you can run them on folders of images (real vs generated) or on tensors.

4. **Explanations and inline comments**: each code block includes comments and rationale so you can understand and modify as required.

> Notes: running the cells requires internet access the first time to download pretrained models (BERT / Inception) and may require installing packages. An installation cell is included.

In [None]:

# Install necessary packages. Run this cell once.
# If you're running on Google Colab, add an exclamation mark (!) before pip, e.g. !pip install ...
import sys
import subprocess
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

# Recommended packages
pkgs = [
    "transformers>=4.0.0",     # Hugging Face transformers for BERT embeddings
    "torch",                   # PyTorch (may already be installed in your environment)
    "torchvision",             # For Inception model and image utilities
    "tqdm",
    "numpy",
    "scipy",
    "pillow",
    "nltk"
]
for p in pkgs:
    try:
        __import__(p.split('==')[0])
    except Exception:
        print("Installing", p)
        pip_install(p)

# Download NLTK data for preprocessing
import nltk
nltk.download('punkt')
nltk.download('stopwords')
print('Setup done — restart the kernel if necessary.')


In [None]:

# Text preprocessing and BERT tokenization example.
# This cell demonstrates:
# - Normalization (lowercasing, unicode normalization)
# - Punctuation removal
# - Optional stopword removal using NLTK
# - Conversion to token ids / attention masks via Hugging Face tokenizer
from transformers import BertTokenizer, BertModel
import torch
import re
import unicodedata
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

stop_words = set(stopwords.words('english'))

def normalize_text(text):
    # Unicode normalize, lowercase, strip
    text = unicodedata.normalize('NFKC', text)
    text = text.lower().strip()
    return text

def clean_text(text, remove_stopwords=False):
    text = normalize_text(text)
    # Remove punctuation (keep basic tokens); you can customize regex
    text = re.sub(r'[^\w\s]', ' ', text)
    # Reduce multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    if remove_stopwords:
        tokens = word_tokenize(text)
        tokens = [t for t in tokens if t not in stop_words]
        text = ' '.join(tokens)
    return text

# Example usage
examples = [
    "A small red bird sitting on a branch, looking to the left.",
    "An astronaut riding a horse on Mars!"
]
cleaned = [clean_text(t, remove_stopwords=True) for t in examples]
print('Original -> Cleaned:')
for o,c in zip(examples, cleaned):
    print('-', o, '->', c)

# Tokenize with BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
enc = tokenizer(cleaned, padding=True, truncation=True, max_length=64, return_tensors='pt')
print('\nTokenized keys:', enc.keys())
print('input_ids shape:', enc['input_ids'].shape)


In [None]:

# Obtain BERT embeddings (two common strategies):
# 1) Use pooled_output (CLS token) as a sentence embedding.
# 2) Average token embeddings (excluding padding) for a mean-pooled embedding.
from transformers import BertModel
import torch
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

def get_bert_embeddings(input_ids, attention_mask, pooling='cls'):
    '''
    input_ids, attention_mask: torch tensors (batch, seq_len)
    pooling: 'cls' or 'mean'
    Returns: (batch, hidden_size) tensor
    '''
    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = out.last_hidden_state  # (B, L, H)
        pooled = out.pooler_output           # (B, H) — corresponds to [CLS] after dense+tanh
        if pooling == 'cls':
            return pooled
        elif pooling == 'mean':
            # Compute mean over valid tokens only
            mask = attention_mask.unsqueeze(-1).expand_as(last_hidden).float()
            summed = (last_hidden * mask).sum(1)
            counts = mask.sum(1).clamp(min=1e-9)
            return summed / counts
        else:
            raise ValueError('Unknown pooling type')

# Example: compute embeddings for cleaned examples
embs = get_bert_embeddings(enc['input_ids'], enc['attention_mask'], pooling='mean')
print('BERT embeddings shape (mean-pooled):', embs.shape)


In [None]:

# Functions to compute Inception activations, Frechet Inception Distance (FID), and Inception Score (IS).
# Implementation notes:
# - We use torchvision's pretrained InceptionV3 (transform input to 299x299 and apply required preprocessing)
# - FID computes statistics (mean, covariance) of activations and compares real vs generated sets.
# - IS computes KL divergence between conditional label distribution and marginal label distribution using softmax outputs from Inception.
import torch
from torchvision import transforms, models
from PIL import Image
import numpy as np
from scipy import linalg
import os
from tqdm import tqdm

_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Preprocessing to match InceptionV3 expected inputs
_inception_transform = transforms.Compose([
    transforms.Resize((299,299)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229,0.224,0.225])
])

def load_image_tensor(path):
    img = Image.open(path).convert('RGB')
    return _inception_transform(img).unsqueeze(0)  # (1,3,299,299)

# Load inception model
_inception = models.inception_v3(pretrained=True, transform_input=False).to(_device)
_inception.eval()

# We will use the pool3 features (2048-d) for FID. For torchvision's inception, we can access 'Mixed_7c' output via forward hook.
# Simpler approach: run inception and take the last pooling layer 'fc' input. We'll use the model's forward to obtain features
from torch import nn
class InceptionActivations(nn.Module):
    def __init__(self, inception_model):
        super().__init__()
        # copy model up to the last pooling layer
        self.model = inception_model
    def forward(self, x):
        # Inception forward returns logits; to get features we run through model until the pool layer
        # Use the model's forward but intercept features; torchvision's inception has attribute 'AuxLogits' and uses different branches.
        # A pragmatic (but slightly slower) approach: run model, then grab the penultimate layer before fc by forward hooks.
        # For clarity and reproducibility we will run the model and extract the 'Mixed_7c' activations via hooks.
        raise RuntimeError('Use helper function get_activations_from_files for ease.')

def get_activations_from_files(file_list, batch_size=8):
    # Compute activations (2048-d pool3) for a list of image file paths
    acts = []
    _inception.to(_device)
    for i in range(0, len(file_list), batch_size):
        batch_paths = file_list[i:i+batch_size]
        batch = torch.cat([load_image_tensor(p) for p in batch_paths], dim=0).to(_device)
        # InceptionV3 in torchvision expects a special forward when aux_logits present; set transform_input False.
        with torch.no_grad():
            # Run model up to last pooling
            preds = _inception(batch)
            # Unfortunately torchvision's inception returns logits; to get pool features we can access the model's last bottleneck.
            # As a workaround, run the model and compute features using adaptive avg pool on the final convolution output.
            # Accessing internal layers:
            # The following navigates the model to get final conv features — this depends on torchvision implementation.
            try:
                # try accessing 'Mixed_7c' output by running forward through children
                x = batch
                for name, module in _inception.named_children():
                    if name == 'fc' or name == 'AuxLogits':
                        break
                    x = module(x)
                # x should now be of shape (B, 2048, 1, 1)
                feat = nn.functional.adaptive_avg_pool2d(x, (1,1)).reshape(x.size(0), -1)
            except Exception as e:
                # Fallback: use logits as features (not ideal but still usable for coarse comparisons)
                feat = preds
            acts.append(feat.cpu().numpy())
    acts = np.concatenate(acts, axis=0)
    return acts

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    # From original FID implementation
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2*covmean)
    return fid

def compute_fid(real_files, gen_files, batch_size=8):
    act1 = get_activations_from_files(real_files, batch_size=batch_size)
    act2 = get_activations_from_files(gen_files, batch_size=batch_size)
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    return float(fid_value)

def inception_score_from_probs(preds, eps=1e-16):
    # preds: numpy array (N, num_classes) of softmax scores from Inception logits
    p_y = np.mean(preds, axis=0)
    kl = preds * (np.log(preds + eps) - np.log(p_y + eps))
    sum_kl = np.sum(kl, axis=1)
    return float(np.exp(np.mean(sum_kl)))

def compute_inception_score_from_files(file_list, splits=10, batch_size=8):
    # Compute softmax predictions for each image using inception, then compute IS
    preds = []
    for i in range(0, len(file_list), batch_size):
        batch_paths = file_list[i:i+batch_size]
        batch = torch.cat([load_image_tensor(p) for p in batch_paths], dim=0).to(_device)
        with torch.no_grad():
            logits = _inception(batch)
            soft = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()
            preds.append(soft)
    preds = np.concatenate(preds, axis=0)
    # Compute IS by splitting
    N = preds.shape[0]
    split_scores = []
    split_size = N // splits
    for k in range(splits):
        part = preds[k*split_size: (k+1)*split_size]
        split_scores.append(inception_score_from_probs(part))
    return float(np.mean(split_scores)), float(np.std(split_scores))


In [None]:

# Usage examples for evaluation:
# 1) Save generated images (PIL) to a folder 'generated/' and have a folder 'real/' with real images.
# 2) Run compute_fid(real_files, gen_files) and compute_inception_score_from_files(gen_files).
import glob
# Example paths (replace with your actual folders)
real_folder = 'real_images/'
gen_folder = 'generated_images/'
real_files = sorted(glob.glob(real_folder + '*.png')) + sorted(glob.glob(real_folder + '*.jpg'))
gen_files = sorted(glob.glob(gen_folder + '*.png')) + sorted(glob.glob(gen_folder + '*.jpg'))

if len(real_files) == 0 or len(gen_files) == 0:
    print('No images found in example folders. Save generated images to', gen_folder, 'and real images to', real_folder)
else:
    fid_val = compute_fid(real_files, gen_files, batch_size=8)
    print('FID:', fid_val)
    is_mean, is_std = compute_inception_score_from_files(gen_files, splits=10, batch_size=8)
    print('Inception Score (mean, std):', is_mean, is_std)



# Recommended next steps after these improvements:
# - Replace dummy / small generator with your full text-to-image GAN architecture.
# - Use BERT token embeddings: you may want to fine-tune BERT or use static embeddings depending on dataset size.
# - For cross-attention: pass token-level embeddings (B, T, H) into your CrossAttention block so each token can attend to spatial features.
# - Compute FID between a held-out validation set of real images and generated images (>= 1000 images gives more stable FID, but you can start with 100-500).
# - Consider additional metrics: CLIP-based similarity (semantic similarity between prompt and image), precision/recall for generative models, and diversity metrics.
# - Add more visualizations: image grids, t-SNE of embeddings, and attention maps to inspect what tokens attend to what regions.
