# Diffusion Transformer

## Libraries

In [1]:
import os
import gc
import math
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision import transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from tqdm.auto import tqdm
from einops import rearrange

import matplotlib.pyplot as plt
from ema_pytorch import EMA
from PIL import Image
from torchvision.transforms import ToPILImage

## Data preprocess

In [None]:
images = np.load('sprites.npy')
labels = np.load('sprites_labels.npy')

class_ids = np.argmax(labels, axis=1)
num_classes = labels.shape[1]
classes = ['human', 'non_human', 'food', 'spell', 'sideface']
for cls in range(num_classes):
    cls_mask = (class_ids==cls)
    class_imgs = images[cls_mask]
    out_name = f'images_class_{classes[cls]}.npy'
    np.save(out_name, class_imgs)
    print(f"Class {cls}: {class_imgs.shape[0]} images saved to {out_name}")

In [None]:

def show_npy_image(file_path):
    """
    Load an image from a .npy file and display it.
    Supports:
      • Grayscale: shape [H, W] or [1, H, W]
      • RGB:       shape [H, W, 3] or [3, H, W]
    """
    # 1) Load
    img = np.load(file_path)[0]
    
    # 2) If it’s channel-first (C×H×W), convert to H×W×C
    if img.ndim == 3 and img.shape[0] in (1, 3):
        img = img.transpose(1, 2, 0)
    
    # 3) Plot
    plt.figure(figsize=(4, 4))
    if img.ndim == 2 or (img.ndim == 3 and img.shape[2] == 1):
        plt.imshow(img.squeeze(), cmap='gray')
    else:
        plt.imshow(img)
    plt.axis('off')
    plt.title(file_path)
    plt.show()

for cls in range(5):
    path = f'images_class_{classes[cls]}.npy'  # or prepend folder, e.g. 'data/images_class{cls}.npy'
    try:
        show_npy_image(path)
    except FileNotFoundError:
        print(f"⚠️  File not found: {path}")

In [2]:
class CustomDataset(Dataset):
    def __init__(self, images_npy: str, labels_npy: str, image_size: int):
        # load images and labels
        self.sprites = np.load(images_npy)     # shape: [N, H, W, C]
        raw_labels   = np.load(labels_npy)     # shape: [N, 5] (one-hot) or [N] (ints)

        # if one-hot, convert to class indices:
        if raw_labels.ndim == 2:
            self.labels = raw_labels.argmax(axis=1)
        else:
            self.labels = raw_labels

        self.image_size = image_size
        self.transform = transforms.Compose([
            ToPILImage(),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomResizedCrop(image_size, scale=(0.9,1.0)),
            transforms.ColorJitter(0.1,0.1,0.1,0.05),
            transforms.ToTensor(),
            transforms.Normalize((0.5,)*3, (0.5,)*3),
        ])

    def __len__(self):
        return len(self.sprites)

    def __getitem__(self, idx):
        img_np = self.sprites[idx]                # H×W×C
        img = self.transform(img_np)              # [3, H, W]
        label = int(self.labels[idx])             # scalar in {0,…,4}
        return img, label


## Noise Generation

In [3]:
def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

# e.g. from Nichol & Dhariwal’s “improved DDPM”:
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas.clamp(max=0.999)

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

## network modules test

In [None]:
class TransformerBackbone(nn.Module):
    def __init__(self, in_channels=3, img_size=64, patch_size=2, stride=1, emb_dim=512, depth=6, num_heads=8, time_emb_dim=128):
        super().__init__()
        self.channels = in_channels
        self.embed = ConvPatchEmbedding(
            in_channels, emb_dim, img_size,
            patch_size=patch_size, stride=stride
        )

        self.time_embedding = SinusoidalPosEmb(time_emb_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim)
        )

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=1024, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.output_proj = nn.Sequential(
            nn.Linear(emb_dim, patch_size * patch_size * in_channels),
        )
        self.patch_size = patch_size
        self.img_size = img_size
        

    def forward(self, x, t):
        B = x.size(0)
        patches = self.embed(x)  # [B, N, D]

        t_emb = self.time_embedding(t) # [B, time_emb_dim]
        t_emb = self.time_mlp(t_emb) # [B, D]
        t_emb = t_emb.unsqueeze(1) # [B, 1, D]

        patches = patches + t_emb # Broadcast time embedding to all patches
        
        encoded = self.encoder(patches)  # [B, N, D]
        decoded = self.output_proj(encoded)  # [B, N, patch*patch*channels]
        decoded = decoded.view(B, -1, self.channels, self.patch_size, self.patch_size)  # reshape into patches
        h = w = self.img_size // self.patch_size
        decoded = rearrange(decoded, 'b (h w) c ph pw -> b c (h ph) (w pw)', h=h, w=w)
        return decoded
    
class ConvSkipViT(nn.Module):
    def __init__(
        self,
        in_channels=3,
        img_size=64,
        patch_size=2,
        stride=1,
        emb_dim=512,
        depth=6,
        num_heads=8,
        time_emb_dim=128,
    ):
        super().__init__()
        # 1) conv to lift 3→emb_dim, preserving H×W
        self.conv_in = nn.Conv2d(in_channels, emb_dim, kernel_size=3, padding=1)
        
        # 2) the Transformer expects emb_dim channels
        self.vit = TransformerBackbone(
            in_channels=emb_dim,
            img_size=img_size,
            patch_size=patch_size,
            stride=stride,
            emb_dim=emb_dim,
            depth=depth,
            num_heads=num_heads,
            time_emb_dim=time_emb_dim,
        )
        
        # 3) conv to project emb_dim→3 and add back the original RGB
        self.conv_out = nn.Conv2d(emb_dim, in_channels, kernel_size=3, padding=1)
        
        # for diffusion wrapper
        self.channels = in_channels
        self.image_size = img_size

    def forward(self, x, t):
        """
        x: [B, 3, H, W], t: [B] timesteps
        """
        # conv skip-in
        x_emb = self.conv_in(x)                     # [B, emb_dim, H, W]
        
        # pass through your ViT
        y = self.vit(x_emb, t)                      # [B, emb_dim, H, W]
        
        # conv skip-out + residual
        out = self.conv_out(y) + x                  # [B, 3, H, W]
        return out

## Diffusion process

In [4]:
# normalization functions
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        image_size: int, 
        timesteps: int,
        beta_schedule: str = 'linear',
        auto_normalize: bool = True,
        lambda_l1: float = 0.1,
    ):
        super().__init__()

        self.model = model
        self.channels = model.channels
        self.image_size = image_size
        self.num_timesteps = timesteps
        self.lambda_l1 = lambda_l1
        self.num_classes = model.num_classes

        # 1) build beta schedule
        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f'Unknown beta schedule: {beta_schedule}')
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        # 2) register all buffers
        self.register_buffer('betas', betas.float())
        self.register_buffer('alphas_cumprod', alphas_cumprod.float())
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod).float())
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod - 1.0).float())

        # posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance.float())
        self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)).float())
        self.register_buffer('posterior_mean_coef1', (betas * torch.sqrt(alphas_cumprod_prev) / (1 - alphas_cumprod)).float())
        self.register_buffer('posterior_mean_coef2', ((1 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1 - alphas_cumprod)).float())

        # 3) Normalization functions
        self.normalize = normalize_to_neg_one_to_one if auto_normalize else (lambda x: x)
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else (lambda x: x)

    def dynamic_threshold(self, x0, percentile=0.995):
        """ Clips x0 dynamically to avoid over/under-exposed outputs """
        s = torch.quantile(x0.abs().flatten(1), percentile, dim=1)
        s = torch.maximum(s, torch.ones_like(s))[:, None, None, None]
        return x0.clamp(-s, s) / s

    def predict_start_from_noise(self, x_t, t, noise):
        """Estimate x_0 from x_t and predicted noise"""
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 
            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )
    
    def q_posterior(self, x_start, x_t, t):
        """ Compute mean & variance of q(x_{t-1}) | x_t, x_0 """
        mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        var = extract(self.posterior_variance, t, x_t.shape)
        log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return mean, var, log_var
    
    @torch.no_grad()
    def p_sample(self, x, t: int, clip_denoised: bool = True):
        """ One diffusion reverse step """
        b = x.shape[0]
        t_batch = torch.full((b,), t, device=x.device, dtype=torch.long)
        # model predicts noise
        pred_noise = self.model(x, t_batch)
        # recover x0; optionally clamp it
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        x0 = self.dynamic_threshold(x0)

        # posterior mean & variance
        mean, _, log_var = self.q_posterior(x0, x, t_batch)
        noise = torch.randn_like(x) if t > 0 else 0.0
        return mean + torch.exp(0.5 * log_var) * noise

    @torch.no_grad()
    def p_sample_loop(self, shape, return_all_timesteps: bool = False):
        img = torch.randn(shape, device=self.betas.device)
        if return_all_timesteps:
            all_imgs = [img]
            for t in tqdm(reversed(range(self.num_timesteps)), desc='sampling'):
                img = self.p_sample(img, t)
                all_imgs.append(img)
            out = torch.stack(all_imgs, dim=1)
        else:
            for t in tqdm(reversed(range(self.num_timesteps)), desc='sampling'):
                img = self.p_sample(img, t)
            out = img
        return self.unnormalize(out)
    
    @torch.no_grad()
    def p_sample_cfg(self, x, t, labels, guidance_scale):
        """
        x:       [B, C, H, W]
        t:       integer timestep
        labels:  LongTensor [B] with the class IDs
        """
        b = x.shape[0]
        device = x.device

        # 1) duplicate inputs for unconditional + conditional
        x_in = torch.cat([x, x], dim=0)
        t_in = torch.full((2*b,), t, device=device, dtype=torch.long)
        # first half: uncond (we pass dummy label, e.g. zeros)
        # second half: real labels
        lbl_uncond = torch.full((b,), self.num_classes, device=device, dtype=torch.long)
        lbl_cond   = labels
        labels_in  = torch.cat([lbl_uncond, lbl_cond], dim=0)

        # 2) predict noise for both branches
        eps_all = self.model(x_in, t_in, labels_in)  # model must accept labels!
        eps_uncond, eps_cond = eps_all.chunk(2, dim=0)

        # 3) fuse via CFG
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

        # 4) standard posterior
        x0 = self.predict_start_from_noise(x, t_in[:b], eps)
        x0 = self.dynamic_threshold(x0)
        mean, _, log_var = self.q_posterior(x0, x, t_in[:b])
        noise = torch.randn_like(x) if t > 0 else 0.0
        return mean + torch.exp(0.5 * log_var) * noise

    @torch.no_grad()
    def p_sample_cfg_loop(self, shape, labels, guidance_scale=0.0):
        img = torch.randn(shape, device=self.betas.device)
        for t in reversed(range(self.num_timesteps)):
            if guidance_scale > 0:
                img = self.p_sample_cfg(img, t, labels, guidance_scale)
            else:
                img = self.p_sample(img, t)
        return self.unnormalize(img)
    
    @torch.no_grad()
    def p_sample_ddim(self, x, t, t_prev, eta=0.0, clip_denoised=True):
        """ One DDIM sampling step from x_t to x_{t_prev}"""
        b = x.shape[0]
        t_batch = torch.full((b, ), t, device=x.device, dtype=torch.long)
        # predict noise
        pred_noise = self.model(x, t_batch)
        # pred x0
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        x0 = self.dynamic_threshold(x0)

        alpha_t = self.alphas_cumprod[t]
        alpha_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=x.device)
        sqrt_alpha_t = alpha_t.sqrt()
        sqrt_alpha_prev = alpha_prev.sqrt()
        sigma_t = eta * ((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev)).sqrt()
        pred_dir = (1 - alpha_prev - sigma_t ** 2).sqrt() * pred_noise
        noise = sigma_t * torch.randn_like(x) if t_prev >= 0 else 0.0
        x_prev = sqrt_alpha_prev * x0 + pred_dir + noise

        return x_prev, x0
    


    @torch.no_grad()
    def ddim_sample_loop(self, shape, num_ddim_steps=50, eta=0.0, return_all_timesteps=False):
        """ Run the full DDIM sampling loop """
        device = self.betas.device
        img = torch.randn(shape, device=device)
        all_imgs = [img]

        # Create a custom timestep schedule
        ddim_timesteps = np.linspace(0, self.num_timesteps - 1, num_ddim_steps, dtype=int)

        for i in tqdm(range(num_ddim_steps - 1, -1, -1), desc='DDIM sampling'):
            t = ddim_timesteps[i]
            t_prev = ddim_timesteps[i - 1] if i > 0 else -1

            img, _ = self.p_sample_ddim(img, t, t_prev, eta=eta)
            all_imgs.append(img)
        
        out = torch.stack(all_imgs, dim=1) if return_all_timesteps else img
        return self.unnormalize(out)

    def sample(self, batch_size=16, use_ddim=False, use_cfg=False, num_ddim_steps=50, eta=0.0, return_all_timesteps=False, labels=None, guidance_scale=0.0):
        shape = (batch_size, self.channels, self.image_size, self.image_size)
        if use_ddim:
            return self.ddim_sample_loop(shape, num_ddim_steps=num_ddim_steps, eta=eta, return_all_timesteps=return_all_timesteps)
        elif use_cfg:
            if labels is None:
                labels = torch.zeros(batch_size, dtype=torch.long, device=self.betas.device)
            return self.p_sample_cfg_loop(shape, labels, guidance_scale)
        else:
            return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)
    
    def q_sample(self, x_start, t, noise=None):
        """ Forward noising process """
        if noise is None:
            noise = torch.randn_like(x_start)
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
    
    def p_losses(self, x_start, t, noise=None):
        """ MSE + L1 loss between true noise and model's prediction """
        if noise is None:
            noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)
        pred_noise = self.model(x_noisy, t)
        mse_loss = F.mse_loss(pred_noise, noise, reduction='none')
        mse_loss = mse_loss.mean(dim=list(range(1, mse_loss.ndim))) # mean over c, h, w

        x_recon = self.predict_start_from_noise(x_noisy, t, pred_noise)
        x_recon = x_recon.clamp(-1., 1.)
        l1_loss = F.l1_loss(x_recon, x_start, reduction='none')
        l1_loss = l1_loss.mean(dim=list(range(1, l1_loss.ndim)))

        loss = mse_loss + self.lambda_l1 * l1_loss
        # loss = mse_loss
        return loss.mean()
    
    def forward(self, img, labels=None):
        """ Training entrypoint: sample random timesteps & return loss """
        b, c, h, w = img.shape
        assert h == self.image_size and w == self.image_size
        t = torch.randint(0, self.num_timesteps, (b, ), device=img.device)
        img = self.normalize(img)
        return self.p_losses(img, t)

## Trainer

In [5]:
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def cycle(dl):
    while True:
        for data in dl:
            yield data

class Trainer:
    def __init__(
        self,
        diffusion_model: nn.Module,
        data_folder: str,
        label_folder: str,
        batch_size: int = 16,
        lr: float = 1e-4,
        num_steps: int = 100000,
        grad_accum_steps: int = 1,
        ema_decay: float = 0.995,
        save_interval: int = 1000,
        num_samples: int = 25,
        results_folder: str = './results',
        use_ddim = False,
        num_ddim_steps=50,
        eta = 0.0, 
        use_cfg = False,
        guidance_scale=2.0
    ):
        # Accelerator
        self.accelerator = Accelerator(mixed_precision='fp16')
        self.device = self.accelerator.device

        # Training State
        self.batch_size       = batch_size
        self.grad_accum_steps = grad_accum_steps
        self.num_steps        = num_steps
        self.save_interval    = save_interval
        self.num_samples      = num_samples

        # model, optimizer, EMA
        self.model = diffusion_model.to(self.device)
        self.optimizer = Adam(self.model.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.num_steps, eta_min=1e-6)
        self.model, self.optimizer = self.accelerator.prepare(
            self.model, self.optimizer
        )
        
        # Use EMA on the raw model
        self.ema = EMA(self.accelerator.unwrap_model(self.model), beta=ema_decay)

        # Data
        ds = CustomDataset(data_folder, label_folder, diffusion_model.image_size)
        dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())
        self.dl = cycle(self.accelerator.prepare(dl))

        # checkpoints & samples
        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(parents=True, exist_ok=True)
        self.step = 0
        self.use_ddim = use_ddim
        self.num_ddim_steps = num_ddim_steps
        self.eta = eta
        self.use_cfg = use_cfg
        self.guidance_scale = guidance_scale
    
    def save(self, milestone: int):
        """Save model, optimizer, EMA and step counter."""
        if not self.accelerator.is_main_process:
            return
        ckpt = self.results_folder / f'model-{milestone}.pt'
        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'optimizer': self.optimizer.state_dict(),
            'ema':       self.ema.state_dict(),
        }
        torch.save(data, ckpt)

    def load(self, ckpt_path: str):
        """Load all state (model, optimizer, EMA, step)."""
        data = torch.load(ckpt_path, map_location=self.device)
        raw_model = self.accelerator.unwrap_model(self.model)
        raw_model.load_state_dict(data['model'])
        self.optimizer.load_state_dict(data['optimizer'])
        self.ema.load_state_dict(data['ema'])
        self.step = data['step']

    def _sample_and_save(self, milestone: int, guidance_scale: float = 2.0, use_cfg=False):
        """Generate `num_samples` via EMA model and save grid."""
        self.ema.ema_model.eval()
        batches = num_to_groups(self.num_samples, self.batch_size)
        if use_cfg:
            cls = milestone % 5
            labels = torch.full((self.num_samples,), cls, dtype=torch.long, device=self.device)
            imgs = []
            offset = 0
            for n in batches:
                imgs.append(
                    self.ema.ema_model.sample(
                        batch_size=n,
                        use_cfg=True,
                        labels=labels[offset:offset+n],
                        guidance_scale=guidance_scale,
                        use_ddim=self.use_ddim,
                        num_ddim_steps=self.num_ddim_steps,
                        eta=self.eta
                    )
                )
                offset += n
            imgs = torch.cat(imgs, dim=0)
            path = self.results_folder / f'sample-{milestone}-cls{cls}-gs{guidance_scale}.png'
        else:
            imgs = torch.cat([
                self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta) for n in batches
            ], dim=0)
            path = self.results_folder / f'sample-{milestone}.png'
        vutils.save_image(imgs, path, nrow=int(math.sqrt(self.num_samples)))

    def train(self):
        """Run the training loop with gradient accumulation, EMA updates, and periodic sampling."""
        pbar = tqdm(total=self.num_steps, initial=self.step, disable=True)
        while self.step < self.num_steps:
            total_loss = 0.0

            # gradient accumulation
            for _ in range(self.grad_accum_steps):
                imgs, labels = next(self.dl)
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                with self.accelerator.autocast():
                    loss = self.model(imgs, labels) / self.grad_accum_steps
                total_loss += loss.item()
                self.accelerator.backward(loss)

            # optimizer step
            self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
            self.step += 1

            # EMA & sampling
            if self.accelerator.is_main_process:
                self.ema.update()
                if self.step % self.save_interval == 0:
                    milestone = self.step // self.save_interval
                    self._sample_and_save(milestone, self.guidance_scale, self.use_cfg)
                    self.save(milestone)
            if self.step % 10 == 0 and self.accelerator.is_main_process:
                print(f"[Step {self.step}] loss: {total_loss:.4f}")

            pbar.set_description(f'loss: {total_loss:.4f}')
            pbar.update(1)

        if self.accelerator.is_main_process:
            print('Training complete.')

    def inference(self, total: int = 1000, target_class: int=0, guidance_scale: float = 2.0, output_path: str = './submission'):
        """
        Generate `total` images using the same batch size as training
        (self.batch_size, e.g. 16), and save them to disk.
        """
        Path(output_path).mkdir(parents=True, exist_ok=True)
        count = 0
        batch_size = self.batch_size  # manually fixed at 16

        with torch.no_grad():
            while count < total:
                n = min(batch_size, total - count)

                # clear any leftover GPU memory
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()

                # sample n images
                imgs = self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta)

                # move to CPU and save
                imgs = imgs.cpu()
                for img in imgs:
                    count += 1
                    vutils.save_image(img, f"{output_path}/{count}.jpg")

                # clean up 
                del imgs
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()

        print("Inference complete.")

## Network modules 

## UnetTransformer

In [6]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = torch.exp(
            torch.arange(half_dim, device=device) * -(math.log(10000) / (half_dim - 1))
        )
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb  # [B, dim]

class TransformerBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        emb_dim: int,
        num_heads: int,
        dim_feedforward: int | None = None,
        dropout: float = 0.0
    ):
        super().__init__()
        self.input_proj = (
            nn.Conv2d(in_channels, emb_dim, kernel_size=1)
            if in_channels != emb_dim else nn.Identity()
        )

        self.self_attn = nn.MultiheadAttention(
            embed_dim=emb_dim, num_heads=num_heads,
            dropout=dropout, batch_first=True
        )

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=emb_dim, num_heads=num_heads,
            dropout=dropout, batch_first=True
        )

        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.norm3 = nn.LayerNorm(emb_dim)

        ff_dim = dim_feedforward or (emb_dim * 4)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, emb_dim),
        )

        self.output_proj = (
            nn.Conv2d(emb_dim, in_channels, kernel_size=1)
            if in_channels != emb_dim else nn.Identity()
        )

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor, label_emb: torch.Tensor | None = None):
        B, C, H, W = x.shape
        x0 = self.input_proj(x)                   # [B, emb_dim, H, W]
        t = t_emb[:, :, None, None]
        x0 = x0 + t

        x_flat = rearrange(x0, 'b c h w -> b (h w) c')  # [B, N, C]
        x1, _ = self.self_attn(x_flat, x_flat, x_flat, need_weights=False)
        x1 = self.norm1(x_flat + x1)

        if label_emb is not None:
            label_emb = label_emb.unsqueeze(1)         # [B, 1, emb_dim]
            x2, _ = self.cross_attn(x1, label_emb, label_emb, need_weights=False)
            x2 = self.norm2(x1 + x2)
        else:
            x2 = x1

        ff_out = self.ff(x2)
        x3 = self.norm3(x2 + ff_out)

        x3 = rearrange(x3, 'b (h w) c -> b c h w', h=H, w=W)
        return self.output_proj(x3)

class CBAM(nn.Module): # Convolutional Block attention module
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1, bias=False),
            nn.Sigmoid()
        )
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        ca = self.channel_attn(x)
        x = x * ca

        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        sa = self.spatial_attn(torch.cat([avg_pool, max_pool], dim=1))
        return x * sa

class DoubleConv(nn.Module):
    """Two successive 3×3 convs each followed by GELU."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GELU(),
        )
        self.cbam = CBAM(out_ch)
    def forward(self, x):
        x = self.conv(x)
        return self.cbam(x)


class UNetTransformer(nn.Module):
    def __init__(
        self,
        *,
        in_channels=3,
        base_channels=64,
        emb_dim=256,
        num_heads=8,
        time_emb_dim=128,
        num_classes:int = 5
    ):
        super().__init__()
        self.channels = in_channels
        self.num_classes = num_classes
        # Time‐MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        )
        self.label_emb = nn.Embedding(num_classes + 1, emb_dim)

        # ---- Encoder ----
        # Stage 1: 3 → 64
        self.enc1    = DoubleConv(in_channels,        base_channels)
        self.trans1  = TransformerBlock(base_channels, emb_dim, num_heads)

        # Stage 2: 64 → 128
        self.enc2    = DoubleConv(base_channels,      base_channels*2)
        self.down2   = nn.Conv2d(base_channels*2,     base_channels*2, 4, stride=2, padding=1)
        self.trans2  = TransformerBlock(base_channels*2, emb_dim, num_heads)

        # Stage 3: 128 → 256
        self.enc3    = DoubleConv(base_channels*2,    base_channels*4)
        self.down3   = nn.Conv2d(base_channels*4,     base_channels*4, 4, stride=2, padding=1)
        self.trans3  = TransformerBlock(base_channels*4, emb_dim, num_heads)

        # Stage 4: 256 → 512
        self.enc4    = DoubleConv(base_channels*4,    base_channels*8)
        self.down4   = nn.Conv2d(base_channels*8,     base_channels*8, 4, stride=2, padding=1)
        self.trans4  = TransformerBlock(base_channels*8, emb_dim, num_heads)

        # ---- Decoder ----
        # Upsample + concat + DoubleConv + Transformer per stage
        # Stage 4→3: 512 → 256, then cat with enc4(256)→512→256
        self.up4       = nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, stride=2, padding=1)
        self.dec4      = DoubleConv(base_channels*4 + base_channels*8, base_channels*4)
        self.trans_up3 = TransformerBlock(base_channels*4, emb_dim, num_heads)

        # Stage 3→2: 256 → 128, then cat with enc3(128)→256→128
        self.up3       = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, stride=2, padding=1)
        self.dec3      = DoubleConv(base_channels*2 + base_channels*4, base_channels*2)
        self.trans_up2 = TransformerBlock(base_channels*2, emb_dim, num_heads)

        # Stage 2→1: 128 → 64, then cat with enc2(64)→128→64
        self.up2       = nn.ConvTranspose2d(base_channels*2, base_channels, 4, stride=2, padding=1)
        self.dec2      = DoubleConv(base_channels + base_channels*2, base_channels)
        self.trans_up1 = TransformerBlock(base_channels, emb_dim, num_heads)

        # Final 1×1 conv to RGB
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t, labels=None):
        B = x.size(0)
        # build time embedding once
        t_emb = self.time_mlp(t)  # [B, emb_dim]

        if labels is None:
            labels = torch.full((B,), self.label_emb.num_embeddings-1, device=x.device, dtype=torch.long)
        label_emb = self.label_emb(labels)

        # —— Encode ——
        e1 = self.enc1(x)               # [B,  64, 16,16]
        e1 = self.trans1(e1, t_emb, label_emb)

        x2 = self.enc2(e1)              # [B, 128,16,16]
        d2 = self.down2(x2)             # [B, 128, 8, 8]
        e2 = self.trans2(d2, t_emb, label_emb)

        x3 = self.enc3(e2)              # [B, 256, 8, 8]
        d3 = self.down3(x3)             # [B, 256, 4, 4]
        e3 = self.trans3(d3, t_emb, label_emb)

        x4 = self.enc4(e3)              # [B, 512, 4, 4]
        b  = self.down4(x4)             # [B, 512, 2, 2]
        b  = self.trans4(b, t_emb, label_emb)

        # —— Decode ——
        u3 = self.up4(b)                           # [B, 256, 4, 4]
        u3 = torch.cat([u3, x4], dim=1)            # [B, 256+512=768, 4,4]
        u3 = self.dec4(u3)                         # [B, 256, 4, 4]
        u3 = self.trans_up3(u3, t_emb, label_emb)

        u2 = self.up3(u3)                          # [B, 128, 8, 8]
        u2 = torch.cat([u2, x3], dim=1)            # [B, 128+256=384, 8,8]
        u2 = self.dec3(u2)                         # [B, 128, 8, 8]
        u2 = self.trans_up2(u2, t_emb, label_emb)

        u1 = self.up2(u2)                          # [B, 64,16,16]
        u1 = torch.cat([u1, x2], dim=1)            # [B, 64+128=192,16,16]
        u1 = self.dec2(u1)                         # [B, 64,16,16]
        u1 = self.trans_up1(u1, t_emb, label_emb)

        return self.out_conv(u1)                   # [B,  3,16,16]
    
model = UNetTransformer(
    in_channels=3,
    base_channels=64,
    emb_dim=256,
    num_heads=4,
    time_emb_dim=128,
    num_classes=5
)

## UVitPixelArt

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = torch.exp(
            torch.arange(half_dim, device=device) * -(math.log(10000) / (half_dim - 1))
        )
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb  # [B, dim]

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
        self.norm2 = nn.LayerNorm(dim)
    def forward(self, x, t_emb=None):
        if t_emb is not None:
            x = x + t_emb.unsqueeze(1)
        a, _ = self.attn(x, x, x)
        x = self.norm1(x + a)
        x = self.norm2(x + self.ff(x))
        return x

class UVitPixelArt(nn.Module):
    def __init__(
        self,
        in_channels=3,
        base_channels=64,
        emb_dim=256,
        enc8_blocks=2,
        bottleneck_blocks=4,
        dec8_blocks=2,
        dec16_blocks=1,
        heads=8,
        time_dim=128,
        num_classes = None
    ):
        super().__init__()
        self.channels = in_channels
        self.num_classes = num_classes
        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes+1, emb_dim)
        else:
            self.label_emb=None
        # time embed
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim), nn.Linear(time_dim, emb_dim), nn.GELU(), nn.Linear(emb_dim, emb_dim)
        )
        # stem 16×16
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.GELU()
        )
        # encoder 16→8
        self.down8 = nn.Conv2d(base_channels, base_channels*2, 2, 2)
        self.patch8_enc = nn.Conv2d(base_channels*2, emb_dim, 3, 2, 1)
                # positional embeddings for 4×4 tokens
        self.pos8_enc = nn.Parameter(torch.randn(1, 4*4, emb_dim))
        # positional embeddings for 8×8 tokens in decoder
        self.pos8_dec = nn.Parameter(torch.randn(1, 8*8, emb_dim))
        self.enc8 = nn.ModuleList([TransformerBlock(emb_dim, heads) for _ in range(enc8_blocks)])
        # encoder 8→4
        self.down4 = nn.Conv2d(base_channels*2, base_channels*4, 2, 2)
        self.patch4 = nn.Conv2d(base_channels*4, emb_dim, 3, 1, 1)
        self.pos4 = nn.Parameter(torch.randn(1, 4*4, emb_dim))
        self.bottleneck = nn.ModuleList([TransformerBlock(emb_dim, heads) for _ in range(bottleneck_blocks)])
        # decoder 4→8
        self.up4 = nn.ConvTranspose2d(emb_dim, base_channels*4, 2, 2)
        self.skip8_proj = nn.Conv2d(base_channels*2, base_channels*4, 1)
        self.patch8_dec = nn.Conv2d(base_channels*4, emb_dim, 3, 1, 1)
        self.dec8 = nn.ModuleList([TransformerBlock(emb_dim, heads) for _ in range(dec8_blocks)])
        # decoder 8→16
        self.up8 = nn.ConvTranspose2d(emb_dim, base_channels*2, 2, 2)
        self.skip16_proj = nn.Conv2d(base_channels, base_channels*2, 1)
        self.patch16 = nn.Conv2d(base_channels*2, emb_dim, 3, 1, 1)
        self.pos16 = nn.Parameter(torch.randn(1, 16*16, emb_dim))
        self.dec16 = nn.ModuleList([TransformerBlock(emb_dim, heads) for _ in range(dec16_blocks)])
        # final conv
        self.final_conv = nn.Sequential(
            nn.Conv2d(emb_dim, base_channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(base_channels, in_channels, 3, padding=1)
        )

    def forward(self, x, t, labels=None):
        B = x.size(0)
        t_emb = self.time_mlp(t)
        if self.label_emb is not None:
            if labels is None:
                labels = torch.full((x.size(0),), self.num_classes, dtype=torch.long, device=x.device)
            label_emb = self.label_emb(labels)
            t_emb = t_emb + label_emb
        # 16×16 stem
        x16 = self.stem(x)
        # 16→8 encode
        f8 = self.down8(x16)
        tok8 = self.patch8_enc(f8)              # [B,emb,4,4]
        tok8 = rearrange(tok8, 'b c h w -> b (h w) c') + self.pos8_enc + t_emb.unsqueeze(1)
        for blk in self.enc8:
            tok8 = blk(tok8, t_emb)
        # 8→4 encode
        f4 = self.down4(f8)
        tok4 = self.patch4(f4)                  # [B,emb,4,4]
        tok4 = rearrange(tok4, 'b c h w -> b (h w) c') + self.pos4 + t_emb.unsqueeze(1)
        for blk in self.bottleneck:
            tok4 = blk(tok4, t_emb)
        # 4→8 decode
        dec4 = rearrange(tok4, 'b (h w) c -> b c h w', h=4, w=4)
        up8 = self.up4(dec4)                    # [B,4C,8,8]
        skip8 = self.skip8_proj(f8)             # [B,4C,8,8]
        x8 = up8 + skip8                        # [B,4C,8,8]
        tok8_d = self.patch8_dec(x8)            # [B,emb,8,8]
        tok8_d = rearrange(tok8_d, 'b c h w -> b (h w) c') + self.pos8_dec + t_emb.unsqueeze(1)
        for blk in self.dec8:
            tok8_d = blk(tok8_d, t_emb)
        # 8→16 decode
        dec8 = rearrange(tok8_d, 'b (h w) c -> b c h w', h=8, w=8)
        up16 = self.up8(dec8)                   # [B,2C,16,16]
        skip16 = self.skip16_proj(x16)          # [B,2C,16,16]
        x16d = up16 + skip16                    # [B,2C,16,16]
        tok16 = self.patch16(x16d)              # [B,emb,16,16]
        tok16 = rearrange(tok16, 'b c h w -> b (h w) c') + self.pos16 + t_emb.unsqueeze(1)
        for blk in self.dec16:
            tok16 = blk(tok16, t_emb)
        # reconstruct and final conv
        dec16 = rearrange(tok16, 'b (h w) c -> b c h w', h=16, w=16)
        return self.final_conv(dec16)

model = UVitPixelArt(
    in_channels=3,
    base_channels=64,
    emb_dim=256,
    enc8_blocks=2,
    bottleneck_blocks=4,
    dec8_blocks=2,
    dec16_blocks=1,
    heads=8,
    time_dim=128,
    num_classes=5
)

## Usage

In [7]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# model = TransformerBackbone(
#     in_channels=3, img_size=16, patch_size=1,
#     emb_dim=1024, depth=8, num_heads=8, time_emb_dim=128
# )


In [8]:
diffusion = GaussianDiffusion(
    model=model, image_size=16, timesteps=1000,
    beta_schedule='cosine', auto_normalize=True,
    lambda_l1=0.0
)

trainer = Trainer(
    diffusion_model=diffusion,
    data_folder = './sprites.npy',
    label_folder = './sprites_labels.npy',
    batch_size=128, lr=1e-4, num_steps=100000,
    grad_accum_steps=1, ema_decay=0.999,
    save_interval=1000, num_samples=25,
    results_folder='./results',
    use_ddim=True, num_ddim_steps=100, eta=0.0,
    use_cfg=True, guidance_scale=2.0
)

In [9]:
total_params = sum(p.numel() for p in model.parameters())
print("Params: ", total_params/1e6, "M")

Params:  24.032241 M


In [10]:
trainer.train()

[Step 10] loss: 1.0097
[Step 20] loss: 1.0041
[Step 30] loss: 1.0029
[Step 40] loss: 1.0054
[Step 50] loss: 0.9968
[Step 60] loss: 0.9021
[Step 70] loss: 0.8073
[Step 80] loss: 0.7477
[Step 90] loss: 0.6716
[Step 100] loss: 0.6232
[Step 110] loss: 0.6287
[Step 120] loss: 0.5492
[Step 130] loss: 0.5477
[Step 140] loss: 0.5221
[Step 150] loss: 0.5076
[Step 160] loss: 0.5034
[Step 170] loss: 0.4463
[Step 180] loss: 0.4113
[Step 190] loss: 0.3837
[Step 200] loss: 0.4074
[Step 210] loss: 0.3843
[Step 220] loss: 0.3535
[Step 230] loss: 0.3606
[Step 240] loss: 0.3433
[Step 250] loss: 0.3507
[Step 260] loss: 0.3040
[Step 270] loss: 0.3516
[Step 280] loss: 0.3101
[Step 290] loss: 0.3406
[Step 300] loss: 0.3709
[Step 310] loss: 0.2901
[Step 320] loss: 0.3505
[Step 330] loss: 0.2807
[Step 340] loss: 0.2986
[Step 350] loss: 0.3130
[Step 360] loss: 0.2920
[Step 370] loss: 0.2945
[Step 380] loss: 0.2776
[Step 390] loss: 0.2877
[Step 400] loss: 0.2774
[Step 410] loss: 0.3103
[Step 420] loss: 0.3085
[

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 1000] loss: 0.2208
[Step 1010] loss: 0.2240
[Step 1020] loss: 0.2216
[Step 1030] loss: 0.2264
[Step 1040] loss: 0.2523
[Step 1050] loss: 0.2191
[Step 1060] loss: 0.2176
[Step 1070] loss: 0.1981
[Step 1080] loss: 0.2202
[Step 1090] loss: 0.2301
[Step 1100] loss: 0.2178
[Step 1110] loss: 0.2010
[Step 1120] loss: 0.2157
[Step 1130] loss: 0.2260
[Step 1140] loss: 0.2085
[Step 1150] loss: 0.2116
[Step 1160] loss: 0.2204
[Step 1170] loss: 0.1828
[Step 1180] loss: 0.2092
[Step 1190] loss: 0.1997
[Step 1200] loss: 0.1879
[Step 1210] loss: 0.2156
[Step 1220] loss: 0.1915
[Step 1230] loss: 0.2224
[Step 1240] loss: 0.2352
[Step 1250] loss: 0.1948
[Step 1260] loss: 0.2276
[Step 1270] loss: 0.2099
[Step 1280] loss: 0.1897
[Step 1290] loss: 0.2093
[Step 1300] loss: 0.2180
[Step 1310] loss: 0.2215
[Step 1320] loss: 0.2171
[Step 1330] loss: 0.1858
[Step 1340] loss: 0.2262
[Step 1350] loss: 0.1957
[Step 1360] loss: 0.2262
[Step 1370] loss: 0.2250
[Step 1380] loss: 0.2032
[Step 1390] loss: 0.2021


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 2000] loss: 0.1974
[Step 2010] loss: 0.1715
[Step 2020] loss: 0.1920
[Step 2030] loss: 0.1834
[Step 2040] loss: 0.1695
[Step 2050] loss: 0.2006
[Step 2060] loss: 0.1806
[Step 2070] loss: 0.2077
[Step 2080] loss: 0.1793
[Step 2090] loss: 0.1847
[Step 2100] loss: 0.1689
[Step 2110] loss: 0.1784
[Step 2120] loss: 0.1939
[Step 2130] loss: 0.1887
[Step 2140] loss: 0.1763
[Step 2150] loss: 0.2099
[Step 2160] loss: 0.1795
[Step 2170] loss: 0.1978
[Step 2180] loss: 0.1659
[Step 2190] loss: 0.1669
[Step 2200] loss: 0.1971
[Step 2210] loss: 0.1812
[Step 2220] loss: 0.1955
[Step 2230] loss: 0.2075
[Step 2240] loss: 0.1921
[Step 2250] loss: 0.1846
[Step 2260] loss: 0.1778
[Step 2270] loss: 0.1903
[Step 2280] loss: 0.1695
[Step 2290] loss: 0.1877
[Step 2300] loss: 0.2047
[Step 2310] loss: 0.1781
[Step 2320] loss: 0.1739
[Step 2330] loss: 0.1917
[Step 2340] loss: 0.1719
[Step 2350] loss: 0.1897
[Step 2360] loss: 0.1682
[Step 2370] loss: 0.1800
[Step 2380] loss: 0.1944
[Step 2390] loss: 0.1628


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 3000] loss: 0.1833
[Step 3010] loss: 0.1556
[Step 3020] loss: 0.1660
[Step 3030] loss: 0.1593
[Step 3040] loss: 0.1705
[Step 3050] loss: 0.1821
[Step 3060] loss: 0.1748
[Step 3070] loss: 0.1886
[Step 3080] loss: 0.1833
[Step 3090] loss: 0.1708
[Step 3100] loss: 0.1630
[Step 3110] loss: 0.1543
[Step 3120] loss: 0.1879
[Step 3130] loss: 0.1827
[Step 3140] loss: 0.1661
[Step 3150] loss: 0.1618
[Step 3160] loss: 0.2103
[Step 3170] loss: 0.1504
[Step 3180] loss: 0.1944
[Step 3190] loss: 0.1675
[Step 3200] loss: 0.1672
[Step 3210] loss: 0.1693
[Step 3220] loss: 0.1663
[Step 3230] loss: 0.1999
[Step 3240] loss: 0.1665
[Step 3250] loss: 0.1839
[Step 3260] loss: 0.1890
[Step 3270] loss: 0.1675
[Step 3280] loss: 0.1887
[Step 3290] loss: 0.1602
[Step 3300] loss: 0.1870
[Step 3310] loss: 0.1726
[Step 3320] loss: 0.1665
[Step 3330] loss: 0.1701
[Step 3340] loss: 0.1762
[Step 3350] loss: 0.1868
[Step 3360] loss: 0.1754
[Step 3370] loss: 0.1834
[Step 3380] loss: 0.1655
[Step 3390] loss: 0.1649


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 4000] loss: 0.1739
[Step 4010] loss: 0.1535
[Step 4020] loss: 0.1661
[Step 4030] loss: 0.1697
[Step 4040] loss: 0.1696
[Step 4050] loss: 0.1653
[Step 4060] loss: 0.1642
[Step 4070] loss: 0.1682
[Step 4080] loss: 0.1826
[Step 4090] loss: 0.1635
[Step 4100] loss: 0.1695
[Step 4110] loss: 0.1446
[Step 4120] loss: 0.1760
[Step 4130] loss: 0.1567
[Step 4140] loss: 0.1676
[Step 4150] loss: 0.1569
[Step 4160] loss: 0.1463
[Step 4170] loss: 0.1497
[Step 4180] loss: 0.1512
[Step 4190] loss: 0.1458
[Step 4200] loss: 0.1623
[Step 4210] loss: 0.1635
[Step 4220] loss: 0.1571
[Step 4230] loss: 0.1637
[Step 4240] loss: 0.1547
[Step 4250] loss: 0.1543
[Step 4260] loss: 0.1554
[Step 4270] loss: 0.1784
[Step 4280] loss: 0.1702
[Step 4290] loss: 0.1559
[Step 4300] loss: 0.1765
[Step 4310] loss: 0.1671
[Step 4320] loss: 0.1517
[Step 4330] loss: 0.1855
[Step 4340] loss: 0.1685
[Step 4350] loss: 0.1620
[Step 4360] loss: 0.1440
[Step 4370] loss: 0.1333
[Step 4380] loss: 0.1835
[Step 4390] loss: 0.1551


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 5000] loss: 0.1465
[Step 5010] loss: 0.1561
[Step 5020] loss: 0.1775
[Step 5030] loss: 0.1577
[Step 5040] loss: 0.1646
[Step 5050] loss: 0.1541
[Step 5060] loss: 0.1360
[Step 5070] loss: 0.1835
[Step 5080] loss: 0.1571
[Step 5090] loss: 0.1599
[Step 5100] loss: 0.1266
[Step 5110] loss: 0.1564
[Step 5120] loss: 0.1578
[Step 5130] loss: 0.1587
[Step 5140] loss: 0.1661
[Step 5150] loss: 0.1502
[Step 5160] loss: 0.1802
[Step 5170] loss: 0.1380
[Step 5180] loss: 0.1596
[Step 5190] loss: 0.1467
[Step 5200] loss: 0.1682
[Step 5210] loss: 0.1501
[Step 5220] loss: 0.1475
[Step 5230] loss: 0.1408
[Step 5240] loss: 0.1264
[Step 5250] loss: 0.1430
[Step 5260] loss: 0.1376
[Step 5270] loss: 0.1504
[Step 5280] loss: 0.1588
[Step 5290] loss: 0.1632
[Step 5300] loss: 0.1779
[Step 5310] loss: 0.1410
[Step 5320] loss: 0.1718
[Step 5330] loss: 0.1671
[Step 5340] loss: 0.1646
[Step 5350] loss: 0.1427
[Step 5360] loss: 0.1492
[Step 5370] loss: 0.1539
[Step 5380] loss: 0.1368
[Step 5390] loss: 0.1611


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 6000] loss: 0.1601
[Step 6010] loss: 0.1449
[Step 6020] loss: 0.1458
[Step 6030] loss: 0.1353
[Step 6040] loss: 0.1107
[Step 6050] loss: 0.1476
[Step 6060] loss: 0.1435
[Step 6070] loss: 0.1499
[Step 6080] loss: 0.1452
[Step 6090] loss: 0.1418
[Step 6100] loss: 0.1445
[Step 6110] loss: 0.1316
[Step 6120] loss: 0.1331
[Step 6130] loss: 0.1374
[Step 6140] loss: 0.1341
[Step 6150] loss: 0.1546
[Step 6160] loss: 0.1532
[Step 6170] loss: 0.1349
[Step 6180] loss: 0.1353
[Step 6190] loss: 0.1383
[Step 6200] loss: 0.1325
[Step 6210] loss: 0.1382
[Step 6220] loss: 0.1548
[Step 6230] loss: 0.1407
[Step 6240] loss: 0.1514
[Step 6250] loss: 0.1291
[Step 6260] loss: 0.1638
[Step 6270] loss: 0.1315
[Step 6280] loss: 0.1235
[Step 6290] loss: 0.1487
[Step 6300] loss: 0.1429
[Step 6310] loss: 0.1435
[Step 6320] loss: 0.1494
[Step 6330] loss: 0.1553
[Step 6340] loss: 0.1452
[Step 6350] loss: 0.1593
[Step 6360] loss: 0.1490
[Step 6370] loss: 0.1324
[Step 6380] loss: 0.1142
[Step 6390] loss: 0.1518


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 7000] loss: 0.1547
[Step 7010] loss: 0.1223
[Step 7020] loss: 0.1282
[Step 7030] loss: 0.1432
[Step 7040] loss: 0.1363
[Step 7050] loss: 0.1302
[Step 7060] loss: 0.1425
[Step 7070] loss: 0.1467
[Step 7080] loss: 0.1477
[Step 7090] loss: 0.1500
[Step 7100] loss: 0.1145
[Step 7110] loss: 0.1419
[Step 7120] loss: 0.1318
[Step 7130] loss: 0.1317
[Step 7140] loss: 0.1410
[Step 7150] loss: 0.1389
[Step 7160] loss: 0.1136
[Step 7170] loss: 0.1233
[Step 7180] loss: 0.1372
[Step 7190] loss: 0.1394
[Step 7200] loss: 0.1431
[Step 7210] loss: 0.1248
[Step 7220] loss: 0.1390
[Step 7230] loss: 0.1347
[Step 7240] loss: 0.1239
[Step 7250] loss: 0.1389
[Step 7260] loss: 0.1378
[Step 7270] loss: 0.1434
[Step 7280] loss: 0.1082
[Step 7290] loss: 0.1229
[Step 7300] loss: 0.1240
[Step 7310] loss: 0.1506
[Step 7320] loss: 0.1214
[Step 7330] loss: 0.1319
[Step 7340] loss: 0.1256
[Step 7350] loss: 0.1380
[Step 7360] loss: 0.1333
[Step 7370] loss: 0.1437
[Step 7380] loss: 0.1176
[Step 7390] loss: 0.1290


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 8000] loss: 0.1336
[Step 8010] loss: 0.1297
[Step 8020] loss: 0.1230
[Step 8030] loss: 0.1337
[Step 8040] loss: 0.1279
[Step 8050] loss: 0.1322
[Step 8060] loss: 0.1173
[Step 8070] loss: 0.1275
[Step 8080] loss: 0.1310
[Step 8090] loss: 0.1440
[Step 8100] loss: 0.1263
[Step 8110] loss: 0.1239
[Step 8120] loss: 0.1173
[Step 8130] loss: 0.1439
[Step 8140] loss: 0.1235
[Step 8150] loss: 0.1139
[Step 8160] loss: 0.1424
[Step 8170] loss: 0.1176
[Step 8180] loss: 0.1298
[Step 8190] loss: 0.1268
[Step 8200] loss: 0.1146
[Step 8210] loss: 0.1278
[Step 8220] loss: 0.1348
[Step 8230] loss: 0.1149
[Step 8240] loss: 0.1175
[Step 8250] loss: 0.1327
[Step 8260] loss: 0.1338
[Step 8270] loss: 0.1302
[Step 8280] loss: 0.1270
[Step 8290] loss: 0.1622
[Step 8300] loss: 0.1205
[Step 8310] loss: 0.1155
[Step 8320] loss: 0.1301
[Step 8330] loss: 0.1332
[Step 8340] loss: 0.1248
[Step 8350] loss: 0.1193
[Step 8360] loss: 0.1202
[Step 8370] loss: 0.1208
[Step 8380] loss: 0.1324
[Step 8390] loss: 0.1323


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 9000] loss: 0.1352
[Step 9010] loss: 0.1186
[Step 9020] loss: 0.1158
[Step 9030] loss: 0.0991
[Step 9040] loss: 0.1245
[Step 9050] loss: 0.1211
[Step 9060] loss: 0.1115
[Step 9070] loss: 0.1197
[Step 9080] loss: 0.1207
[Step 9090] loss: 0.1316
[Step 9100] loss: 0.1159
[Step 9110] loss: 0.1116
[Step 9120] loss: 0.1180
[Step 9130] loss: 0.1102
[Step 9140] loss: 0.1279
[Step 9150] loss: 0.1243
[Step 9160] loss: 0.1211
[Step 9170] loss: 0.1311
[Step 9180] loss: 0.1185
[Step 9190] loss: 0.1197
[Step 9200] loss: 0.1251
[Step 9210] loss: 0.1086
[Step 9220] loss: 0.1371
[Step 9230] loss: 0.1223
[Step 9240] loss: 0.1158
[Step 9250] loss: 0.1132
[Step 9260] loss: 0.1313
[Step 9270] loss: 0.1170
[Step 9280] loss: 0.1249
[Step 9290] loss: 0.1237
[Step 9300] loss: 0.1243
[Step 9310] loss: 0.1294
[Step 9320] loss: 0.1054
[Step 9330] loss: 0.1327
[Step 9340] loss: 0.1067
[Step 9350] loss: 0.1117
[Step 9360] loss: 0.1160
[Step 9370] loss: 0.1116
[Step 9380] loss: 0.1188
[Step 9390] loss: 0.1084


DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 10000] loss: 0.1069
[Step 10010] loss: 0.1079
[Step 10020] loss: 0.1048
[Step 10030] loss: 0.1078
[Step 10040] loss: 0.1190
[Step 10050] loss: 0.1047
[Step 10060] loss: 0.1140
[Step 10070] loss: 0.1247
[Step 10080] loss: 0.0997
[Step 10090] loss: 0.1071
[Step 10100] loss: 0.1282
[Step 10110] loss: 0.1164
[Step 10120] loss: 0.1171
[Step 10130] loss: 0.0935
[Step 10140] loss: 0.1190
[Step 10150] loss: 0.1079
[Step 10160] loss: 0.1114
[Step 10170] loss: 0.1123
[Step 10180] loss: 0.1057
[Step 10190] loss: 0.1128
[Step 10200] loss: 0.1191
[Step 10210] loss: 0.1165
[Step 10220] loss: 0.1303
[Step 10230] loss: 0.1157
[Step 10240] loss: 0.1234
[Step 10250] loss: 0.1126
[Step 10260] loss: 0.1282
[Step 10270] loss: 0.1344
[Step 10280] loss: 0.1130
[Step 10290] loss: 0.1268
[Step 10300] loss: 0.1260
[Step 10310] loss: 0.1305
[Step 10320] loss: 0.1050
[Step 10330] loss: 0.1184
[Step 10340] loss: 0.1075
[Step 10350] loss: 0.1276
[Step 10360] loss: 0.1033
[Step 10370] loss: 0.1354
[Step 10380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 11000] loss: 0.1143
[Step 11010] loss: 0.1089
[Step 11020] loss: 0.1036
[Step 11030] loss: 0.1178
[Step 11040] loss: 0.1048
[Step 11050] loss: 0.1210
[Step 11060] loss: 0.1153
[Step 11070] loss: 0.1113
[Step 11080] loss: 0.0919
[Step 11090] loss: 0.1042
[Step 11100] loss: 0.1014
[Step 11110] loss: 0.1081
[Step 11120] loss: 0.1099
[Step 11130] loss: 0.1160
[Step 11140] loss: 0.1138
[Step 11150] loss: 0.1133
[Step 11160] loss: 0.1002
[Step 11170] loss: 0.1217
[Step 11180] loss: 0.1168
[Step 11190] loss: 0.1370
[Step 11200] loss: 0.1208
[Step 11210] loss: 0.0972
[Step 11220] loss: 0.1227
[Step 11230] loss: 0.1049
[Step 11240] loss: 0.1108
[Step 11250] loss: 0.1025
[Step 11260] loss: 0.1040
[Step 11270] loss: 0.1199
[Step 11280] loss: 0.1043
[Step 11290] loss: 0.1136
[Step 11300] loss: 0.1158
[Step 11310] loss: 0.1038
[Step 11320] loss: 0.1269
[Step 11330] loss: 0.1074
[Step 11340] loss: 0.1047
[Step 11350] loss: 0.1242
[Step 11360] loss: 0.1217
[Step 11370] loss: 0.0960
[Step 11380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 12000] loss: 0.1041
[Step 12010] loss: 0.1005
[Step 12020] loss: 0.1002
[Step 12030] loss: 0.1149
[Step 12040] loss: 0.1020
[Step 12050] loss: 0.0899
[Step 12060] loss: 0.1217
[Step 12070] loss: 0.1104
[Step 12080] loss: 0.1085
[Step 12090] loss: 0.1141
[Step 12100] loss: 0.0940
[Step 12110] loss: 0.1154
[Step 12120] loss: 0.1122
[Step 12130] loss: 0.1040
[Step 12140] loss: 0.1027
[Step 12150] loss: 0.1139
[Step 12160] loss: 0.1073
[Step 12170] loss: 0.1138
[Step 12180] loss: 0.0932
[Step 12190] loss: 0.1001
[Step 12200] loss: 0.1217
[Step 12210] loss: 0.1046
[Step 12220] loss: 0.1223
[Step 12230] loss: 0.0994
[Step 12240] loss: 0.1006
[Step 12250] loss: 0.1005
[Step 12260] loss: 0.1052
[Step 12270] loss: 0.1139
[Step 12280] loss: 0.0952
[Step 12290] loss: 0.1249
[Step 12300] loss: 0.1094
[Step 12310] loss: 0.1070
[Step 12320] loss: 0.1017
[Step 12330] loss: 0.1061
[Step 12340] loss: 0.1094
[Step 12350] loss: 0.1062
[Step 12360] loss: 0.0960
[Step 12370] loss: 0.1005
[Step 12380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 13000] loss: 0.1084
[Step 13010] loss: 0.1098
[Step 13020] loss: 0.1285
[Step 13030] loss: 0.0972
[Step 13040] loss: 0.1062
[Step 13050] loss: 0.0913
[Step 13060] loss: 0.1103
[Step 13070] loss: 0.0906
[Step 13080] loss: 0.0977
[Step 13090] loss: 0.0898
[Step 13100] loss: 0.0995
[Step 13110] loss: 0.1061
[Step 13120] loss: 0.1090
[Step 13130] loss: 0.0790
[Step 13140] loss: 0.0983
[Step 13150] loss: 0.1182
[Step 13160] loss: 0.1044
[Step 13170] loss: 0.0878
[Step 13180] loss: 0.1127
[Step 13190] loss: 0.1134
[Step 13200] loss: 0.1050
[Step 13210] loss: 0.0870
[Step 13220] loss: 0.1024
[Step 13230] loss: 0.0860
[Step 13240] loss: 0.1109
[Step 13250] loss: 0.1110
[Step 13260] loss: 0.1230
[Step 13270] loss: 0.0957
[Step 13280] loss: 0.1077
[Step 13290] loss: 0.1162
[Step 13300] loss: 0.0916
[Step 13310] loss: 0.0905
[Step 13320] loss: 0.1105
[Step 13330] loss: 0.1156
[Step 13340] loss: 0.0989
[Step 13350] loss: 0.1122
[Step 13360] loss: 0.1057
[Step 13370] loss: 0.0892
[Step 13380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 14000] loss: 0.1052
[Step 14010] loss: 0.0969
[Step 14020] loss: 0.1095
[Step 14030] loss: 0.0895
[Step 14040] loss: 0.0996
[Step 14050] loss: 0.0885
[Step 14060] loss: 0.1015
[Step 14070] loss: 0.1020
[Step 14080] loss: 0.1068
[Step 14090] loss: 0.1005
[Step 14100] loss: 0.0832
[Step 14110] loss: 0.1168
[Step 14120] loss: 0.0874
[Step 14130] loss: 0.0964
[Step 14140] loss: 0.0831
[Step 14150] loss: 0.0858
[Step 14160] loss: 0.0860
[Step 14170] loss: 0.1003
[Step 14180] loss: 0.1070
[Step 14190] loss: 0.0876
[Step 14200] loss: 0.0852
[Step 14210] loss: 0.0954
[Step 14220] loss: 0.0965
[Step 14230] loss: 0.1128
[Step 14240] loss: 0.1037
[Step 14250] loss: 0.1064
[Step 14260] loss: 0.1150
[Step 14270] loss: 0.0869
[Step 14280] loss: 0.0992
[Step 14290] loss: 0.1105
[Step 14300] loss: 0.1053
[Step 14310] loss: 0.1005
[Step 14320] loss: 0.1156
[Step 14330] loss: 0.0968
[Step 14340] loss: 0.1019
[Step 14350] loss: 0.1050
[Step 14360] loss: 0.0984
[Step 14370] loss: 0.0980
[Step 14380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 15000] loss: 0.0991
[Step 15010] loss: 0.1104
[Step 15020] loss: 0.1000
[Step 15030] loss: 0.0896
[Step 15040] loss: 0.1048
[Step 15050] loss: 0.0932
[Step 15060] loss: 0.1022
[Step 15070] loss: 0.1131
[Step 15080] loss: 0.0792
[Step 15090] loss: 0.0894
[Step 15100] loss: 0.0819
[Step 15110] loss: 0.0934
[Step 15120] loss: 0.1155
[Step 15130] loss: 0.0983
[Step 15140] loss: 0.0875
[Step 15150] loss: 0.1119
[Step 15160] loss: 0.0780
[Step 15170] loss: 0.0897
[Step 15180] loss: 0.0930
[Step 15190] loss: 0.0895
[Step 15200] loss: 0.0906
[Step 15210] loss: 0.0973
[Step 15220] loss: 0.0945
[Step 15230] loss: 0.0786
[Step 15240] loss: 0.0978
[Step 15250] loss: 0.0884
[Step 15260] loss: 0.0997
[Step 15270] loss: 0.0981
[Step 15280] loss: 0.0942
[Step 15290] loss: 0.0995
[Step 15300] loss: 0.1008
[Step 15310] loss: 0.0989
[Step 15320] loss: 0.1024
[Step 15330] loss: 0.1074
[Step 15340] loss: 0.0927
[Step 15350] loss: 0.1068
[Step 15360] loss: 0.1012
[Step 15370] loss: 0.0836
[Step 15380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 16000] loss: 0.0862
[Step 16010] loss: 0.0814
[Step 16020] loss: 0.0838
[Step 16030] loss: 0.0983
[Step 16040] loss: 0.0939
[Step 16050] loss: 0.0947
[Step 16060] loss: 0.0926
[Step 16070] loss: 0.1108
[Step 16080] loss: 0.0931
[Step 16090] loss: 0.0887
[Step 16100] loss: 0.0950
[Step 16110] loss: 0.1063
[Step 16120] loss: 0.0998
[Step 16130] loss: 0.0913
[Step 16140] loss: 0.0829
[Step 16150] loss: 0.0969
[Step 16160] loss: 0.1068
[Step 16170] loss: 0.1002
[Step 16180] loss: 0.0887
[Step 16190] loss: 0.0778
[Step 16200] loss: 0.0927
[Step 16210] loss: 0.1077
[Step 16220] loss: 0.0882
[Step 16230] loss: 0.0932
[Step 16240] loss: 0.0938
[Step 16250] loss: 0.0885
[Step 16260] loss: 0.0919
[Step 16270] loss: 0.0944
[Step 16280] loss: 0.0980
[Step 16290] loss: 0.0857
[Step 16300] loss: 0.1023
[Step 16310] loss: 0.0778
[Step 16320] loss: 0.0945
[Step 16330] loss: 0.0750
[Step 16340] loss: 0.0732
[Step 16350] loss: 0.0945
[Step 16360] loss: 0.0927
[Step 16370] loss: 0.0974
[Step 16380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 17000] loss: 0.0910
[Step 17010] loss: 0.1138
[Step 17020] loss: 0.1061
[Step 17030] loss: 0.1112
[Step 17040] loss: 0.0852
[Step 17050] loss: 0.0895
[Step 17060] loss: 0.1042
[Step 17070] loss: 0.0866
[Step 17080] loss: 0.0933
[Step 17090] loss: 0.1035
[Step 17100] loss: 0.0938
[Step 17110] loss: 0.0826
[Step 17120] loss: 0.1068
[Step 17130] loss: 0.0807
[Step 17140] loss: 0.0907
[Step 17150] loss: 0.0987
[Step 17160] loss: 0.0993
[Step 17170] loss: 0.0947
[Step 17180] loss: 0.0984
[Step 17190] loss: 0.1069
[Step 17200] loss: 0.0844
[Step 17210] loss: 0.0805
[Step 17220] loss: 0.0879
[Step 17230] loss: 0.0908
[Step 17240] loss: 0.0929
[Step 17250] loss: 0.0953
[Step 17260] loss: 0.0879
[Step 17270] loss: 0.0896
[Step 17280] loss: 0.0771
[Step 17290] loss: 0.0876
[Step 17300] loss: 0.0891
[Step 17310] loss: 0.0859
[Step 17320] loss: 0.0899
[Step 17330] loss: 0.0999
[Step 17340] loss: 0.0895
[Step 17350] loss: 0.1007
[Step 17360] loss: 0.0822
[Step 17370] loss: 0.0882
[Step 17380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 18000] loss: 0.0916
[Step 18010] loss: 0.0752
[Step 18020] loss: 0.0924
[Step 18030] loss: 0.0973
[Step 18040] loss: 0.0898
[Step 18050] loss: 0.0871
[Step 18060] loss: 0.0895
[Step 18070] loss: 0.0783
[Step 18080] loss: 0.1003
[Step 18090] loss: 0.0857
[Step 18100] loss: 0.0872
[Step 18110] loss: 0.0957
[Step 18120] loss: 0.0775
[Step 18130] loss: 0.1001
[Step 18140] loss: 0.0781
[Step 18150] loss: 0.0855
[Step 18160] loss: 0.0764
[Step 18170] loss: 0.0889
[Step 18180] loss: 0.0897
[Step 18190] loss: 0.0855
[Step 18200] loss: 0.0938
[Step 18210] loss: 0.1015
[Step 18220] loss: 0.0884
[Step 18230] loss: 0.1029
[Step 18240] loss: 0.0855
[Step 18250] loss: 0.0734
[Step 18260] loss: 0.0981
[Step 18270] loss: 0.0923
[Step 18280] loss: 0.0858
[Step 18290] loss: 0.0778
[Step 18300] loss: 0.0923
[Step 18310] loss: 0.0814
[Step 18320] loss: 0.0976
[Step 18330] loss: 0.0833
[Step 18340] loss: 0.0905
[Step 18350] loss: 0.0835
[Step 18360] loss: 0.0818
[Step 18370] loss: 0.0810
[Step 18380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 19000] loss: 0.0862
[Step 19010] loss: 0.0932
[Step 19020] loss: 0.0806
[Step 19030] loss: 0.0935
[Step 19040] loss: 0.1028
[Step 19050] loss: 0.0823
[Step 19060] loss: 0.0843
[Step 19070] loss: 0.0929
[Step 19080] loss: 0.0848
[Step 19090] loss: 0.0836
[Step 19100] loss: 0.0918
[Step 19110] loss: 0.0959
[Step 19120] loss: 0.0824
[Step 19130] loss: 0.0831
[Step 19140] loss: 0.0799
[Step 19150] loss: 0.0815
[Step 19160] loss: 0.0770
[Step 19170] loss: 0.0858
[Step 19180] loss: 0.0768
[Step 19190] loss: 0.0823
[Step 19200] loss: 0.0905
[Step 19210] loss: 0.0843
[Step 19220] loss: 0.0848
[Step 19230] loss: 0.0922
[Step 19240] loss: 0.0950
[Step 19250] loss: 0.0851
[Step 19260] loss: 0.0864
[Step 19270] loss: 0.0762
[Step 19280] loss: 0.0909
[Step 19290] loss: 0.0806
[Step 19300] loss: 0.0739
[Step 19310] loss: 0.0837
[Step 19320] loss: 0.0836
[Step 19330] loss: 0.0710
[Step 19340] loss: 0.0944
[Step 19350] loss: 0.0937
[Step 19360] loss: 0.1082
[Step 19370] loss: 0.0836
[Step 19380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 20000] loss: 0.1023
[Step 20010] loss: 0.0726
[Step 20020] loss: 0.0739
[Step 20030] loss: 0.0828
[Step 20040] loss: 0.0795
[Step 20050] loss: 0.0786
[Step 20060] loss: 0.0944
[Step 20070] loss: 0.0940
[Step 20080] loss: 0.0960
[Step 20090] loss: 0.0913
[Step 20100] loss: 0.0940
[Step 20110] loss: 0.0873
[Step 20120] loss: 0.0728
[Step 20130] loss: 0.0875
[Step 20140] loss: 0.0797
[Step 20150] loss: 0.0772
[Step 20160] loss: 0.0907
[Step 20170] loss: 0.0892
[Step 20180] loss: 0.0892
[Step 20190] loss: 0.0723
[Step 20200] loss: 0.0959
[Step 20210] loss: 0.0827
[Step 20220] loss: 0.1048
[Step 20230] loss: 0.0963
[Step 20240] loss: 0.0935
[Step 20250] loss: 0.0833
[Step 20260] loss: 0.0796
[Step 20270] loss: 0.0907
[Step 20280] loss: 0.0875
[Step 20290] loss: 0.0886
[Step 20300] loss: 0.0980
[Step 20310] loss: 0.0900
[Step 20320] loss: 0.0902
[Step 20330] loss: 0.0866
[Step 20340] loss: 0.0896
[Step 20350] loss: 0.0770
[Step 20360] loss: 0.0831
[Step 20370] loss: 0.0892
[Step 20380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 21000] loss: 0.0770
[Step 21010] loss: 0.0781
[Step 21020] loss: 0.0825
[Step 21030] loss: 0.0787
[Step 21040] loss: 0.0800
[Step 21050] loss: 0.0898
[Step 21060] loss: 0.0724
[Step 21070] loss: 0.1076
[Step 21080] loss: 0.0816
[Step 21090] loss: 0.1025
[Step 21100] loss: 0.0859
[Step 21110] loss: 0.0923
[Step 21120] loss: 0.0867
[Step 21130] loss: 0.1004
[Step 21140] loss: 0.0824
[Step 21150] loss: 0.0927
[Step 21160] loss: 0.0740
[Step 21170] loss: 0.0866
[Step 21180] loss: 0.0902
[Step 21190] loss: 0.0995
[Step 21200] loss: 0.0862
[Step 21210] loss: 0.0841
[Step 21220] loss: 0.0929
[Step 21230] loss: 0.1151
[Step 21240] loss: 0.0913
[Step 21250] loss: 0.0959
[Step 21260] loss: 0.0795
[Step 21270] loss: 0.0812
[Step 21280] loss: 0.0877
[Step 21290] loss: 0.0993
[Step 21300] loss: 0.0958
[Step 21310] loss: 0.0953
[Step 21320] loss: 0.0725
[Step 21330] loss: 0.0932
[Step 21340] loss: 0.0745
[Step 21350] loss: 0.0740
[Step 21360] loss: 0.0913
[Step 21370] loss: 0.0815
[Step 21380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 22000] loss: 0.0846
[Step 22010] loss: 0.0766
[Step 22020] loss: 0.0831
[Step 22030] loss: 0.0784
[Step 22040] loss: 0.0763
[Step 22050] loss: 0.0684
[Step 22060] loss: 0.0779
[Step 22070] loss: 0.0665
[Step 22080] loss: 0.0859
[Step 22090] loss: 0.0775
[Step 22100] loss: 0.0723
[Step 22110] loss: 0.0711
[Step 22120] loss: 0.0890
[Step 22130] loss: 0.0896
[Step 22140] loss: 0.0833
[Step 22150] loss: 0.0821
[Step 22160] loss: 0.0790
[Step 22170] loss: 0.0840
[Step 22180] loss: 0.0919
[Step 22190] loss: 0.0800
[Step 22200] loss: 0.0949
[Step 22210] loss: 0.0772
[Step 22220] loss: 0.0734
[Step 22230] loss: 0.0878
[Step 22240] loss: 0.0988
[Step 22250] loss: 0.0754
[Step 22260] loss: 0.0968
[Step 22270] loss: 0.0767
[Step 22280] loss: 0.0944
[Step 22290] loss: 0.0816
[Step 22300] loss: 0.0731
[Step 22310] loss: 0.0862
[Step 22320] loss: 0.0893
[Step 22330] loss: 0.0798
[Step 22340] loss: 0.0875
[Step 22350] loss: 0.0801
[Step 22360] loss: 0.0901
[Step 22370] loss: 0.0703
[Step 22380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 23000] loss: 0.0774
[Step 23010] loss: 0.0782
[Step 23020] loss: 0.0841
[Step 23030] loss: 0.0932
[Step 23040] loss: 0.0840
[Step 23050] loss: 0.0854
[Step 23060] loss: 0.0826
[Step 23070] loss: 0.0762
[Step 23080] loss: 0.0691
[Step 23090] loss: 0.0782
[Step 23100] loss: 0.0715
[Step 23110] loss: 0.0954
[Step 23120] loss: 0.0818
[Step 23130] loss: 0.0759
[Step 23140] loss: 0.0756
[Step 23150] loss: 0.0826
[Step 23160] loss: 0.0706
[Step 23170] loss: 0.0811
[Step 23180] loss: 0.0877
[Step 23190] loss: 0.0820
[Step 23200] loss: 0.0810
[Step 23210] loss: 0.0754
[Step 23220] loss: 0.0791
[Step 23230] loss: 0.0884
[Step 23240] loss: 0.0705
[Step 23250] loss: 0.0686
[Step 23260] loss: 0.0876
[Step 23270] loss: 0.0880
[Step 23280] loss: 0.0668
[Step 23290] loss: 0.0890
[Step 23300] loss: 0.0722
[Step 23310] loss: 0.0798
[Step 23320] loss: 0.0625
[Step 23330] loss: 0.0839
[Step 23340] loss: 0.0828
[Step 23350] loss: 0.0799
[Step 23360] loss: 0.0736
[Step 23370] loss: 0.0846
[Step 23380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 24000] loss: 0.0734
[Step 24010] loss: 0.0831
[Step 24020] loss: 0.0924
[Step 24030] loss: 0.0860
[Step 24040] loss: 0.0879
[Step 24050] loss: 0.0746
[Step 24060] loss: 0.0761
[Step 24070] loss: 0.0797
[Step 24080] loss: 0.0698
[Step 24090] loss: 0.0898
[Step 24100] loss: 0.0815
[Step 24110] loss: 0.0884
[Step 24120] loss: 0.0792
[Step 24130] loss: 0.0664
[Step 24140] loss: 0.0725
[Step 24150] loss: 0.0719
[Step 24160] loss: 0.0898
[Step 24170] loss: 0.0792
[Step 24180] loss: 0.0801
[Step 24190] loss: 0.0885
[Step 24200] loss: 0.0805
[Step 24210] loss: 0.0870
[Step 24220] loss: 0.0748
[Step 24230] loss: 0.0764
[Step 24240] loss: 0.0808
[Step 24250] loss: 0.1043
[Step 24260] loss: 0.0719
[Step 24270] loss: 0.0881
[Step 24280] loss: 0.0752
[Step 24290] loss: 0.0847
[Step 24300] loss: 0.0735
[Step 24310] loss: 0.0592
[Step 24320] loss: 0.0657
[Step 24330] loss: 0.0834
[Step 24340] loss: 0.0747
[Step 24350] loss: 0.0730
[Step 24360] loss: 0.0720
[Step 24370] loss: 0.0707
[Step 24380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 25000] loss: 0.1010
[Step 25010] loss: 0.0826
[Step 25020] loss: 0.0835
[Step 25030] loss: 0.0735
[Step 25040] loss: 0.0738
[Step 25050] loss: 0.0671
[Step 25060] loss: 0.0926
[Step 25070] loss: 0.0702
[Step 25080] loss: 0.0862
[Step 25090] loss: 0.0655
[Step 25100] loss: 0.0838
[Step 25110] loss: 0.0763
[Step 25120] loss: 0.0772
[Step 25130] loss: 0.0884
[Step 25140] loss: 0.0862
[Step 25150] loss: 0.0893
[Step 25160] loss: 0.0622
[Step 25170] loss: 0.0695
[Step 25180] loss: 0.0638
[Step 25190] loss: 0.0664
[Step 25200] loss: 0.0849
[Step 25210] loss: 0.0925
[Step 25220] loss: 0.0761
[Step 25230] loss: 0.0720
[Step 25240] loss: 0.0838
[Step 25250] loss: 0.0687
[Step 25260] loss: 0.0711
[Step 25270] loss: 0.0771
[Step 25280] loss: 0.0832
[Step 25290] loss: 0.0708
[Step 25300] loss: 0.0901
[Step 25310] loss: 0.0738
[Step 25320] loss: 0.0825
[Step 25330] loss: 0.0706
[Step 25340] loss: 0.0698
[Step 25350] loss: 0.0982
[Step 25360] loss: 0.0939
[Step 25370] loss: 0.0719
[Step 25380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 26000] loss: 0.0887
[Step 26010] loss: 0.0822
[Step 26020] loss: 0.0831
[Step 26030] loss: 0.0761
[Step 26040] loss: 0.0705
[Step 26050] loss: 0.0924
[Step 26060] loss: 0.0728
[Step 26070] loss: 0.0804
[Step 26080] loss: 0.0812
[Step 26090] loss: 0.0953
[Step 26100] loss: 0.0835
[Step 26110] loss: 0.0821
[Step 26120] loss: 0.0695
[Step 26130] loss: 0.0967
[Step 26140] loss: 0.0812
[Step 26150] loss: 0.0590
[Step 26160] loss: 0.0720
[Step 26170] loss: 0.0885
[Step 26180] loss: 0.0853
[Step 26190] loss: 0.0672
[Step 26200] loss: 0.0738
[Step 26210] loss: 0.0775
[Step 26220] loss: 0.0640
[Step 26230] loss: 0.0700
[Step 26240] loss: 0.0684
[Step 26250] loss: 0.0718
[Step 26260] loss: 0.0687
[Step 26270] loss: 0.0855
[Step 26280] loss: 0.0727
[Step 26290] loss: 0.0845
[Step 26300] loss: 0.0718
[Step 26310] loss: 0.0655
[Step 26320] loss: 0.0768
[Step 26330] loss: 0.0626
[Step 26340] loss: 0.0785
[Step 26350] loss: 0.0605
[Step 26360] loss: 0.0792
[Step 26370] loss: 0.0845
[Step 26380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 27000] loss: 0.0735
[Step 27010] loss: 0.0694
[Step 27020] loss: 0.0899
[Step 27030] loss: 0.0692
[Step 27040] loss: 0.0697
[Step 27050] loss: 0.0753
[Step 27060] loss: 0.0763
[Step 27070] loss: 0.0844
[Step 27080] loss: 0.0798
[Step 27090] loss: 0.0644
[Step 27100] loss: 0.0854
[Step 27110] loss: 0.0635
[Step 27120] loss: 0.0820
[Step 27130] loss: 0.0664
[Step 27140] loss: 0.0707
[Step 27150] loss: 0.0675
[Step 27160] loss: 0.0649
[Step 27170] loss: 0.0755
[Step 27180] loss: 0.0929
[Step 27190] loss: 0.0768
[Step 27200] loss: 0.0656
[Step 27210] loss: 0.0638
[Step 27220] loss: 0.0714
[Step 27230] loss: 0.0654
[Step 27240] loss: 0.0666
[Step 27250] loss: 0.0803
[Step 27260] loss: 0.0735
[Step 27270] loss: 0.0702
[Step 27280] loss: 0.0642
[Step 27290] loss: 0.0700
[Step 27300] loss: 0.0876
[Step 27310] loss: 0.0795
[Step 27320] loss: 0.0783
[Step 27330] loss: 0.0748
[Step 27340] loss: 0.0849
[Step 27350] loss: 0.0843
[Step 27360] loss: 0.0700
[Step 27370] loss: 0.0688
[Step 27380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 28000] loss: 0.0700
[Step 28010] loss: 0.0646
[Step 28020] loss: 0.0817
[Step 28030] loss: 0.0663
[Step 28040] loss: 0.0697
[Step 28050] loss: 0.0748
[Step 28060] loss: 0.0729
[Step 28070] loss: 0.0654
[Step 28080] loss: 0.0714
[Step 28090] loss: 0.0597
[Step 28100] loss: 0.0858
[Step 28110] loss: 0.0830
[Step 28120] loss: 0.0867
[Step 28130] loss: 0.0726
[Step 28140] loss: 0.0748
[Step 28150] loss: 0.0699
[Step 28160] loss: 0.0721
[Step 28170] loss: 0.0747
[Step 28180] loss: 0.0839
[Step 28190] loss: 0.0771
[Step 28200] loss: 0.0794
[Step 28210] loss: 0.0705
[Step 28220] loss: 0.0856
[Step 28230] loss: 0.0911
[Step 28240] loss: 0.0765
[Step 28250] loss: 0.0756
[Step 28260] loss: 0.0737
[Step 28270] loss: 0.0693
[Step 28280] loss: 0.0812
[Step 28290] loss: 0.0795
[Step 28300] loss: 0.0718
[Step 28310] loss: 0.0784
[Step 28320] loss: 0.0745
[Step 28330] loss: 0.0691
[Step 28340] loss: 0.0758
[Step 28350] loss: 0.0613
[Step 28360] loss: 0.0748
[Step 28370] loss: 0.0678
[Step 28380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 29000] loss: 0.0765
[Step 29010] loss: 0.0719
[Step 29020] loss: 0.0657
[Step 29030] loss: 0.0791
[Step 29040] loss: 0.0773
[Step 29050] loss: 0.0763
[Step 29060] loss: 0.0874
[Step 29070] loss: 0.0595
[Step 29080] loss: 0.0779
[Step 29090] loss: 0.0811
[Step 29100] loss: 0.0789
[Step 29110] loss: 0.0643
[Step 29120] loss: 0.0854
[Step 29130] loss: 0.0847
[Step 29140] loss: 0.0847
[Step 29150] loss: 0.0847
[Step 29160] loss: 0.0816
[Step 29170] loss: 0.0700
[Step 29180] loss: 0.0721
[Step 29190] loss: 0.0676
[Step 29200] loss: 0.0814
[Step 29210] loss: 0.0681
[Step 29220] loss: 0.0729
[Step 29230] loss: 0.0927
[Step 29240] loss: 0.0708
[Step 29250] loss: 0.0755
[Step 29260] loss: 0.0662
[Step 29270] loss: 0.0759
[Step 29280] loss: 0.0784
[Step 29290] loss: 0.0893
[Step 29300] loss: 0.0707
[Step 29310] loss: 0.0608
[Step 29320] loss: 0.0773
[Step 29330] loss: 0.0655
[Step 29340] loss: 0.0661
[Step 29350] loss: 0.0828
[Step 29360] loss: 0.0631
[Step 29370] loss: 0.0788
[Step 29380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 30000] loss: 0.0816
[Step 30010] loss: 0.0696
[Step 30020] loss: 0.0756
[Step 30030] loss: 0.0908
[Step 30040] loss: 0.0692
[Step 30050] loss: 0.0837
[Step 30060] loss: 0.0782
[Step 30070] loss: 0.0807
[Step 30080] loss: 0.0816
[Step 30090] loss: 0.0761
[Step 30100] loss: 0.0711
[Step 30110] loss: 0.0925
[Step 30120] loss: 0.0834
[Step 30130] loss: 0.0671
[Step 30140] loss: 0.0838
[Step 30150] loss: 0.0648
[Step 30160] loss: 0.0804
[Step 30170] loss: 0.0783
[Step 30180] loss: 0.0731
[Step 30190] loss: 0.0803
[Step 30200] loss: 0.0680
[Step 30210] loss: 0.0648
[Step 30220] loss: 0.0678
[Step 30230] loss: 0.0671
[Step 30240] loss: 0.0782
[Step 30250] loss: 0.0630
[Step 30260] loss: 0.0710
[Step 30270] loss: 0.0660
[Step 30280] loss: 0.0605
[Step 30290] loss: 0.0739
[Step 30300] loss: 0.0759
[Step 30310] loss: 0.0658
[Step 30320] loss: 0.0650
[Step 30330] loss: 0.0855
[Step 30340] loss: 0.0608
[Step 30350] loss: 0.0700
[Step 30360] loss: 0.0667
[Step 30370] loss: 0.0773
[Step 30380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 31000] loss: 0.0852
[Step 31010] loss: 0.0661
[Step 31020] loss: 0.0737
[Step 31030] loss: 0.0747
[Step 31040] loss: 0.0614
[Step 31050] loss: 0.0707
[Step 31060] loss: 0.0738
[Step 31070] loss: 0.0707
[Step 31080] loss: 0.0597
[Step 31090] loss: 0.0765
[Step 31100] loss: 0.0552
[Step 31110] loss: 0.0578
[Step 31120] loss: 0.0631
[Step 31130] loss: 0.0819
[Step 31140] loss: 0.0639
[Step 31150] loss: 0.0723
[Step 31160] loss: 0.0650
[Step 31170] loss: 0.0675
[Step 31180] loss: 0.0847
[Step 31190] loss: 0.0523
[Step 31200] loss: 0.0715
[Step 31210] loss: 0.0780
[Step 31220] loss: 0.0607
[Step 31230] loss: 0.0551
[Step 31240] loss: 0.0664
[Step 31250] loss: 0.0718
[Step 31260] loss: 0.0718
[Step 31270] loss: 0.0715
[Step 31280] loss: 0.0756
[Step 31290] loss: 0.0687
[Step 31300] loss: 0.0701
[Step 31310] loss: 0.0706
[Step 31320] loss: 0.0736
[Step 31330] loss: 0.0714
[Step 31340] loss: 0.0828
[Step 31350] loss: 0.0949
[Step 31360] loss: 0.0607
[Step 31370] loss: 0.0647
[Step 31380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 32000] loss: 0.0802
[Step 32010] loss: 0.0809
[Step 32020] loss: 0.0705
[Step 32030] loss: 0.0565
[Step 32040] loss: 0.0807
[Step 32050] loss: 0.0751
[Step 32060] loss: 0.0777
[Step 32070] loss: 0.0689
[Step 32080] loss: 0.0863
[Step 32090] loss: 0.0598
[Step 32100] loss: 0.0776
[Step 32110] loss: 0.0743
[Step 32120] loss: 0.0753
[Step 32130] loss: 0.0625
[Step 32140] loss: 0.0680
[Step 32150] loss: 0.0718
[Step 32160] loss: 0.0889
[Step 32170] loss: 0.0756
[Step 32180] loss: 0.0744
[Step 32190] loss: 0.0726
[Step 32200] loss: 0.0673
[Step 32210] loss: 0.0768
[Step 32220] loss: 0.0548
[Step 32230] loss: 0.0878
[Step 32240] loss: 0.0718
[Step 32250] loss: 0.0653
[Step 32260] loss: 0.0648
[Step 32270] loss: 0.0795
[Step 32280] loss: 0.0664
[Step 32290] loss: 0.0761
[Step 32300] loss: 0.0735
[Step 32310] loss: 0.0581
[Step 32320] loss: 0.0600
[Step 32330] loss: 0.0773
[Step 32340] loss: 0.0612
[Step 32350] loss: 0.0739
[Step 32360] loss: 0.0764
[Step 32370] loss: 0.0724
[Step 32380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 33000] loss: 0.0646
[Step 33010] loss: 0.0788
[Step 33020] loss: 0.0653
[Step 33030] loss: 0.0686
[Step 33040] loss: 0.0676
[Step 33050] loss: 0.0757
[Step 33060] loss: 0.0588
[Step 33070] loss: 0.0648
[Step 33080] loss: 0.0700
[Step 33090] loss: 0.0726
[Step 33100] loss: 0.0656
[Step 33110] loss: 0.0739
[Step 33120] loss: 0.0581
[Step 33130] loss: 0.0626
[Step 33140] loss: 0.0724
[Step 33150] loss: 0.0670
[Step 33160] loss: 0.0684
[Step 33170] loss: 0.0601
[Step 33180] loss: 0.0489
[Step 33190] loss: 0.0724
[Step 33200] loss: 0.0740
[Step 33210] loss: 0.0605
[Step 33220] loss: 0.0779
[Step 33230] loss: 0.0804
[Step 33240] loss: 0.0685
[Step 33250] loss: 0.0617
[Step 33260] loss: 0.0745
[Step 33270] loss: 0.0805
[Step 33280] loss: 0.0726
[Step 33290] loss: 0.0780
[Step 33300] loss: 0.0758
[Step 33310] loss: 0.0555
[Step 33320] loss: 0.0610
[Step 33330] loss: 0.0693
[Step 33340] loss: 0.0746
[Step 33350] loss: 0.0640
[Step 33360] loss: 0.0604
[Step 33370] loss: 0.0623
[Step 33380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 34000] loss: 0.0640
[Step 34010] loss: 0.0613
[Step 34020] loss: 0.0667
[Step 34030] loss: 0.0544
[Step 34040] loss: 0.0646
[Step 34050] loss: 0.0721
[Step 34060] loss: 0.0607
[Step 34070] loss: 0.0697
[Step 34080] loss: 0.0644
[Step 34090] loss: 0.0657
[Step 34100] loss: 0.0705
[Step 34110] loss: 0.0546
[Step 34120] loss: 0.0835
[Step 34130] loss: 0.0585
[Step 34140] loss: 0.0698
[Step 34150] loss: 0.0806
[Step 34160] loss: 0.0717
[Step 34170] loss: 0.0779
[Step 34180] loss: 0.0700
[Step 34190] loss: 0.0518
[Step 34200] loss: 0.0571
[Step 34210] loss: 0.0775
[Step 34220] loss: 0.0713
[Step 34230] loss: 0.0707
[Step 34240] loss: 0.0697
[Step 34250] loss: 0.0790
[Step 34260] loss: 0.0656
[Step 34270] loss: 0.0657
[Step 34280] loss: 0.0685
[Step 34290] loss: 0.0513
[Step 34300] loss: 0.0753
[Step 34310] loss: 0.0585
[Step 34320] loss: 0.0557
[Step 34330] loss: 0.0736
[Step 34340] loss: 0.0601
[Step 34350] loss: 0.0653
[Step 34360] loss: 0.0727
[Step 34370] loss: 0.0720
[Step 34380]

DDIM sampling:   0%|          | 0/100 [00:00<?, ?it/s]

[Step 35000] loss: 0.0628
[Step 35010] loss: 0.0792
[Step 35020] loss: 0.0610
[Step 35030] loss: 0.0628
[Step 35040] loss: 0.0644
[Step 35050] loss: 0.0699
[Step 35060] loss: 0.0658
[Step 35070] loss: 0.0641
[Step 35080] loss: 0.0743
[Step 35090] loss: 0.0625
[Step 35100] loss: 0.0697
[Step 35110] loss: 0.0713
[Step 35120] loss: 0.0626
[Step 35130] loss: 0.0489
[Step 35140] loss: 0.0609
[Step 35150] loss: 0.0712
[Step 35160] loss: 0.0759
[Step 35170] loss: 0.0753
[Step 35180] loss: 0.0745
[Step 35190] loss: 0.0744
[Step 35200] loss: 0.0639
[Step 35210] loss: 0.0610
[Step 35220] loss: 0.0679
[Step 35230] loss: 0.0776
[Step 35240] loss: 0.0574
[Step 35250] loss: 0.0603
[Step 35260] loss: 0.0753
[Step 35270] loss: 0.0563
[Step 35280] loss: 0.0753
[Step 35290] loss: 0.0698
[Step 35300] loss: 0.0513
[Step 35310] loss: 0.0725
[Step 35320] loss: 0.0623
[Step 35330] loss: 0.0718
[Step 35340] loss: 0.0584
[Step 35350] loss: 0.0916
[Step 35360] loss: 0.0884
[Step 35370] loss: 0.0770
[Step 35380]

KeyboardInterrupt: 

## Inference

In [None]:
ckpt = './UNetTransformer-emb256.pt'
trainer.load(ckpt)
# trainer.inference()

In [None]:
diffusion = trainer.model
with torch.no_grad():
    all_steps = diffusion.sample(batch_size=1, return_all_timesteps=True)
# shape [1, T+1, C, H, W]

imgs = all_steps.squeeze(0).permute(0, 2, 3, 1).cpu().numpy()
imgs = unnormalize_to_zero_to_one(imgs)

timesteps = [0, diffusion.num_timesteps // 4, diffusion.num_timesteps // 2,
             3 * diffusion.num_timesteps // 4, diffusion.num_timesteps]
sel = [imgs[t] for t in timesteps]

# Plot
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for ax, im, step in zip(axes, sel, timesteps):
    ax.imshow(im)
    ax.axis('off')
    ax.set_title(f"Step {step}")
plt.tight_layout()
plt.show()