# Diffusion Transformer

## Libraries

In [1]:
import os
import gc
import math
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision import transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from tqdm.auto import tqdm
from einops import rearrange

import matplotlib.pyplot as plt
from ema_pytorch import EMA
from PIL import Image
from torchvision.transforms import ToPILImage

## Data preprocess

In [2]:
class CustomDataset(Dataset):
    def __init__(self, images_npy: str, labels_npy: str, image_size: int):
        # load images and labels
        self.sprites = np.load(images_npy)     # shape: [N, H, W, C]
        raw_labels   = np.load(labels_npy)     # shape: [N, 5] (one-hot) or [N] (ints)

        # if one-hot, convert to class indices:
        if raw_labels.ndim == 2:
            self.labels = raw_labels.argmax(axis=1)
        else:
            self.labels = raw_labels

        self.image_size = image_size
        self.transform = transforms.Compose([
            ToPILImage(),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomResizedCrop(image_size, scale=(0.9,1.0)),
            transforms.ColorJitter(0.1,0.1,0.1,0.05),
            transforms.ToTensor(),
            transforms.Normalize((0.5,)*3, (0.5,)*3),
        ])

    def __len__(self):
        return len(self.sprites)

    def __getitem__(self, idx):
        img_np = self.sprites[idx]                # H×W×C
        img = self.transform(img_np)              # [3, H, W]
        label = int(self.labels[idx])             # scalar in {0,…,4}
        return img, label

## Diffusion process

In [3]:
# Helper: extract time-indexed values from buffer
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t).float()
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

# Normalization functions
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

class DiTDiffusion(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        image_size: int, 
        timesteps: int,
        beta_schedule: str = 'linear',
        auto_normalize: bool = True,
        lambda_l1: float = 0.1,
    ):
        super().__init__()
        self.model = model
        self.channels = model.channels
        self.image_size = image_size
        self.num_timesteps = timesteps
        self.lambda_l1 = lambda_l1
        self.num_classes = model.num_classes

        if beta_schedule == 'linear':
            betas = torch.linspace(1e-4, 0.02, timesteps)
        elif beta_schedule == 'cosine':
            steps = timesteps + 1
            x = torch.linspace(0, timesteps, steps)
            alphas_cumprod = torch.cos(((x / timesteps) + 0.008) / 1.008 * torch.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        else:
            raise ValueError(f'Unknown beta schedule: {beta_schedule}')

        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        self.register_buffer('betas', betas.float())
        self.register_buffer('alphas_cumprod', alphas_cumprod.float())
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod).float())
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod - 1.0).float())
        self.register_buffer('posterior_variance', betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod))
        self.register_buffer('posterior_log_variance_clipped', torch.log(self.posterior_variance.clamp(min=1e-20)).float())
        self.register_buffer('posterior_mean_coef1', (betas * torch.sqrt(alphas_cumprod_prev) / (1 - alphas_cumprod)).float())
        self.register_buffer('posterior_mean_coef2', ((1 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1 - alphas_cumprod)).float())

        self.normalize = normalize_to_neg_one_to_one if auto_normalize else (lambda x: x)
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else (lambda x: x)

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    @torch.no_grad()
    def p_sample(self, x, t: int, clip_denoised: bool = True):
        """ One diffusion reverse step """
        b = x.shape[0]
        t_batch = torch.full((b,), t, device=x.device, dtype=torch.long)
        # model predicts noise
        pred_noise = self.model(x, t_batch)
        # recover x0; optionally clamp it
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        x0 = self.dynamic_threshold(x0)

        # posterior mean & variance
        mean, _, log_var = self.q_posterior(x0, x, t_batch)
        noise = torch.randn_like(x) if t > 0 else 0.0
        return mean + torch.exp(0.5 * log_var) * noise

    @torch.no_grad()
    def p_sample_loop(self, shape, return_all_timesteps: bool = False):
        img = torch.randn(shape, device=self.betas.device)
        if return_all_timesteps:
            all_imgs = [img]
            for t in tqdm(reversed(range(self.num_timesteps)), desc='sampling'):
                img = self.p_sample(img, t)
                all_imgs.append(img)
            out = torch.stack(all_imgs, dim=1)
        else:
            for t in tqdm(reversed(range(self.num_timesteps)), desc='sampling'):
                img = self.p_sample(img, t)
            out = img
        return self.unnormalize(out)
    
    @torch.no_grad()
    def p_sample_cfg(self, x, t, labels, guidance_scale):
        """
        x:       [B, C, H, W]
        t:       integer timestep
        labels:  LongTensor [B] with the class IDs
        """
        b = x.shape[0]
        device = x.device

        # 1) duplicate inputs for unconditional + conditional
        x_in = torch.cat([x, x], dim=0)
        t_in = torch.full((2*b,), t, device=device, dtype=torch.long)
        # first half: uncond (we pass dummy label, e.g. zeros)
        # second half: real labels
        lbl_uncond = torch.full((b,), self.num_classes, device=device, dtype=torch.long)
        lbl_cond   = labels
        labels_in  = torch.cat([lbl_uncond, lbl_cond], dim=0)

        # 2) predict noise for both branches
        eps_all = self.model(x_in, t_in, labels_in)  # model must accept labels!
        eps_uncond, eps_cond = eps_all.chunk(2, dim=0)

        # 3) fuse via CFG
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

        # 4) standard posterior
        x0 = self.predict_start_from_noise(x, t_in[:b], eps)
        x0 = self.dynamic_threshold(x0)
        mean, _, log_var = self.q_posterior(x0, x, t_in[:b])
        noise = torch.randn_like(x) if t > 0 else 0.0
        return mean + torch.exp(0.5 * log_var) * noise

    @torch.no_grad()
    def p_sample_cfg_loop(self, shape, labels, guidance_scale=0.0):
        img = torch.randn(shape, device=self.betas.device)
        for t in reversed(range(self.num_timesteps)):
            if guidance_scale > 0:
                img = self.p_sample_cfg(img, t, labels, guidance_scale)
            else:
                img = self.p_sample(img, t)
        return self.unnormalize(img)
    
    @torch.no_grad()
    def p_sample_ddim(self, x, t, t_prev, eta=0.0, clip_denoised=True):
        """ One DDIM sampling step from x_t to x_{t_prev}"""
        b = x.shape[0]
        t_batch = torch.full((b, ), t, device=x.device, dtype=torch.long)
        # predict noise
        pred_noise = self.model(x, t_batch)
        # pred x0
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        x0 = self.dynamic_threshold(x0)

        alpha_t = self.alphas_cumprod[t]
        alpha_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=x.device)
        sqrt_alpha_t = alpha_t.sqrt()
        sqrt_alpha_prev = alpha_prev.sqrt()
        sigma_t = eta * ((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev)).sqrt()
        pred_dir = (1 - alpha_prev - sigma_t ** 2).sqrt() * pred_noise
        noise = sigma_t * torch.randn_like(x) if t_prev >= 0 else 0.0
        x_prev = sqrt_alpha_prev * x0 + pred_dir + noise

        return x_prev, x0
    


    @torch.no_grad()
    def ddim_sample_loop(self, shape, num_ddim_steps=50, eta=0.0, return_all_timesteps=False):
        """ Run the full DDIM sampling loop """
        device = self.betas.device
        img = torch.randn(shape, device=device)
        all_imgs = [img]

        # Create a custom timestep schedule
        ddim_timesteps = np.linspace(0, self.num_timesteps - 1, num_ddim_steps, dtype=int)

        for i in tqdm(range(num_ddim_steps - 1, -1, -1), desc='DDIM sampling'):
            t = ddim_timesteps[i]
            t_prev = ddim_timesteps[i - 1] if i > 0 else -1

            img, _ = self.p_sample_ddim(img, t, t_prev, eta=eta)
            all_imgs.append(img)
        
        out = torch.stack(all_imgs, dim=1) if return_all_timesteps else img
        return self.unnormalize(out)

    @torch.no_grad()
    def sample(self, batch_size=16, use_ddim=False, use_cfg=False, num_ddim_steps=50, eta=0.0, return_all_timesteps=False, labels=None, guidance_scale=0.0):
        shape = (batch_size, self.channels, self.image_size, self.image_size)

        if use_ddim:
            return self.ddim_sample_loop(
                shape,
                num_ddim_steps=num_ddim_steps,
                eta=eta,
                return_all_timesteps=return_all_timesteps
            )
        elif use_cfg:
            if labels is None:
                labels = torch.zeros(batch_size, dtype=torch.long, device=self.betas.device)
            return self.p_sample_cfg_loop(shape, labels, guidance_scale)
        else:
            return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):
        mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        var = extract(self.posterior_variance, t, x_t.shape)
        log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return mean, var, log_var

    def p_losses(self, x_start, t, noise=None):
        """ MSE + L1 loss between true noise and model's prediction """
        if noise is None:
            noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)
        pred_noise = self.model(x_noisy, t)
        mse_loss = F.mse_loss(pred_noise, noise, reduction='none')
        mse_loss = mse_loss.mean(dim=list(range(1, mse_loss.ndim))) # mean over c, h, w

        x_recon = self.predict_start_from_noise(x_noisy, t, pred_noise)
        x_recon = x_recon.clamp(-1., 1.)
        l1_loss = F.l1_loss(x_recon, x_start, reduction='none')
        l1_loss = l1_loss.mean(dim=list(range(1, l1_loss.ndim)))

        loss = mse_loss + self.lambda_l1 * l1_loss
        # loss = mse_loss
        return loss.mean()
    
    def forward(self, img, labels=None):
        """ Training entrypoint: sample random timesteps & return loss """
        b, c, h, w = img.shape
        assert h == self.image_size and w == self.image_size
        t = torch.randint(0, self.num_timesteps, (b, ), device=img.device)
        img = self.normalize(img)
        return self.p_losses(img, t)

## Trainer

In [4]:
def cycle(dl):
    while True:
        for data in dl:
            yield data

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

class DiTTrainer:
    def __init__(
        self,
        diffusion_model,
        data_folder: str,
        label_folder: str,
        batch_size: int = 16,
        lr: float = 1e-4,
        num_steps: int = 100000,
        grad_accum_steps: int = 1,
        ema_decay: float = 0.995,
        save_interval: int = 1000,
        num_samples: int = 25,
        results_folder: str = './results_dit',
        use_ddim=False,
        num_ddim_steps=50,
        eta=0.0
    ):
        self.accelerator = Accelerator(mixed_precision='fp16')
        self.device = self.accelerator.device

        self.batch_size = batch_size
        self.grad_accum_steps = grad_accum_steps
        self.num_steps = num_steps
        self.save_interval = save_interval
        self.num_samples = num_samples
        self.use_ddim = use_ddim
        self.num_ddim_steps = num_ddim_steps
        self.eta = eta

        self.model = diffusion_model.to(self.device)
        self.optimizer = Adam(self.model.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_steps, eta_min=1e-6)
        self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
        self.ema = EMA(self.accelerator.unwrap_model(self.model), beta=ema_decay)

        ds = CustomDataset(data_folder, label_folder, diffusion_model.image_size)
        dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())
        self.dl = cycle(self.accelerator.prepare(dl))

        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(parents=True, exist_ok=True)
        self.step = 0

    def save(self, milestone: int):
        if not self.accelerator.is_main_process:
            return
        ckpt = self.results_folder / f'dit_model-{milestone}.pt'
        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'optimizer': self.optimizer.state_dict(),
            'ema': self.ema.state_dict(),
        }
        torch.save(data, ckpt)

    def load(self, ckpt_path: str):
        data = torch.load(ckpt_path, map_location=self.device)
        raw_model = self.accelerator.unwrap_model(self.model)
        raw_model.load_state_dict(data['model'])
        self.optimizer.load_state_dict(data['optimizer'])
        self.ema.load_state_dict(data['ema'])
        self.step = data['step']

    def _sample_and_save(self, milestone: int):
        self.ema.ema_model.eval()
        batches = num_to_groups(self.num_samples, self.batch_size)
        imgs = torch.cat([
            self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta) for n in batches
        ], dim=0)
        path = self.results_folder / f'dit_sample-{milestone}.png'
        vutils.save_image(imgs, path, nrow=int(math.sqrt(self.num_samples)))

    def train(self):
        pbar = tqdm(total=self.num_steps, initial=self.step, disable=not self.accelerator.is_main_process)
        while self.step < self.num_steps:
            total_loss = 0.0
            for _ in range(self.grad_accum_steps):
                imgs, labels = next(self.dl)
                imgs = imgs.to(self.device)
                with self.accelerator.autocast():
                    loss = self.model(imgs) / self.grad_accum_steps
                total_loss += loss.item()
                self.accelerator.backward(loss)

            self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
            self.step += 1

            if self.accelerator.is_main_process:
                self.ema.update()
                if self.step % self.save_interval == 0:
                    milestone = self.step // self.save_interval
                    self._sample_and_save(milestone)
                    self.save(milestone)

            # if self.step % 10 == 0 and self.accelerator.is_main_process:
            #     print(f"[Step {self.step}] loss: {total_loss:.4f}")

            pbar.set_description(f'loss: {total_loss:.4f}')
            pbar.update(1)

        if self.accelerator.is_main_process:
            print('Training complete.')

    def inference(self, total: int = 1000, output_path: str = './submission_dit'):
        Path(output_path).mkdir(parents=True, exist_ok=True)
        count = 0
        batch_size = self.batch_size

        with torch.no_grad():
            while count < total:
                n = min(batch_size, total - count)
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()
                imgs = self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta)
                imgs = imgs.cpu()
                for img in imgs:
                    count += 1
                    vutils.save_image(img, f"{output_path}/{count}.jpg")
                del imgs
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()
        print("Inference complete.")

## Network

In [5]:
class PatchEmbed(nn.Module):
    def __init__(self, in_channels, embed_dim, patch_size):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2) # (B, N_patches, embed_dim)
        return x

class ConditionEmbed(nn.Module):
    def __init__(self, num_classes:int, embed_dim:int):
        super().__init__()
        self.label_embed = nn.Embedding(num_classes + 1, embed_dim)
        nn.init.normal_(self.label_embed.weight, std=0.02)

    def forward(self, labels: torch.LongTensor):
        return self.label_embed(labels)

class AdaLNBlock(nn.Module):
    def __init__(self, dim, num_heads, cond_dim):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.feedforward = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.GELU(),
            nn.Linear(dim*4, dim)
        )
        self.mlp_gamma = nn.Linear(cond_dim, dim)
        self.mlp_beta = nn.Linear(cond_dim, dim)

    def forward(self, x, cond):
        gamma1, beta1 = self.mlp_gamma(cond), self.mlp_beta(cond)
        x_norm = self.norm1(x)
        x = x + self.self_attn(x_norm, x_norm, x_norm)[0] * (1+gamma1.unsqueeze(1)) + beta1.unsqueeze(1)

        gamma2, beta2 = self.mlp_gamma(cond), self.mlp_beta(cond)
        x_norm = self.norm2(x)
        x = x + self.feedforward(x_norm) * (1 + gamma2.unsqueeze(1)) + beta2.unsqueeze(1)

        return x
    
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, cond_dim):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

        self.feedforward = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x, cond_tokens):
        # Self-Attention
        x = x + self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        # Cross-Attention
        x = x + self.cross_attn(self.norm2(x), cond_tokens, cond_tokens)[0]
        # Feedforward
        x = x + self.feedforward(self.norm3(x))
        return x
    
class InContextConditionBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.feedforward = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x, cond_tokens):
        # Concatenate condition tokens
        x_cat = torch.cat([x, cond_tokens], dim=1)  # (B, N+M, D)
        # Apply transformer attention over both
        attn_out = self.attn(self.norm1(x_cat), self.norm1(x_cat), self.norm1(x_cat))[0]
        x_cat = x_cat + attn_out
        x_cat = x_cat + self.feedforward(self.norm2(x_cat))
        # Return only the original input portion (not condition)
        return x_cat[:, :x.size(1)]
    
class DiTDiffusionModel(nn.Module):
    def __init__(self, patch_size=4, in_channels=3, embed_dim=512, depth=12, num_heads=8, cond_dim=128, num_classes=5):
        super().__init__()
        self.channels = in_channels
        self.num_classes = num_classes
        self.patch_embed = PatchEmbed(in_channels, embed_dim, patch_size)
        self.condition_embed = ConditionEmbed(num_classes=self.num_classes, embed_dim=embed_dim)
        self.blocks = nn.ModuleList([
            AdaLNBlock(embed_dim, num_heads, cond_dim) for _ in range(depth)
        ])
        self.to_output = nn.Linear(embed_dim, in_channels * patch_size * patch_size)

    def forward(self, x_latent, t=None, labels=None):
        B = x_latent.shape[0]
        if labels is None:
            labels = torch.full((B,), self.num_classes, device=x_latent.device, dtype=torch.long)
        cond_emb = self.condition_embed(labels)
        x = self.patch_embed(x_latent)
        for blk in self.blocks:
            x = blk(x, cond_emb)
        return self.to_output(x).transpose(1, 2).view_as(x_latent)

## Usage

In [6]:
model = DiTDiffusionModel(
    patch_size=2,
    in_channels=3,
    embed_dim=256,
    depth=12,
    num_heads=4,
    cond_dim=128
)

diffusion_model = DiTDiffusion(
    model=model,
    image_size=16,
    timesteps=1000,
    beta_schedule='linear',
    lambda_l1=0.1,
    auto_normalize=True
)

trainer = DiTTrainer(
    diffusion_model=diffusion_model,
    data_folder='./sprites.npy',
    label_folder='./sprites_labels.npy',
    batch_size=32,
    lr=1e-4,
    num_steps=100000,
    save_interval=500,
    num_samples=36,
    results_folder='./results_dit',
    use_ddim=False,
    num_ddim_steps=50,
    eta=0.0
)

In [7]:
trainer.train()

  0%|          | 0/100000 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x256 and 128x256)

In [None]:
trainer.load('./results_dit/dit_model-20.pt')
trainer.inference(total=1000, output_path='./samples')
