In [1]:
# 03_generator_train.ipynb
# Conditional diffusion over wireframe images, conditioned on UI embeddings

import os, math, json, random, numpy as np, pandas as pd
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils as vutils

# device
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

# paths
SEM_DIR = Path("data/rico_semantic_annotations")      # wireframe-like images live here (your semantic .jpg/.jpeg)
EMB_PATHS = [
    Path("embeddings/semantic_vit_gnn.pkl"),          # prefer pkl from 02 notebook (fast, robust)
    Path("embeddings/semantic_vit_gnn.parquet"),      # if parquet worked for you
]

# image & train config (keep small to start)
IMG_SIZE = 256
BATCH_SIZE = 16
NUM_WORKERS = 0
EPOCHS = 3           # sanity-run; bump later
LR = 1e-4
SAVE_DIR = Path("runs/generator")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

Running on: mps


In [2]:
# load embeddings (pkl recommended; fall back to parquet)
if EMB_PATHS[0].exists():
    df_emb = pd.read_pickle(EMB_PATHS[0])
elif EMB_PATHS[1].exists():
    df_emb = pd.read_parquet(EMB_PATHS[1], engine="pyarrow")
else:
    raise FileNotFoundError("No embeddings file found.")

# ensure arrays
def to_np(x):
    if isinstance(x, list): return np.array(x, dtype=np.float32)
    if isinstance(x, np.ndarray): return x.astype(np.float32)
    return np.array(x, dtype=np.float32)

df_emb["vision_emb"] = df_emb["vision_emb"].apply(to_np)
df_emb["graph_emb"]  = df_emb["graph_emb"].apply(to_np)

# build {id -> conditioning vector}
id_to_cond = {}
for _, row in df_emb.iterrows():
    emb = np.concatenate([row["vision_emb"], row["graph_emb"]], axis=0)
    id_to_cond[str(row["id"])] = emb

cond_dim = next(iter(id_to_cond.values())).shape[0]
print("Conditioning dim:", cond_dim)

# build list of (id, img_path) that exist on disk
records = []
for sid in id_to_cond.keys():
    p = SEM_DIR / f"{sid}.png"
    if not p.exists():
        p = SEM_DIR / f"{sid}.jpg"
    if p.exists():
        records.append((sid, str(p)))
print(f"Usable pairs: {len(records)}")

Conditioning dim: 448
Usable pairs: 66261


In [3]:
class WireframeDataset(Dataset):
    def __init__(self, records, id_to_cond, img_size=256):
        self.recs = records
        self.id2c = id_to_cond
        self.tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),                 # [0,1]
        ])
    def __len__(self):
        return len(self.recs)
    def __getitem__(self, idx):
        sid, img_path = self.recs[idx]
        img = Image.open(img_path).convert("RGB")
        x = self.tf(img)                           # (3,H,W), range [0,1]
        # scale to [-1,1] for diffusion
        x = x * 2. - 1.
        cond = torch.from_numpy(self.id2c[sid]).float()
        return x, cond

# split
random.shuffle(records)
n = len(records)
train_recs = records[: int(0.9*n)]
val_recs   = records[int(0.9*n):]

ds_train = WireframeDataset(train_recs, id_to_cond, IMG_SIZE)
ds_val   = WireframeDataset(val_recs, id_to_cond, IMG_SIZE)
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
dl_val   = DataLoader(ds_val,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, drop_last=False)

len(ds_train), len(ds_val)

(59634, 6627)

In [4]:
class ConvBlock(nn.Module):
    """
    A residual convolutional block with FiLM-style conditioning.

    Each block performs:
        x -> GroupNorm -> Conv -> FiLM (scale + shift)
           -> GroupNorm -> Conv -> FiLM
           -> skip connection (residual)
           -> SiLU activation

    FiLM conditioning:
        For each conditioning vector c (the fused UI embedding),
        we learn a linear projection that outputs gamma/beta (scale/shift)
        parameters to modulate the normalization layers.

    Args:
        in_ch (int): number of input channels
        out_ch (int): number of output channels
        cond_dim (int): dimension of conditioning vector
    """
    def __init__(self, in_ch, out_ch, cond_dim):
        super().__init__()
        # Normalization layers
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)

        # Convolutions
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)

        # Optional 1×1 conv for residual alignment if channels differ
        self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

        # FiLM modulation: linear layers generate gamma/beta for each norm
        self.film1 = nn.Linear(cond_dim, in_ch * 2)   # -> (gamma1, beta1)
        self.film2 = nn.Linear(cond_dim, out_ch * 2)  # -> (gamma2, beta2)

    def forward(self, x, c):
        """
        Forward pass with FiLM modulation.

        Args:
            x (Tensor): [B, C_in, H, W] input feature map
            c (Tensor): [B, cond_dim] conditioning vector

        Returns:
            Tensor of shape [B, C_out, H, W]
        """
        # --- First conv + FiLM ---
        g1, b1 = self.film1(c).chunk(2, dim=-1)  # split into gamma, beta
        g1, b1 = g1.unsqueeze(-1).unsqueeze(-1), b1.unsqueeze(-1).unsqueeze(-1)

        h = self.norm1(x) * (1 + g1) + b1        # apply FiLM modulation
        h = F.silu(self.conv1(h))                # SiLU activation (swish)

        # --- Second conv + FiLM ---
        g2, b2 = self.film2(c).chunk(2, dim=-1)
        g2, b2 = g2.unsqueeze(-1).unsqueeze(-1), b2.unsqueeze(-1).unsqueeze(-1)

        h = self.norm2(h) * (1 + g2) + b2
        h = self.conv2(h)

        # --- Residual connection ---
        out = F.silu(h + self.skip(x))
        return out

class TinyUNet(nn.Module):
    """
    A minimal U-Net architecture conditioned on UI embeddings.

    The U-Net takes an input image x_t (a noisy wireframe)
    and predicts the noise to denoise it at each timestep,
    conditioned on the latent UI code c.

    Architecture outline:
        Input (3×H×W)
          ↓
        [Down path]
            Conv(3→base)
            ↓
            ConvBlock(base→2×base)
            ↓
            AvgPool (downsample 2×)
            ↓
            ConvBlock(2×base→4×base)
            ↓
            AvgPool (downsample 2×)
          ↓
        [Bottleneck]
            ConvBlock(4×base→4×base)
          ↓
        [Up path]
            ConvTranspose(4×base→2×base)
            Concatenate skip (from 2×base)
            ConvBlock((2+4)×base→2×base)   # after concat: 6×base in
            ↓
            ConvTranspose(2×base→base)
            Concatenate skip (from base)
            ConvBlock((1+2)×base→base)      # after concat: 3×base in
          ↓
        Output
            Conv(base→3)
            → final RGB prediction

    Args:
        cond_dim (int): conditioning vector dimension (vision+graph embedding)
        base (int): number of base feature channels (default: 32)
    """
    def __init__(self, cond_dim, base=32):
        super().__init__()

        # --- Encoder (down path) ---
        self.inp = nn.Conv2d(3, base, kernel_size=3, padding=1)
        self.down1 = ConvBlock(base, base * 2, cond_dim)  # 32→64
        self.pool1 = nn.AvgPool2d(2)                      # 256→128
        self.down2 = ConvBlock(base * 2, base * 4, cond_dim)  # 64→128
        self.pool2 = nn.AvgPool2d(2)                      # 128→64

        # --- Bottleneck ---
        self.mid = ConvBlock(base * 4, base * 4, cond_dim)  # 128→128

        # --- Decoder (up path) ---
        self.up1 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
        # FIX: input to upblk1 = upsampled(2×base) + skip(4×base)
        self.upblk1 = ConvBlock(base * 2 + base * 4, base * 2, cond_dim)

        self.up2 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
        # FIX: input to upblk2 = upsampled(base) + skip(2×base)
        self.upblk2 = ConvBlock(base + base * 2, base, cond_dim)

        # --- Final output ---
        self.outp = nn.Conv2d(base, 3, kernel_size=3, padding=1)

    def forward(self, x, c):
        """
        Forward pass through the conditional U-Net.

        Args:
            x (Tensor): [B, 3, H, W] input image (noisy wireframe)
            c (Tensor): [B, cond_dim] conditioning embedding

        Returns:
            [B, 3, H, W] predicted noise (same shape as input)
        """
        # --- Downsampling path ---
        x1 = F.silu(self.inp(x))       # shape: [B, base, H, W]
        x2 = self.down1(x1, c)         # [B, 2×base, H, W]
        p1 = self.pool1(x2)            # ↓ H/2
        x3 = self.down2(p1, c)         # [B, 4×base, H/2, W/2]
        p2 = self.pool2(x3)            # ↓ H/4

        # --- Bottleneck ---
        m = self.mid(p2, c)            # [B, 4×base, H/4, W/4]

        # --- Upsampling path ---
        u1 = self.up1(m)               # [B, 2×base, H/2, W/2]
        u1 = torch.cat([u1, x3], dim=1)  # concat skip: [B, (2+4)×base, H/2, W/2]
        u1 = self.upblk1(u1, c)

        u2 = self.up2(u1)              # [B, base, H, W]
        u2 = torch.cat([u2, x2], dim=1)  # concat skip: [B, (1+2)×base, H, W]
        u2 = self.upblk2(u2, c)

        # --- Output ---
        out = self.outp(u2)
        return out

In [5]:
unet = TinyUNet(cond_dim=cond_dim, base=32)
x = torch.randn(2, 3, 256, 256)
c = torch.randn(2, cond_dim)
y = unet(x, c)
print("Output:", y.shape)

Output: torch.Size([2, 3, 256, 256])


In [6]:
class DDPM(nn.Module):
    """
    Denoising Diffusion Probabilistic Model (DDPM)
    ------------------------------------------------
    Wraps the conditional U-Net denoiser into a diffusion process
    that learns to reverse a fixed Gaussian noise schedule.

    At a high level:
        q(x_t | x_{t-1}) = N(√α_t * x_{t-1}, β_t * I)
        p_θ(x_{t-1} | x_t, c) = N(μ_θ(x_t, t, c), σ_t^2 * I)

    The model learns to predict the noise ε added at each step t.

    Args:
        model (nn.Module): the denoiser network (TinyUNet)
        timesteps (int): number of diffusion steps T
        beta_start, beta_end (float): noise schedule parameters
    """
    def __init__(self, model, timesteps=300, beta_start=1e-4, beta_end=0.02):
        super().__init__()
        self.model = model
        self.T = timesteps

        # --- Define linear beta schedule (variance schedule) ---
        # βₜ controls how much noise is added at each timestep
        betas = torch.linspace(beta_start, beta_end, timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)  # ᾱₜ = ∏ₛ₌₁ᵗ αₛ

        # --- Register buffers so they are saved on device with model ---
        self.register_buffer("betas", betas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("alphas_cumprod_prev",
            torch.cat([torch.tensor([1.], device=betas.device), alphas_cumprod[:-1]])
        )

    # ------------------------------------------------------------------
    # Forward diffusion (q) — add noise to clean image
    # ------------------------------------------------------------------
    def q_sample(self, x0, t, noise=None):
        """
        Diffuse the data (add Gaussian noise) for a given timestep t.

        Args:
            x0 (Tensor): clean image, shape [B, 3, H, W]
            t (Tensor): timestep indices, shape [B]
            noise (Tensor): optional external noise, same shape as x0

        Returns:
            x_t (Tensor): noisy image at timestep t
            noise (Tensor): the exact noise used (for supervision)
        """
        if noise is None:
            noise = torch.randn_like(x0)

        # Gather ᾱₜ for each sample in batch
        ac = self.alphas_cumprod[t].view(-1, 1, 1, 1)
        x_t = torch.sqrt(ac) * x0 + torch.sqrt(1 - ac) * noise
        return x_t, noise

    # ------------------------------------------------------------------
    # Training objective — predict the noise ε
    # ------------------------------------------------------------------
    def p_losses(self, x0, cond, t):
        """
        Compute the DDPM training loss for a given batch.

        Steps:
            1. Sample random noise ε ~ N(0, I)
            2. Add it to x0 using q_sample to get x_t
            3. Predict ε̂ = model(x_t, cond)
            4. Loss = MSE(ε̂, ε)

        Args:
            x0 (Tensor): clean input images
            cond (Tensor): conditioning embeddings
            t (Tensor): random timestep per sample
        """
        x_t, noise = self.q_sample(x0, t)
        noise_pred = self.model(x_t, cond)
        return F.mse_loss(noise_pred, noise)

    # ------------------------------------------------------------------
    # Reverse diffusion step — one iteration of sampling
    # ------------------------------------------------------------------
    @torch.no_grad()
    def p_sample(self, x, cond, t):
        """
        Sample x_{t-1} given x_t using the model’s predicted noise.

        Args:
            x (Tensor): current noisy image [B, 3, H, W]
            cond (Tensor): conditioning embeddings
            t (Tensor): current timestep index

        Returns:
            x_prev (Tensor): denoised image for t-1
        """
        betat = self.betas[t].view(-1, 1, 1, 1)
        ac = self.alphas_cumprod[t].view(-1, 1, 1, 1)
        ac_prev = self.alphas_cumprod_prev[t].view(-1, 1, 1, 1)

        # Predict noise ε̂ = model(x_t, cond)
        eps = self.model(x, cond)

        # Compute posterior mean μ_θ(x_t, t, c)
        mean = (1 / torch.sqrt(1 - betat)) * (x - betat / torch.sqrt(1 - ac) * eps)

        if (t == 0).all():
            # last step — no noise added
            return mean

        # Compute variance σₜ
        z = torch.randn_like(x)
        sigma = torch.sqrt((1 - ac_prev) / (1 - ac) * betat)
        return mean + sigma * z

    # ------------------------------------------------------------------
    # Full sampling loop — generate an image from pure noise
    # ------------------------------------------------------------------
    @torch.no_grad()
    def sample(self, cond, shape):
        """
        Generate samples x₀ from pure Gaussian noise.

        Algorithm:
            x_T ~ N(0, I)
            for t = T-1 ... 0:
                x_t = p_sample(x_{t+1}, t+1)
            return x_0

        Args:
            cond (Tensor): [B, cond_dim] conditioning vectors
            shape (tuple): image shape (C, H, W)

        Returns:
            Tensor of shape [B, C, H, W] (denoised samples)
        """
        b = cond.size(0)
        x = torch.randn(b, *shape, device=cond.device)
        for t in reversed(range(self.T)):
            tt = torch.full((b,), t, device=cond.device, dtype=torch.long)
            x = self.p_sample(x, cond, tt)
        return x

In [7]:
# -------------------------------------------------------------
# Training utilities for DDPM model
# -------------------------------------------------------------
# -------------------------------------------------------------
# 1. Utility: save a grid of generated images
# -------------------------------------------------------------
def save_grid(tensors, path, nrow=8, normalize=True, value_range=(-1, 1)):
    """
    Save a batch of tensors as an image grid.
    Args:
        tensors: [B, 3, H, W] in range [-1, 1]
        path: destination Path for PNG
        nrow: number of images per row
    """
    grid = vutils.make_grid(
        tensors, nrow=nrow, normalize=normalize, value_range=value_range
    )
    path.parent.mkdir(exist_ok=True, parents=True)
    vutils.save_image(grid, path)


# -------------------------------------------------------------
# 2. Utility: one training step
# -------------------------------------------------------------
def train_one_batch(ddpm, optimizer, x, c, device):
    """
    Perform one gradient step on a batch.
    Returns:
        scalar loss value
    """
    x = x.to(device)
    c = c.to(device)
    t = torch.randint(0, ddpm.T, (x.size(0),), device=device, dtype=torch.long)
    loss = ddpm.p_losses(x, c, t)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    return loss.item()


# -------------------------------------------------------------
# 3. Utility: sample images for qualitative monitoring
# -------------------------------------------------------------
@torch.no_grad()
def sample_grid(ddpm, cond_batch, img_size, step_name, save_dir, nrow=4):
    """
    Generate a small batch of images for visual progress monitoring.
    """
    ddpm.eval()
    cond_batch = cond_batch.to(next(ddpm.parameters()).device)
    imgs = ddpm.sample(cond_batch, shape=(3, img_size, img_size))
    out_path = Path(save_dir) / f"samples_{step_name}.png"
    save_grid(imgs, out_path, nrow=nrow)
    ddpm.train()
    return out_path


# -------------------------------------------------------------
# 4. Utility: validate on held-out data
# -------------------------------------------------------------
@torch.no_grad()
def validate(ddpm, dl_val, img_size, save_dir, epoch):
    """
    Generate validation samples at the end of an epoch.
    """
    ddpm.eval()
    batch = next(iter(dl_val))
    x_val, c_val = batch[0].to(next(ddpm.parameters()).device), batch[1].to(
        next(ddpm.parameters()).device
    )
    x_gen = ddpm.sample(c_val[:8], shape=(3, img_size, img_size))
    out_path = Path(save_dir) / f"samples_epoch{epoch}.png"
    save_grid(x_gen, out_path, nrow=4)
    print(f"✅ Saved validation samples for epoch {epoch} → {out_path}")
    ddpm.train()


In [8]:
# -------------------------------------------------------------
# 5. Main training function
# -------------------------------------------------------------
def train_ddpm(
    ddpm,
    dataloaders,
    optimizer,
    device,
    save_dir,
    epochs=10,
    img_size=256,
    log_every=400,
):
    """
    Full training loop for DDPM.

    Args:
        ddpm: diffusion model (with U-Net inside)
        dataloaders: (dl_train, dl_val)
        optimizer: optimizer (AdamW)
        device: "mps" / "cuda" / "cpu"
        save_dir: where to store checkpoints & samples
        epochs: number of epochs
        img_size: image resolution
        log_every: steps between sample generations
    """
    dl_train, dl_val = dataloaders
    global_step = 0
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)

    for epoch in range(1, epochs + 1):
        ddpm.train()
        pbar = tqdm(dl_train, desc=f"Epoch {epoch}/{epochs}")
        epoch_loss = []

        for x, c in pbar:
            loss = train_one_batch(ddpm, optimizer, x, c, device)
            epoch_loss.append(loss)
            global_step += 1
            pbar.set_postfix(loss=f"{loss:.5f}")

            # periodic qualitative sampling
            if global_step % log_every == 0:
                out_path = sample_grid(ddpm, c[:8], img_size, f"step{global_step}", save_dir)
                print(f"🖼️ Sample saved: {out_path}")

        mean_loss = np.mean(epoch_loss)
        print(f"\n📉 Epoch {epoch} mean loss: {mean_loss:.6f}")

        # validation image + checkpoint
        validate(ddpm, dl_val, img_size, save_dir, epoch)
        ckpt_path = save_dir / f"ddpm_epoch{epoch}.pt"
        torch.save(ddpm.state_dict(), ckpt_path)
        print(f"💾 Saved model checkpoint → {ckpt_path}")

In [9]:
unet = TinyUNet(cond_dim=cond_dim, base=32).to(device)
ddpm = DDPM(unet, timesteps=300).to(device)
opt = torch.optim.AdamW(ddpm.parameters(), lr=LR)

train_ddpm(
    ddpm=ddpm,
    dataloaders=(dl_train, dl_val),
    optimizer=opt,
    device=device,
    save_dir=SAVE_DIR,
    epochs=EPOCHS,
    img_size=IMG_SIZE,
    log_every=400,
)

Epoch 1/3:   3%|█▌                                                       | 103/3727 [01:13<43:05,  1.40it/s, loss=0.19841]


KeyboardInterrupt: 