In [1]:
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

from tqdm import tqdm
import numpy as np
import pandas as pd
import os
import random
import math
import matplotlib.pyplot as plt
from PIL import Image
import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [5]:
clip_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device)
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
clip_tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')

In [6]:
images_path = '/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images/'
csv_path = '/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv'
captions = pd.read_csv(csv_path, sep='|')
captions.rename(columns={' comment': 'comment'}, inplace=True)
dataset = [(captions.iloc[i]["comment"], os.path.join(images_path, captions.iloc[i]["image_name"])) for i in range(len(captions))]

dataset_sz = len(dataset)
train_sz = int(dataset_sz * 0.6)
val_sz = int(dataset_sz * 0.2)

train_set = dataset[:train_sz]
val_set = dataset[train_sz:train_sz + val_sz]
test_set = dataset[train_sz + val_sz:]

In [7]:
class LDMDataset(Dataset):
    def __init__(self, img_texts, transform=None):
        super().__init__()
        self.img_texts = img_texts
        self.transform = transform

    def __len__(self):
        return len(self.img_texts)

    def __getitem__(self, index):
        caption, img_path = self.img_texts[index]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, caption

In [8]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

train_dataset = LDMDataset(train_set, transform)
val_dataset = LDMDataset(val_set, transform)
test_dataset = LDMDataset(test_set, transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=3, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=3, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=3, pin_memory=True)

## Model Architecture

### VAE

In [9]:
class VAEResBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.norm1 = nn.GroupNorm(32, cin)
        self.act1 = nn.SiLU()
        self.conv1 = nn.Conv2d(cin, cout, kernel_size=3, padding=1)

        self.norm2 = nn.GroupNorm(32, cout)
        self.act2 = nn.SiLU()
        self.conv2 = nn.Conv2d(cout, cout, kernel_size=3, padding=1)

        if cin != cout:
            self.shortcut = nn.Conv2d(cin, cout, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        out = self.conv1(self.act1(self.norm1(x)))
        out = self.conv2(self.act2(self.norm2(out)))
        return out + self.shortcut(x)


class VAEEncoderBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.conv = nn.Conv2d(cin, cout, kernel_size=3, padding=1)
        self.res = VAEResBlock(cout, cout)
        self.down = nn.Conv2d(cout, cout, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        return self.down(self.res(self.conv(x)))


class VAEEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = VAEEncoderBlock(3, 128)
        self.block2 = VAEEncoderBlock(128, 256)
        self.block3 = VAEEncoderBlock(256, 512)

        self.norm = nn.GroupNorm(32, 512)
        self.act = nn.SiLU()
        self.conv = nn.Conv2d(512, 8, kernel_size=3, stride=1, padding=1)

    def forward(self, x): # (B, 3, 512, 512)
        out = self.block1(x) # (B, 128, 256, 256)
        out = self.block2(out) # (B, 256, 128, 128)
        out = self.block3(out) # (B, 512, 64, 64)
        out = self.conv(self.act(self.norm(out))) # (B, 8, 64, 64)

        mu, logvar = torch.chunk(out, chunks=2, dim=1)
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std)
        z = mu + std * noise
        return z, mu, logvar


class VAEDecoderBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.conv = nn.Conv2d(cin, cout, kernel_size=3, padding=1)
        self.res = VAEResBlock(cout, cout)
        self.up = nn.ConvTranspose2d(cout, cout, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        return self.up(self.res(self.conv(x)))


class VAEDecoder(nn.Module):
    def __init__(self):
        super().__init__() # (B, 4, 64, 64)
        self.conv = nn.Conv2d(4, 8, kernel_size=3, padding=1) # (B, 8, 64, 64)
        
        self.block1 = VAEDecoderBlock(8, 512) # (B, 512, 128, 128)
        self.block2 = VAEDecoderBlock(512, 256) # (B, 256, 256, 256)
        self.block3 = VAEDecoderBlock(256, 128) # (B, 128, 512, 512)

        self.conv2 = nn.Conv2d(128, 3, kernel_size=3, padding=1)
        self.act = nn.Tanh()

    def forward(self, x):
        out = self.conv(x)
        out = self.block3(self.block2(self.block1(out)))
        out = self.conv2(out)
        return self.act(out)


class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = VAEEncoder()
        self.decoder = VAEDecoder()

    def forward(self, x):
        z, mu, logvar = self.encoder(x)
        recon = self.decoder(z)
        return recon, mu, logvar

### DDPM

In [10]:
class SinusoidalTimeEmbeddings(nn.Module):
    def __init__(self, cout=1280):
        super().__init__()
        self.cout = cout

    def forward(self, timesteps):
        half_dim = self.cout // 2
        
        timesteps = timesteps.float()
        i = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
        exponent = -math.log(10000) * i / half_dim
        freqs = torch.exp(exponent)
        
        angles = timesteps[:, None] * freqs[None, :]
        emb = torch.zeros(timesteps.size(0), self.cout, device=timesteps.device)
        emb[:, 0::2] = torch.sin(angles)
        emb[:, 1::2] = torch.cos(angles)
        
        return emb


class CrossAttention(nn.Module):
    def __init__(self, cin_img, dim, cin_txt=512, n_heads=8):
        super().__init__()
        self.cin_img = cin_img
        self.cin_txt = cin_txt
        
        self.w_k = nn.Linear(cin_txt, dim)
        self.w_v = nn.Linear(cin_txt, dim)
        self.w_q = nn.Linear(cin_img, dim)

        self.dim = dim
        self.n_heads = n_heads
        self.dim_head = dim // n_heads

        if cin_img == dim:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = nn.Conv2d(cin_img, dim, kernel_size=1)
        self.lin_proj = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x, y, y_mask): # x = image, y = text
        batch_sz, cin, h, w = x.shape # (B, C, H, W)
        l = h*w
        img_out = x.reshape(batch_sz, cin, l).transpose(1, 2) # (B, H*W, C)
        
        q = self.w_q(img_out) # (B, H*W, d_k)
        k = self.w_k(y) # (B, n_max, d_k)
        v = self.w_v(y) # (B, n_max, d_k)

        q = q.reshape(batch_sz, l, self.n_heads, self.dim_head).transpose(1, 2) # (B, n_h, H*W, d_h)
        k = k.reshape(batch_sz, k.size(1), self.n_heads, self.dim_head).transpose(1, 2) # (B, n_h, n_max, d_h)
        v = v.reshape(batch_sz, v.size(1), self.n_heads, self.dim_head).transpose(1, 2) # (B, n_h, n_max, d_h)

        scale = math.sqrt(self.dim_head)
        att = torch.softmax((q @ k.transpose(-2, -1)) / scale, dim=-1)  # (B, n_h, H*W, n_max)
        y_mask = y_mask[:, None, None, :] # (B, 1, 1, n_max)
        att = att.masked_fill_(y_mask == 0, -1e9)
        out = att @ v # (B, n_h, H*W, d_h)
        out = out.transpose(1, 2).reshape(batch_sz, h, w, self.dim) # (B, H, W, D)
        out = out.permute(0, 3, 1, 2) # (B, D, H, W)

        return self.shortcut(x) + self.lin_proj(out) # (B, D, H, W)


class ResLayer(nn.Module):
    def __init__(self, cin, cout, kernel_size=3, stride=1, padding=1,
                 has_norm=True, has_act=True):
        super().__init__()
        self.norm = nn.GroupNorm(min(32, cin), cin) if has_norm else nn.Identity()
        self.act = nn.SiLU() if has_act else nn.Identity()
        self.conv = nn.Conv2d(cin, cout, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        return self.conv(self.act(self.norm(x)))


class ResBlock(nn.Module):
    def __init__(self, cin, cmid, cout, time_emb_dim=1280, downsample=False):
        super().__init__()
        self.res_layer1 = ResLayer(cin, cmid)
        self.mlp = nn.Linear(time_emb_dim, cmid)

        stride = 2 if downsample else 1
        
        self.res_layer2 = ResLayer(cmid, cout, stride=stride)

        if cin != cout or downsample: 
            self.shortcut = nn.Conv2d(cin, cout, kernel_size=1, stride=stride)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, timesteps):
        out = self.res_layer1(x) # (B, Cmid, H, W)
        time_emb = self.mlp(timesteps)[:, :, None, None] # (B, Cmid, 1, 1)
        out = out + time_emb
        out = self.res_layer2(out)
        return self.shortcut(x) + out


class DiffEncoderBlock(nn.Module):
    def __init__(self, cin, cout, has_attn=False,
                 kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.res1 = ResBlock(cin, cin, cin)
        self.res2 = ResBlock(cin, cin, cin)
        self.attn = (lambda x, y, y_mask : x) if not has_attn else CrossAttention(cin_img=cin, dim=cin)
        self.conv = nn.Conv2d(cin, cout, kernel_size=kernel_size,
                              stride=stride, padding=padding)

    def forward(self, x, y, y_mask, time_emb): # x = images, y = text, time_emb = timesteps
        out = self.res1(x, time_emb)
        out = self.res2(out, time_emb)
        out = self.attn(out, y, y_mask)
        return out, self.conv(out)


class DiffEncoder(nn.Module):
    def __init__(self):
        super().__init__() # (B, 4, 64, 64)
        self.conv = nn.Conv2d(4, 320, kernel_size=3, padding=1) # (B, 320, 64, 64)
        self.block4 = DiffEncoderBlock(320, 320, kernel_size=1, stride=1, padding=0) # (B, 320, 64, 64)
        self.block3 = DiffEncoderBlock(320, 640) # (B, 640, 32, 32)
        self.block2 = DiffEncoderBlock(640, 960, has_attn=True) # (B, 960, 16, 16)
        self.block1 = DiffEncoderBlock(960, 1280, has_attn=True) # (B, 1280, 8, 8)
        
    def forward(self, x, y, y_mask, time_emb):
        out = self.conv(x)
        pre_pool4, out = self.block4(out, y, y_mask, time_emb)
        pre_pool3, out = self.block3(out, y, y_mask, time_emb)
        pre_pool2, out = self.block2(out, y, y_mask, time_emb)
        pre_pool1, out = self.block1(out, y, y_mask, time_emb)
        return pre_pool1, pre_pool2, pre_pool3, pre_pool4, out


class DiffBottleneck(nn.Module):
    def __init__(self, cin=1280):
        super().__init__()
        self.res1 = ResBlock(cin, cin, cin)
        self.res2 = ResBlock(cin, cin, cin)
        self.attn = CrossAttention(cin_img=cin, dim=cin)

    def forward(self, x, y, y_mask, time_emb):
        return self.res2(self.attn(self.res1(x, time_emb), y, y_mask), time_emb)


class DiffDecoderBlock(nn.Module):
    def __init__(self, cin, cout, has_attn=False,
                 kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.up = nn.ConvTranspose2d(cin, cin, kernel_size=kernel_size,
                                     stride=stride, padding=padding)
        self.res1 = ResBlock(cin + cout, cin + cout, cout)
        self.res2 = ResBlock(cout, cout, cout)
        self.attn = (lambda x, y, y_mask: x) if not has_attn else CrossAttention(cin_img=cout, dim=cout)

    def forward(self, x, y, y_mask, time_emb, skip):
        out = self.up(x)
        out = torch.cat([out, skip], dim=1)
        out = self.res1(out, time_emb)
        out = self.attn(out, y, y_mask)
        return self.res2(out, time_emb)


class DiffDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = DiffDecoderBlock(1280, 960, has_attn=True, kernel_size=4, stride=2)
        self.block2 = DiffDecoderBlock(960, 640, has_attn=True, kernel_size=4, stride=2)
        self.block3 = DiffDecoderBlock(640, 320, kernel_size=4, stride=2)
        self.block4 = DiffDecoderBlock(320, 320)

        self.norm = nn.GroupNorm(32, 320)
        self.act = nn.SiLU()
        self.conv = nn.Conv2d(320, 4, kernel_size=3, padding=1)

    def forward(self, x, y, y_mask, time_emb, pre_pool1, pre_pool2, pre_pool3, pre_pool4):
        out = self.block1(x, y, y_mask, time_emb, pre_pool1)
        out = self.block2(out, y, y_mask, time_emb, pre_pool2)
        out = self.block3(out, y, y_mask, time_emb, pre_pool3)
        out = self.block4(out, y, y_mask, time_emb, pre_pool4)
        return self.conv(self.act(self.norm(out)))


class LDMDiffusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_emb = SinusoidalTimeEmbeddings()
        self.encoder = DiffEncoder()
        self.bottleneck = DiffBottleneck()
        self.decoder = DiffDecoder()
        

    def forward(self, x, y, y_mask, time):
        time_emb = self.time_emb(time)
        pre_pool1, pre_pool2, pre_pool3, pre_pool4, out = self.encoder(x, y, y_mask, time_emb)
        out = self.bottleneck(out, y, y_mask, time_emb)
        out = self.decoder(out, y, y_mask, time_emb,
                           pre_pool1, pre_pool2, pre_pool3, pre_pool4)
        return out

## Training

### VAE training

In [11]:
def vae_loss_function(recon, x, mu, logvar):
    recon_loss = F.mse_loss(recon, x, reduction='mean')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kl_div

In [12]:
vae = VAE().to(device)
dummy = torch.randn(1, 3, 256, 256).to(device)
recon, mu, logvar = vae(dummy)

In [None]:
vae_model_path = '/kaggle/working/vae.pth'

vae = VAE().to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=2e-4)
scaler = GradScaler()

num_epochs = 20
best_val_loss = float('inf')
patience = 4
num_cons_bad_epochs = 0

for epoch in range(num_epochs):
    torch.cuda.empty_cache()

    train_loss = 0.0
    train_total = 0

    print(f"\nEpoch [{epoch+1}/{num_epochs}]")

    vae.train()
    train_bar = tqdm(train_loader, desc="Training", leave=False)
    for imgs, _ in train_bar:
        imgs = imgs.to(device)

        optimizer.zero_grad()
        with autocast(device_type=device):
            recon, mu, logvar = vae(imgs)
            loss = vae_loss_function(recon, imgs, mu, logvar)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)
        train_total += imgs.size(0)
        train_bar.set_postfix(loss=loss.item())

    vae.eval()
    val_loss = 0.0
    val_total = 0
    val_bar = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for imgs, _ in val_bar:
            imgs = imgs.to(device)
            with autocast(device_type=device):
                recon, mu, logvar = vae(imgs)
                loss = vae_loss_function(recon, imgs, mu, logvar)

            val_loss += loss.item() * imgs.size(0)
            val_total += imgs.size(0)
            val_bar.set_postfix(loss=loss.item())

    avg_train = train_loss / train_total
    avg_val   = val_loss / val_total
    print(f"Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        num_cons_bad_epochs = 0
        torch.save(vae.state_dict(), vae_model_path)
        print("Model saved.")
    else:
        num_cons_bad_epochs += 1
        if num_cons_bad_epochs == patience:
            print("Early stopping.")
            break


In [None]:
vae = VAE().to(device)
vae.load_state_dict(torch.load(vae_model_path))
vae.eval()

test_loss = 0.0
test_total = 0
test_bar = tqdm(test_loader, desc="Test", leave=False)
with torch.no_grad():
    for imgs, _ in test_bar:
        imgs = imgs.to(device)
        with autocast(device_type=device):
            recon, mu, logvar = vae(imgs)
            loss = vae_loss_function(recon, imgs, mu, logvar)

        test_loss += loss.item() * imgs.size(0)
        test_total += imgs.size(0)
        test_bar.set_postfix(loss=loss.item())

avg_test = test_loss / test_total
print(f"Test Loss: {avg_test:.4f}")


### DDPM training

In [None]:
def beta_linear_schedule(beta_start=1e-4, beta_end=0.02, steps=1000):
    return torch.linspace(beta_start, beta_end, steps)
    # tensor([0, 0.25, 0.5, 0.75, 1])

def get_alpha_cumprods(betas):
    alphas = 1. - betas
    alpha_bars = torch.cumprod(alphas, dim=0)
    return alpha_bars

def noisify_image(x0, t, alpha_bars):
    # x0 --> 1, c, h, w or h, w, c
    sqrt_alpha_bar = alpha_bars[t].sqrt()
    noise = torch.randn_like(x0)
    noisy_img = sqrt_alpha_bar * x0 + (1 - sqrt_alpha_bar) * noise
    return noisy_img

def noisify_images(x0, t, alpha_bars):
    # x0 --> b, c, h, w
    sqrt_alpha_bar = alpha_bars[t].sqrt().view(-1, 1, 1, 1) # (B, 1, 1, 1)
    sqrt_one_minus = (1 - alpha_bars[t]).sqrt().view(-1, 1, 1, 1) # (B, 1, 1, 1)
    noise = torch.randn_like(x0)
    return sqrt_alpha_bar * x0 + sqrt_one_minus * noise

In [None]:
T = 1000
betas = beta_linear_schedule(0.0001, 0.02, T).to(device)
alpha_bars = get_alpha_cumprods(betas).to(device)

ddpm_model_path = '/kaggle/working/ddpm.pth'
vae_model_path = '/kaggle/working/vae.pth'

vae = VAE().to(device)
vae.load_state_dict(torch.load(vae_model_path))
vae.eval()

ddpm = LDMDiffusion().to(device)
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
scaler = GradScaler()

num_epochs = 20
best_val_loss = float('inf')
patience = 4
num_cons_bad_epochs = 0

for epoch in range(num_epochs):
    torch.cuda.empty_cache()
    
    train_loss = 0.0
    train_total = 0

    ddpm.train()
    for imgs, captions in train_loader:
        batch_size = imgs.size(0)
        imgs = imgs.to(device)
    
        with torch.no_grad():
            z, _, _ = vae.encoder(imgs)
        
        captions = ["" if random.random() < 0.1 else cap for cap in captions]
    
        caption_inputs = clip_tokenizer(captions, padding=True, truncation=True,
                                        return_tensors="pt", return_attention_mask=True)
        caption_inputs = {k: v.to(device) for k, v in caption_inputs.items()}
        caption_outputs = clip_model.text_model(**caption_inputs, output_hidden_states=True)
        text_emb = caption_outputs.last_hidden_state
        text_att_mask = caption_inputs["attention_mask"]
    
        timesteps = torch.randint(low=0, high=T, size=(batch_size,), device=device).long()
        noise = torch.randn_like(z)
        alpha_t = alpha_bars[timesteps].view(-1, 1, 1, 1)
        sqrt_alpha = alpha_t.sqrt()
        sqrt_one_minus = (1 - alpha_t).sqrt()
        noisy_z = sqrt_alpha * z + sqrt_one_minus * noise

        with autocast(device_type=device):
            pred_noise = ddpm(noisy_z, text_emb, text_att_mask, timesteps)
            loss = F.mse_loss(pred_noise, noise)
            
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    
        train_loss += loss.item() * batch_size
        train_total += batch_size

    val_loss = 0.0
    val_total = 0
    
    ddpm.eval()
    for imgs, captions in val_loader:
        with torch.no_grad():
            batch_size = imgs.size(0)
            imgs = imgs.to(device)

            with torch.no_grad():
                z, _, _ = vae.encoder(imgs)

            captions = ["" if random.random() < 0.1 else caption for caption in captions]
    
            caption_inputs = clip_tokenizer(captions, padding=True, truncation=True,
                                            return_tensors="pt", return_attention_mask=True)
            caption_inputs = {k: v.to(device) for k, v in caption_inputs.items()}
            caption_outputs = clip_model.text_model(**caption_inputs, output_hidden_states=True)
            text_emb = caption_outputs.last_hidden_state
            text_att_mask = caption_inputs["attention_mask"]
        
            timesteps = torch.randint(low=0, high=T, size=(batch_size,), device=device).long()
            noise = torch.randn_like(z)
            alpha_t = alpha_bars[timesteps].view(-1, 1, 1, 1)
            sqrt_alpha = alpha_t.sqrt()
            sqrt_one_minus = (1 - alpha_t).sqrt()
            noisy_z = sqrt_alpha * z + sqrt_one_minus * noise

            with autocast(device_type=device):
                pred_noise = ddpm(noisy_z, text_emb, text_att_mask, timesteps)
                loss = F.mse_loss(pred_noise, noise)

            val_loss += loss.item() * batch_size
            val_total += batch_size

    avg_train = train_loss / train_total
    avg_val   = val_loss / val_total
    print(f"Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

    if avg_val < best_val_loss:
        best_val_loss = avg_val
        num_cons_bad_epochs = 0
        torch.save(ddpm.state_dict(), ddpm_model_path)
        print("Model saved.")
    else:
        num_cons_bad_epochs += 1
        if num_cons_bad_epochs == patience:
            print("Early stopping.")
            break

In [None]:
ddpm = LDMDiffusion().to(device)
vae = VAE().to(device)
ddpm.load_state_dict(torch.load(ddpm_model_path))
vae.load_state_dict(torch.load(vae_model_path))
ddpm.eval()
vae.eval()

test_loss = 0.0
test_total = 0

for imgs, captions in test_loader:
    with torch.no_grad():
        batch_size = imgs.size(0)
        imgs = imgs.to(device)

        z, _, _ = vae.encoder(imgs)
        captions = list(captions)

        caption_inputs = clip_tokenizer(captions, padding=True, truncation=True,
                                        return_tensors="pt", return_attention_mask=True)
        caption_inputs = {k: v.to(device) for k, v in caption_inputs.items()}
        caption_outputs = clip_model.text_model(**caption_inputs, output_hidden_states=True)
        text_emb = caption_outputs.last_hidden_state
        text_att_mask = caption_inputs["attention_mask"]

        timesteps = torch.randint(low=0, high=T, size=(batch_size,), device=device).long()
        noise = torch.randn_like(z)
        alpha_t = alpha_bars[timesteps].view(-1, 1, 1, 1)
        sqrt_alpha = alpha_t.sqrt()
        sqrt_one_minus = (1 - alpha_t).sqrt()
        noisy_z = sqrt_alpha * z + sqrt_one_minus * noise

        with autocast(device_type=device):
            pred_noise = ddpm(noisy_z, text_emb, text_att_mask, timesteps)
            loss = F.mse_loss(pred_noise, noise)

        test_loss += loss.item() * batch_size
        test_total += batch_size

avg_test = test_loss / test_total
print(f"Test Loss: {avg_test:.4f}")

## Generate

### Classifier-Free Guidance (CFG)

In [None]:
ddpm_model_path = '/kaggle/working/ddpm.pth'
vae_model_path = '/kaggle/working/vae.pth'

vae = VAE().to(device)
ddpm = LDMDiffusion().to(device)
vae.load_state_dict(torch.load(vae_model_path))
ddpm.load_state_dict(torch.load(ddpm_model_path))
vae.eval()
ddpm.eval()

cfg_scale = 3.0

batch_size = 1
caption = ["purple cat on a windowsill"]
latent_shape = (batch_size, 4, 64, 64)
curr_state = torch.randn(latent_shape).to(device)

caption_inputs = clip_tokenizer(caption, padding=True, truncation=True, return_tensors="pt")
caption_inputs = {k: v.to(device) for k, v in caption_inputs.items()}
caption_outputs = clip_model.text_model(**caption_inputs, output_hidden_states=True)
text_emb = caption_outputs.last_hidden_state
text_att_mask = caption_inputs["attention_mask"]

null_inputs = clip_tokenizer([""], padding=True, truncation=True, return_tensors="pt")
null_inputs = {k: v.to(device) for k, v in null_inputs.items()}
null_outputs = clip_model.text_model(**null_inputs, output_hidden_states=True)
null_text_emb = null_outputs.last_hidden_state
null_text_att_mask = null_inputs["attention_mask"]

with torch.no_grad():
    for t in range(T - 1, -1, -1):
        timestep = torch.full((batch_size,), t, device=device, dtype=torch.long)

        alpha_bar_t = alpha_bars[t].view(1, 1, 1, 1)
        alpha_t = alphas[t].view(1, 1, 1, 1)
        beta_t = betas[t].view(1, 1, 1, 1)

        cond_noise = ddpm(curr_state, text_emb, text_att_mask, timestep)
        uncond_noise = ddpm(curr_state, null_text_emb, null_text_att_mask, timestep)
        pred_noise = uncond_noise + cfg_scale * (cond_noise - uncond_noise)
        
        coef = beta_t / torch.sqrt(1 - alpha_bar_t)
        mu_t = (1 / torch.sqrt(alpha_t)) * (curr_state - coef * pred_noise)

        if t > 0:
            eps = torch.randn_like(curr_state)
            std_t = torch.sqrt(beta_t)
            curr_state = mu_t + std_t * eps
        else:
            curr_state = mu_t


recon_img = vae.decoder(curr_state)
recon_img = recon_img.clamp(0, 1)
img = recon_img.squeeze(0).detach().cpu()
img = T.ToPILImage()(img)

plt.imshow(img)
plt.axis("off")
plt.title("Generated Image")
plt.show()