# Diffusion Transformer

## Libraries

In [None]:
import os
import gc
import math
import numpy as np
from pathlib import Path
from itertools import cycle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm.auto import tqdm
from einops import rearrange

import matplotlib.pyplot as plt

: 

## Data preprocess

In [None]:
class CustomDataset(Dataset):
    def __init__(self, filename, image_size):
        self.sprites = np.load(filename)
        print(f"sprite shape: {self.sprites.shape}")
        self.transform = transforms.Compose([
            # T.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)),
            transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            transforms.ToTensor(), # from [0, 255] to range [0.0,1.0]
            transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
        ])
        self.image_size = image_size
        self.sprites_shape = self.sprites.shape
                
    # Return the number of images in the dataset
    def __len__(self):
        return len(self.sprites)
    
    # Get the image and label at a given index
    def __getitem__(self, idx):
        # Return the image and label as a tuple
        if self.transform:
            image = self.transform(self.sprites[idx])
        return (image)

## Noise Generation

In [None]:
def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

# e.g. from Nichol & Dhariwal’s “improved DDPM”:
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas.clamp(max=0.999)

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

## network modules

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=8, emb_dim=512, img_size=64):
        super().__init__()
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.img_size = img_size
        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embedding = nn.Parameter(torch.randn((img_size // patch_size)**2, emb_dim))

    def forward(self, x):
        x = self.proj(x)  # [B, emb_dim, H/patch, W/patch]
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x + self.pos_embedding

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = torch.exp(
            torch.arange(half_dim, device=device) * -(math.log(10000) / (half_dim - 1))
        )
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb  # [B, dim]

class TransformerBackbone(nn.Module):
    def __init__(self, in_channels=3, img_size=64, patch_size=8, emb_dim=512, depth=6, num_heads=8, time_emb_dim=128):
        super().__init__()
        self.embed = PatchEmbedding(in_channels, patch_size, emb_dim, img_size)

        self.time_embedding = SinusoidalPosEmb(time_emb_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim)
        )

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=1024, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.output_proj = nn.Sequential(
            nn.Linear(emb_dim, patch_size * patch_size * in_channels),
        )
        self.patch_size = patch_size
        self.img_size = img_size
        self.channels = in_channels

    def forward(self, x, t):
        B = x.size(0)
        patches = self.embed(x)  # [B, N, D]

        t_emb = self.time_embedding(t) # [B, time_emb_dim]
        t_emb = self.time_mlp(t_emb) # [B, D]
        t_emb = t_emb.unsqueeze(1) # [B, 1, D]

        patches = patches + t_emb # Broadcast time embedding to all patches
        
        encoded = self.encoder(patches)  # [B, N, D]
        decoded = self.output_proj(encoded)  # [B, N, patch*patch*channels]
        decoded = decoded.view(B, -1, self.channels, self.patch_size, self.patch_size)  # reshape into patches
        h = w = self.img_size // self.patch_size
        decoded = rearrange(decoded, 'b (h w) c ph pw -> b c (h ph) (w pw)', h=h, w=w)
        return decoded

# Diffusion process

In [None]:
# normalization functions
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        image_size: int, 
        timesteps: int,
        beta_schedule: str = 'linear',
        auto_normalize: bool = True,
        lambda_l1: float = 0.1,
    ):
        super().__init__()

        self.model = model
        self.channels = model.channels
        self.image_size = image_size
        self.num_timesteps = timesteps
        self.lambda_l1 = lambda_l1

        # 1) build beta schedule
        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f'Unknown beta schedule: {beta_schedule}')
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        # 2) register all buffers
        self.register_buffer('betas', betas.float())
        self.register_buffer('alphas_cumprod', alphas_cumprod.float())
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod).float())
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod - 1.0).float())

        # posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance.float())
        self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)).float())
        self.register_buffer('posterior_mean_coef1', (betas * torch.sqrt(alphas_cumprod_prev) / (1 - alphas_cumprod)).float())
        self.register_buffer('posterior_mean_coef2', ((1 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1 - alphas_cumprod)).float())

        # 3) Normalization functions
        self.normalize = normalize_to_neg_one_to_one if auto_normalize else (lambda x: x)
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else (lambda x: x)

    def dynamic_threshold(self, x0, percentile=0.995):
        """ Clips x0 dynamically to avoid over/under-exposed outputs """
        s = torch.quantile(x0.abs().flatten(1), percentile, dim=1)
        s = torch.maximum(s, torch.ones_like(s))[:, None, None, None]
        return x0.clamp(-s, s) / s

    def predict_start_from_noise(self, x_t, t, noise):
        """Estimate x_0 from x_t and predicted noise"""
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 
            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )
    
    def q_posterior(self, x_start, x_t, t):
        """ Compute mean & variance of q(x_{t-1}) | x_t, x_0 """
        mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        var = extract(self.posterior_variance, t, x_t.shape)
        log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return mean, var, log_var
    
    @torch.no_grad()
    def p_sample(self, x, t: int, clip_denoised: bool = True):
        """ One diffusion reverse step """
        b = x.shape[0]
        t_batch = torch.full((b,), t, device=x.device, dtype=torch.long)
        # model predicts noise
        pred_noise = self.model(x, t_batch)
        # recover x0; optionally clamp it
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        if clip_denoised:
            x0 = self.dynamic_threshold(x0)

        # posterior mean & variance
        mean, _, log_var = self.q_posterior(x0, x, t_batch)
        noise = torch.randn_like(x) if t > 0 else 0.0
        return mean + torch.exp(0.5 * log_var) * noise

    @torch.no_grad()
    def p_sample_loop(self, shape, return_all_timesteps: bool = False):
        """ Starting from pure noise, run full reverse chain"""
        img = torch.randn(shape, device=self.betas.device)
        all_imgs = [img]
        for t in tqdm(reversed(range(self.num_timesteps)), desc='sampling'):
            img = self.p_sample(img, t)
            all_imgs.append(img)
        out = torch.stack(all_imgs, dim=1) if return_all_timesteps else img
        return self.unnormalize(out)
    
    @torch.no_grad()
    def p_sample_ddim(self, x, t, t_prev, eta=0.0, clip_denoised=True):
        """ One DDIM sampling step from x_t to x_{t_prev}"""
        b = x.shape[0]
        t_batch = torch.full((b, ), t, device=x.device, dtype=torch.long)
        # predict noise
        pred_noise = self.model(x, t_batch)
        # pred x0
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        if clip_denoised:
            x0 = self.dynamic_threshold(x0)

        alpha_t = self.alphas_cumprod[t]
        alpha_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=x.device)
        sqrt_alpha_t = alpha_t.sqrt()
        sqrt_alpha_prev = alpha_prev.sqrt()
        sigma_t = eta * ((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev)).sqrt()
        pred_dir = (1 - alpha_prev - sigma_t ** 2).sqrt() * pred_noise
        noise = sigma_t * torch.randn_like(x) if t_prev >= 0 else 0.0
        x_prev = sqrt_alpha_prev * x0 + pred_dir + noise

        return x_prev, x0
    
    @torch.no_grad()
    def ddim_sample_loop(self, shape, num_ddim_steps=50, eta=0.0, return_all_timesteps=False):
        """ Run the fuull DDIM sampling loop """
        device = self.betas.device
        img = torch.randn(shape, device=device)
        all_imgs = [img]

        # Create a custom timestep schedule
        ddim_timesteps = np.linspace(0, self.num_timesteps - 1, num_ddim_steps, dtype=int)

        for i in tqdm(range(num_ddim_steps - 1, -1, -1), desc='DDIM sampling'):
            t = ddim_timesteps[i]
            t_prev = ddim_timesteps[i - 1] if i > 0 else -1

            img, _ = self.p_sample_ddim(img, t, t_prev, eta=eta)
            all_imgs.append(img)
        
        out = torch.stack(all_imgs, dim=1) if return_all_timesteps else img
        return self.unnormalize(out)

    def sample(self, batch_size=16, use_ddim=False, num_ddim_steps=50, eta=0.0, return_all_timesteps=False):
        shape = (batch_size, self.channels, self.image_size, self.image_size)
        if use_ddim:
            return self.ddim_sample_loop(shape, num_ddim_steps=num_ddim_steps, eta=eta, return_all_timesteps=return_all_timesteps)
        else:
            return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)
    
    def q_sample(self, x_start, t, noise=None):
        """ Forward noising process """
        if noise is None:
            noise = torch.randn_like(x_start)
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
    
    def p_losses(self, x_start, t, noise=None):
        """ MSE + L1 loss between true noise and model's prediction """
        if noise is None:
            noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)
        pred_noise = self.model(x_noisy, t)
        mse_loss = F.mse_loss(pred_noise, noise, reduction='none')
        mse_loss = mse_loss.mean(dim=list(range(1, mse_loss.ndim))) # mean over c, h, w

        # x_recon = self.predict_start_from_noise(x_noisy, t, pred_noise)
        # x_recon = x_recon.clamp(-1., -1.)
        # l1_loss = F.l1_loss(x_recon, x_start, reduction='none')
        # l1_loss = l1_loss.mean(dim=list(range(1, l1_loss.ndim)))

        # loss = mse_loss + self.lambda_l1 * l1_loss
        loss = mse_loss
        return loss.mean()
    
    def forward(self, img):
        """ Training entrypoint: sample random timesteps & return loss """
        b, c, h, w = img.shape
        assert h == self.image_size and w == self.image_size
        t = torch.randint(0, self.num_timesteps, (b, ), device=img.device)
        img = self.normalize(img)
        return self.p_losses(img, t)

## Trainer

In [None]:
class Trainer:
    def __init__(
        self,
        diffusion_model: nn.Module,
        data_folder: str,
        batch_size: int = 16,
        lr: float = 1e-4,
        num_steps: int = 100000,
        grad_accum_steps: int = 1,
        ema_decay: float = 0.995,
        save_interval: int = 1000,
        num_samples: int = 25,
        results_folder: str = './results',
        use_ddim = False,
        num_ddim_steps=50,
        eta = 0.0
    ):
        # Accelerator
        self.accelerator = Accelerator(mixed_precision='fp16')
        self.device = self.accelerator.device

        # Training State
        self.batch_size       = batch_size
        self.grad_accum_steps = grad_accum_steps
        self.num_steps        = num_steps
        self.save_interval    = save_interval
        self.num_samples      = num_samples

        # model, optimizer, EMA
        self.model = diffusion_model.to(self.device)
        self.optimizer = Adam(self.model.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.num_steps, eta_min=1e-6)
        self.model, self.optimizer = self.accelerator.prepare(
            self.model, self.optimizer
        )
        
        # Use EMA on the raw model
        self.ema = EMA(self.accelerator.unwrap_model(self.model), beta=ema_decay)

        # Data
        ds = Dataset(data_folder, diffusion_model.image_size)
        dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())
        self.dl = cycle(self.accelerator.prepare(dl))

        # checkpoints & samples
        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(parents=True, exist_ok=True)
        self.step = 0
        self.use_ddim = use_ddim
        self.num_ddim_steps = num_ddim_steps
        self.eta = eta
    
    def save(self, milestone: int):
        """Save model, optimizer, EMA and step counter."""
        if not self.accelerator.is_main_process:
            return
        ckpt = self.results_folder / f'model-{milestone}.pt'
        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'optimizer': self.optimizer.state_dict(),
            'ema':       self.ema.state_dict(),
        }
        torch.save(data, ckpt)

    def load(self, ckpt_path: str):
        """Load all state (model, optimizer, EMA, step)."""
        data = torch.load(ckpt_path, map_location=self.device)
        raw_model = self.accelerator.unwrap_model(self.model)
        raw_model.load_state_dict(data['model'])
        self.optimizer.load_state_dict(data['optimizer'])
        self.ema.load_state_dict(data['ema'])
        self.step = data['step']

    def _sample_and_save(self, milestone: int):
        """Generate `num_samples` via EMA model and save grid."""
        self.ema.ema_model.eval()
        batches = num_to_groups(self.num_samples, self.batch_size)
        imgs = torch.cat([
            self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta) for n in batches
        ], dim=0)
        path = self.results_folder / f'sample-{milestone}.png'
        vutils.save_image(imgs, path, nrow=int(math.sqrt(self.num_samples)))

    def train(self):
        """Run the training loop with gradient accumulation, EMA updates, and periodic sampling."""
        pbar = tqdm(total=self.num_steps, initial=self.step, disable=not self.accelerator.is_main_process)
        while self.step < self.num_steps:
            total_loss = 0.0

            # gradient accumulation
            for _ in range(self.grad_accum_steps):
                batch = next(self.dl).to(self.device)
                with self.accelerator.autocast():
                    loss = self.model(batch) / self.grad_accum_steps
                total_loss += loss.item()
                self.accelerator.backward(loss)

            # optimizer step
            self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
            self.step += 1

            # EMA & sampling
            if self.accelerator.is_main_process:
                self.ema.update()
                if self.step % self.save_interval == 0:
                    milestone = self.step // self.save_interval
                    self._sample_and_save(milestone)
                    self.save(milestone)

            pbar.set_description(f'loss: {total_loss:.4f}')
            pbar.update(1)

        if self.accelerator.is_main_process:
            print('Training complete.')

    def inference(self, total: int = 1000, output_path: str = './submission'):
        """
        Generate `total` images using the same batch size as training
        (self.batch_size, e.g. 16), and save them to disk.
        """
        Path(output_path).mkdir(parents=True, exist_ok=True)
        count = 0
        batch_size = self.batch_size  # manually fixed at 16

        with torch.no_grad():
            while count < total:
                n = min(batch_size, total - count)

                # clear any leftover GPU memory
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()

                # sample n images
                imgs = self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta)

                # move to CPU and save
                imgs = imgs.cpu()
                for img in imgs:
                    count += 1
                    vutils.save_image(img, f"{output_path}/{count}.jpg")

                # clean up 
                del imgs
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()

        print("Inference complete.")

## Usage

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
model = TransformerBackbone(
    in_channels=3, img_size=64, patch_size=8,
    emb_dim=512, depth=6, num_heads=8, time_emb_dim=128
)
diffusion = GaussianDiffusion(
    model=model, image_size=64, timesteps=1000,
    beta_schedule='cosine', auto_normalize=True
)
trainer = Trainer(
    diffusion_model=diffusion,
    data_folder = './',
    batch_size=16, lr=1e-4, num_steps=70000,
    grad_accum_steps=1, ema_decay=0.995,
    save_interval=1000, num_samples=25,
    results_folder='./results',
    use_ddim=True, num_ddim_steps=50, eta=0.0
)


In [None]:
trainer.train()

## Inference

In [None]:
ckpt = './results_56/model-20.pt'
trainer.load(ckpt)
trainer.inference()

In [None]:
diffusion = trainer.model
with torch.no_grad():
    all_steps = diffusion.sample(batch_size=1, return_all_timesteps=True)
# shape [1, T+1, C, H, W]

imgs = all_steps.squeeze(0).permute(0, 2, 3, 1).cpu().numpy()
imgs = unnormalize_to_zero_to_one(imgs)

timesteps = [0, diffusion.num_timesteps // 4, diffusion.num_timesteps // 2,
             3 * diffusion.num_timesteps // 4, diffusion.num_timesteps]
sel = [imgs[t] for t in timesteps]

# Plot
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for ax, im, step in zip(axes, sel, timesteps):
    ax.imshow(im)
    ax.axis('off')
    ax.set_title(f"Step {step}")
plt.tight_layout()
plt.show()