<a href="https://colab.research.google.com/github/veydantkatyal/image-generator-gan/blob/main/image_generator_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Setup and Imports**

In [1]:
!pip install torch torchvision matplotlib nltk tqdm gradio pycocotools --quiet
import os, json, random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


# **Auto Download & Setup Datasets**
we'll work with oxford 102 flowers only



In [2]:
import os
os.makedirs("data/oxford/images", exist_ok=True)

# Only download flower images, no captions
!wget -q -O data/oxford/102flowers.tgz https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!tar -xzf data/oxford/102flowers.tgz -C data/oxford
!mv data/oxford/jpg/* data/oxford/images/
!rm -r data/oxford/jpg

# **Tokenizer**

In [3]:
import nltk
nltk.download('punkt_tab')
nltk.download('punkt')
from nltk.tokenize import word_tokenize

class Tokenizer:
    def __init__(self):
        self.word2idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx2word = ['<PAD>', '<UNK>']

    def build_vocab(self, captions):
        for caption in captions:
            for word in word_tokenize(caption.lower()):
                if word not in self.word2idx:
                    self.word2idx[word] = len(self.word2idx)
                    self.idx2word.append(word)

    def tokenize(self, sentence, max_len=20):
        tokens = word_tokenize(sentence.lower())
        ids = [self.word2idx.get(w, self.word2idx['<UNK>']) for w in tokens]
        ids += [0] * (max_len - len(ids))
        return torch.tensor(ids[:max_len])

# Build vocab
synthetic_captions = [
    "a red flower with round petals",
    "a yellow flower blooming in sunlight",
    "a white flower with a purple center",
    "a vibrant blue flower in a garden"
]

tokenizer = Tokenizer()
tokenizer.build_vocab(synthetic_captions)


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


# **Dataset with Synthetic Captions**

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os, random

# Dataset Class
class OxfordFlowersDataset(Dataset):
    def __init__(self, image_dir, transform, tokenizer, max_len=20):
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_files = sorted(os.listdir(image_dir))
        self.colors = ["red", "yellow", "blue", "white", "purple", "pink"]
        self.shapes = ["round", "oval", "spiky", "star-shaped", "long", "wide"]

    def __len__(self): return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = self.transform(Image.open(img_path).convert("RGB"))

        color = random.choice(self.colors)
        shape = random.choice(self.shapes)
        caption = f"a {color} flower with {shape} petals"
        tokens = self.tokenizer.tokenize(caption, self.max_len)

        return image, tokens

# Transforms
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Dataset & Dataloader
dataset = OxfordFlowersDataset("data/oxford/images", transform, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


# **Text Encoder + U-Net Model**

In [5]:
import torch
import torch.nn as nn

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=256):
        super(TextEncoder, self).__init__()  # ✅ safer & explicit
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)

    def forward(self, tokens):
        x = self.embedding(tokens)
        _, h = self.rnn(x)
        return h.squeeze(0)

class CondUNet(nn.Module):
    def __init__(self, text_dim=256):
        super(CondUNet, self).__init__()  # ✅ explicit
        self.text_proj = nn.Linear(text_dim, 64 * 64)

        self.encoder = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x, text_embed):
        b = x.size(0)
        tmap = self.text_proj(text_embed).view(b, 1, 64, 64)  # project text -> spatial
        x = torch.cat([x, tmap], dim=1)  # concat along channel dimension
        x = self.encoder(x)
        x = self.decoder(x)
        return x


# **Diffusion Scheduler + Training**

In [6]:
from tqdm import tqdm

# Diffusion Schedule
def get_noise_schedule(T=300, device='cpu'):
    beta = torch.linspace(1e-4, 0.02, T, device=device)
    alpha = 1 - beta
    alpha_bar = torch.cumprod(alpha, dim=0)
    return beta, alpha, alpha_bar

T = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
beta, alpha, alpha_bar = get_noise_schedule(device=device)

# Forward Diffusion Step
def forward_diffusion(x0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x0)
    sqrt_alpha_bar = torch.sqrt(alpha_bar[t]).float().unsqueeze(1).unsqueeze(2).unsqueeze(3)
    sqrt_one_minus = torch.sqrt(1 - alpha_bar[t]).float().unsqueeze(1).unsqueeze(2).unsqueeze(3)

    return sqrt_alpha_bar * x0 + sqrt_one_minus * noise, noise

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model & Optimizer
model = CondUNet().to(device)
text_encoder = TextEncoder(len(tokenizer.word2idx)).to(device)
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(text_encoder.parameters()),
    lr=1e-4
)
criterion = nn.MSELoss()

# Training Loop
for epoch in range(100):
    pbar = tqdm(dataloader)
    for images, captions in pbar:
        images = images.to(device)
        captions = captions.to(device)

        # Random time steps
        t = torch.randint(0, T, (images.size(0),), device=device)

        # Forward diffusion
        noise = torch.randn_like(images)
        x_noisy, noise = forward_diffusion(images, t, noise)

        # Conditional prediction
        text_embed = text_encoder(captions)
        pred = model(x_noisy, text_embed)

        # Loss + backprop
        loss = criterion(pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")



Epoch 1 | Loss: 0.1570: 100%|██████████| 256/256 [00:34<00:00,  7.46it/s]
Epoch 2 | Loss: 0.1415: 100%|██████████| 256/256 [00:33<00:00,  7.61it/s]
Epoch 3 | Loss: 0.1414: 100%|██████████| 256/256 [00:33<00:00,  7.60it/s]
Epoch 4 | Loss: 0.0665: 100%|██████████| 256/256 [00:33<00:00,  7.57it/s]
Epoch 5 | Loss: 0.1122: 100%|██████████| 256/256 [00:33<00:00,  7.54it/s]
Epoch 6 | Loss: 0.1272: 100%|██████████| 256/256 [00:34<00:00,  7.50it/s]
Epoch 7 | Loss: 0.1484: 100%|██████████| 256/256 [00:33<00:00,  7.54it/s]
Epoch 8 | Loss: 0.1355: 100%|██████████| 256/256 [00:33<00:00,  7.58it/s]
Epoch 9 | Loss: 0.1144: 100%|██████████| 256/256 [00:33<00:00,  7.59it/s]
Epoch 10 | Loss: 0.0986: 100%|██████████| 256/256 [00:33<00:00,  7.57it/s]
Epoch 11 | Loss: 0.0988: 100%|██████████| 256/256 [00:33<00:00,  7.62it/s]
Epoch 12 | Loss: 0.0858: 100%|██████████| 256/256 [00:33<00:00,  7.59it/s]
Epoch 13 | Loss: 0.0773: 100%|██████████| 256/256 [00:34<00:00,  7.45it/s]
Epoch 14 | Loss: 0.1035: 100%|████

# **Gradio App for Prompt-Based Generation**

In [10]:
%%writefile app.py
import gradio as gr
import torchvision.transforms.functional as TF
from datetime import datetime
import os

# Folder to store temporary outputs
os.makedirs("outputs", exist_ok=True)

@torch.no_grad()
def generate_image_and_download(prompt):
    model.eval()
    text_encoder.eval()

    # Tokenize and embed
    tokens = tokenizer.tokenize(prompt, max_len=20).unsqueeze(0).to(device)
    text_embed = text_encoder(tokens)

    # Start with noise
    x = torch.randn((1, 3, 64, 64)).to(device)

    for t in reversed(range(T)):
        t_tensor = torch.full((1,), t, device=device)
        pred_noise = model(x, text_embed)
        x = (1 / alpha[t].sqrt()) * (x - (beta[t] / (1 - alpha_bar[t]).sqrt()) * pred_noise)
        if t > 0:
            x += beta[t].sqrt() * torch.randn_like(x)

    # Convert to image
    image = x.squeeze(0).cpu().clamp(-1, 1) * 0.5 + 0.5
    image = TF.to_pil_image(image)

    # Save to file with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    file_path = f"outputs/generated_{timestamp}.png"
    image.save(file_path)

    return image, file_path  # Display + Download

gr.Interface(
    fn=generate_image_and_download,
    inputs=gr.Textbox(lines=1, label="🌸 Describe Your Flower", placeholder="e.g. a red flower with round petals"),
    outputs=[
        gr.Image(type="pil", label="🖼️ Generated Image"),
        gr.File(label="📥 Download Image")
    ],
    title="🧠 Oxford Flowers Text-to-Image Diffusion",
    description="Describe a flower and generate a realistic image using a custom diffusion model. You can also download the result.",
    examples=[
        ["a red flower with round petals"],
        ["a yellow flower in sunlight"],
        ["a purple flower with star-shaped petals"]
    ],
    theme="default"
).launch()


Writing app.py


In [14]:
%%writefile requirements.txt
torch
torchvision
gradio


Writing requirements.txt
