# Diffusion Transformer

## Libraries

In [1]:
import os
import gc
import math
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision import transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from tqdm.auto import tqdm
from einops import rearrange

import matplotlib.pyplot as plt
from ema_pytorch import EMA
from PIL import Image
from torchvision.transforms import ToPILImage

## Data preprocess

In [None]:
class CustomDataset(Dataset):
    def __init__(self, npy_path: str, image_size: int):
        # 1) load the preprocessed numpy array
        #    expected shape: (N, H, W, C), dtype=uint8 or float32
        self.sprites = np.load(npy_path)
        self.image_size = image_size

        # 2) expose .sprites_shape for legacy code
        #    (N, H, W, C)
        self.sprites_shape = self.sprites.shape

        # 3) augmentation / normalization pipeline
        self.transform = transforms.Compose([
            ToPILImage(),  # turn the H×W×C array into a PIL Image
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)),
            transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            transforms.ToTensor(),                   # ➔ [0,1]
            transforms.Normalize((0.5, 0.5, 0.5),     # ➔ [–1,1]
                                 (0.5, 0.5, 0.5)),
        ])

    def __len__(self):
        return self.sprites_shape[0]

    def __getitem__(self, idx):
        # grab a single H×W×C numpy array
        img_np = self.sprites[idx]
        # apply your transforms (including ToPILImage internally)
        img = self.transform(img_np)
        return img

# Define a wrapper if using Hugging Face or custom VAE
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os

# Example: load VAE from diffusers
from diffusers import AutoencoderKL

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").eval().cuda()  # or your custom VAE
vae.requires_grad_(False)

class CustomLatentDataset(Dataset):
    def __init__(self, npy_path, image_size, vae):
        self.vae = vae
        self.images = np.load(npy_path).to("cuda")  # shape: [N, H, W, 3] or [N, 3, H, W]
        self.image_size = image_size

        # If images are in [N, H, W, 3], convert later
        self.is_channel_last = (self.images.shape[-1] == 3)

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),                     # [0, 1]
            transforms.Normalize([0.5], [0.5])          # to [-1, 1]
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]

        # Convert channel-last to channel-first
        if self.is_channel_last:
            img = np.transpose(img, (2, 0, 1))  # HWC -> CHW

        img = torch.from_numpy(img).float()  # [3, H, W]
        img = self.transform(img)            # [3, image_size, image_size]

        # Add batch dim and push to CUDA for VAE
        img = img.unsqueeze(0).cuda()

        with torch.no_grad():
            latent = self.vae.encode(img).latent_dist.sample() * 0.18215

        return latent.squeeze(0)  # [4, H/8, W/8]


2025-05-21 01:56:15.216639: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747763775.233822     842 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747763775.238922     842 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-21 01:56:15.256304: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

## Noise Generation

In [None]:
def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

# e.g. from Nichol & Dhariwal’s “improved DDPM”:
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas.clamp(max=0.999)

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

## network modules

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=8, emb_dim=512, img_size=64):
        super().__init__()
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.img_size = img_size
        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embedding = nn.Parameter(torch.randn((img_size // patch_size)**2, emb_dim))

    def forward(self, x):
        x = self.proj(x)  # [B, emb_dim, H/patch, W/patch]
        x = rearrange(x, 'b c h w -> b (h w) c')
        return x + self.pos_embedding

class ConvPatchEmbedding(nn.Module):
    def __init__(self, in_channels, emb_dim, img_size, patch_size=2, stride=1):
        super().__init__()
        self.patch_size = patch_size
        self.stride     = stride

        # 1) small conv-stem: in → emb_dim, keeps H×W
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, emb_dim//2, kernel_size=3, padding=1, stride=1),
            nn.GELU(),
            nn.Conv2d(emb_dim//2, emb_dim,   kernel_size=3, padding=1, stride=1),
            nn.GELU(),
        )

        # 2) Unfold to get sliding 2×2 patches
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=stride)

        # compute number of patches = ((H - P)//S + 1)**2
        num_patches_per_dim = (img_size - patch_size) // stride + 1
        num_patches = num_patches_per_dim**2

        # 3) project each flattened patch → emb_dim
        self.proj = nn.Linear(in_channels * patch_size * patch_size, emb_dim)

        # 4) learned positional embeddings
        self.pos_emb = nn.Parameter(torch.randn(1, num_patches, emb_dim))

    def forward(self, x):
        """
        x: [B, 3, 16, 16]
        → stem: [B, emb_dim, 16, 16]
        → unfold: [B, emb_dim*P*P, N_p]
        → proj:    [B, N_p, emb_dim]
        → +pos_emb: [B, N_p, emb_dim]
        """
        B = x.shape[0]
        x = self.stem(x)                          # [B, D, H, W]
        patches = self.unfold(x)                  # [B, D*P*P, N]
        patches = patches.permute(0, 2, 1)        # [B, N, D*P*P]
        tokens = self.proj(patches)               # [B, N, D]
        return tokens + self.pos_emb              # [B, N, D]

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = torch.exp(
            torch.arange(half_dim, device=device) * -(math.log(10000) / (half_dim - 1))
        )
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb  # [B, dim]

class TransformerBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        emb_dim: int,
        num_heads: int,
        dim_feedforward: int | None = None,
        dropout: float = 0.0
    ):
        """
        A single ViT‐style block that:
          1) optionally projects in_channels → emb_dim via 1×1 conv
          2) adds a (B,emb_dim) time embedding
          3) runs multihead self‐attention + residual + LayerNorm
          4) runs MLP (feed‐forward) + residual + LayerNorm
          5) optionally projects emb_dim → in_channels via 1×1 conv
        """
        super().__init__()
        # 1×1 conv to match channels → emb_dim (if needed)
        self.input_proj = (
            nn.Conv2d(in_channels, emb_dim, kernel_size=1)
            if in_channels != emb_dim else nn.Identity()
        )
        # Multi‐head self‐attention
        self.attn = nn.MultiheadAttention(
            embed_dim=emb_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(emb_dim)
        # Feed-forward
        ff_dim = dim_feedforward or (emb_dim * 4)
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, emb_dim),
        )
        self.norm2 = nn.LayerNorm(emb_dim)
        # Project back to in_channels if needed
        self.output_proj = (
            nn.Conv2d(emb_dim, in_channels, kernel_size=1)
            if in_channels != emb_dim else nn.Identity()
        )

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        """
        x:     [B, in_channels, H, W]
        t_emb: [B, emb_dim]  (output of your time‐MLP)
        returns [B, in_channels, H, W]
        """
        B, C, H, W = x.shape

        # 1) project to emb_dim
        x0 = self.input_proj(x)          # [B, emb_dim, H, W]

        # 2) add time embedding (broadcast to spatial dims)
        t = t_emb[:, :, None, None]      # [B, emb_dim, 1, 1]
        x0 = x0 + t                      # [B, emb_dim, H, W]

        # 3) flatten to (B, N, emb_dim) for attention
        x_flat = rearrange(x0, 'b c h w -> b (h w) c')  # N = H*W

        # 4) self‐attention + residual + LN
        attn_out, _ = self.attn(x_flat, x_flat, x_flat, need_weights=False)
        x1 = self.norm1(x_flat + attn_out)

        # 5) feed-forward + residual + LN
        ff_out = self.ff(x1)
        x2 = self.norm2(x1 + ff_out)

        # 6) reshape back to [B, emb_dim, H, W]
        x3 = rearrange(x2, 'b (h w) c -> b c h w', h=H, w=W)

        # 7) project back to in_channels if needed
        out = self.output_proj(x3)       # [B, in_channels, H, W]
        return out


class TransformerBackbone(nn.Module):
    def __init__(self, in_channels=3, img_size=64, patch_size=2, stride=1, emb_dim=512, depth=6, num_heads=8, time_emb_dim=128):
        super().__init__()
        self.channels = in_channels
        self.embed = ConvPatchEmbedding(
            in_channels, emb_dim, img_size,
            patch_size=patch_size, stride=stride
        )

        self.time_embedding = SinusoidalPosEmb(time_emb_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim)
        )

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=1024, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.output_proj = nn.Sequential(
            nn.Linear(emb_dim, patch_size * patch_size * in_channels),
        )
        self.patch_size = patch_size
        self.img_size = img_size
        

    def forward(self, x, t):
        B = x.size(0)
        patches = self.embed(x)  # [B, N, D]

        t_emb = self.time_embedding(t) # [B, time_emb_dim]
        t_emb = self.time_mlp(t_emb) # [B, D]
        t_emb = t_emb.unsqueeze(1) # [B, 1, D]

        patches = patches + t_emb # Broadcast time embedding to all patches
        
        encoded = self.encoder(patches)  # [B, N, D]
        decoded = self.output_proj(encoded)  # [B, N, patch*patch*channels]
        decoded = decoded.view(B, -1, self.channels, self.patch_size, self.patch_size)  # reshape into patches
        h = w = self.img_size // self.patch_size
        decoded = rearrange(decoded, 'b (h w) c ph pw -> b c (h ph) (w pw)', h=h, w=w)
        return decoded
    
class ConvSkipViT(nn.Module):
    def __init__(
        self,
        in_channels=3,
        img_size=64,
        patch_size=2,
        stride=1,
        emb_dim=512,
        depth=6,
        num_heads=8,
        time_emb_dim=128,
    ):
        super().__init__()
        # 1) conv to lift 3→emb_dim, preserving H×W
        self.conv_in = nn.Conv2d(in_channels, emb_dim, kernel_size=3, padding=1)
        
        # 2) the Transformer expects emb_dim channels
        self.vit = TransformerBackbone(
            in_channels=emb_dim,
            img_size=img_size,
            patch_size=patch_size,
            stride=stride,
            emb_dim=emb_dim,
            depth=depth,
            num_heads=num_heads,
            time_emb_dim=time_emb_dim,
        )
        
        # 3) conv to project emb_dim→3 and add back the original RGB
        self.conv_out = nn.Conv2d(emb_dim, in_channels, kernel_size=3, padding=1)
        
        # for diffusion wrapper
        self.channels = in_channels
        self.image_size = img_size

    def forward(self, x, t):
        """
        x: [B, 3, H, W], t: [B] timesteps
        """
        # conv skip-in
        x_emb = self.conv_in(x)                     # [B, emb_dim, H, W]
        
        # pass through your ViT
        y = self.vit(x_emb, t)                      # [B, emb_dim, H, W]
        
        # conv skip-out + residual
        out = self.conv_out(y) + x                  # [B, 3, H, W]
        return out
    
class DoubleConv(nn.Module):
    """Two successive 3×3 convs each followed by GELU."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GELU(),
        )
    def forward(self, x):
        return self.net(x)

class UNetTransformer(nn.Module):
    def __init__(
        self,
        *,
        in_channels=3,
        base_channels=64,
        emb_dim=256,
        num_heads=8,
        time_emb_dim=128
    ):
        super().__init__()
        self.channels = in_channels

        # Time‐MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        )

        # ---- Encoder ----
        # Stage 1: 3 → 64
        self.enc1    = DoubleConv(in_channels,        base_channels)
        self.trans1  = TransformerBlock(base_channels, emb_dim, num_heads)

        # Stage 2: 64 → 128
        self.enc2    = DoubleConv(base_channels,      base_channels*2)
        self.down2   = nn.Conv2d(base_channels*2,     base_channels*2, 4, stride=2, padding=1)
        self.trans2  = TransformerBlock(base_channels*2, emb_dim, num_heads)

        # Stage 3: 128 → 256
        self.enc3    = DoubleConv(base_channels*2,    base_channels*4)
        self.down3   = nn.Conv2d(base_channels*4,     base_channels*4, 4, stride=2, padding=1)
        self.trans3  = TransformerBlock(base_channels*4, emb_dim, num_heads)

        # Stage 4: 256 → 512
        self.enc4    = DoubleConv(base_channels*4,    base_channels*8)
        self.down4   = nn.Conv2d(base_channels*8,     base_channels*8, 4, stride=2, padding=1)
        self.trans4  = TransformerBlock(base_channels*8, emb_dim, num_heads)

        # ---- Decoder ----
        # Upsample + concat + DoubleConv + Transformer per stage
        # Stage 4→3: 512 → 256, then cat with enc4(256)→512→256
        self.up4       = nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, stride=2, padding=1)
        self.dec4      = DoubleConv(base_channels*4 + base_channels*8, base_channels*4)
        self.trans_up3 = TransformerBlock(base_channels*4, emb_dim, num_heads)

        # Stage 3→2: 256 → 128, then cat with enc3(128)→256→128
        self.up3       = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, stride=2, padding=1)
        self.dec3      = DoubleConv(base_channels*2 + base_channels*4, base_channels*2)
        self.trans_up2 = TransformerBlock(base_channels*2, emb_dim, num_heads)

        # Stage 2→1: 128 → 64, then cat with enc2(64)→128→64
        self.up2       = nn.ConvTranspose2d(base_channels*2, base_channels, 4, stride=2, padding=1)
        self.dec2      = DoubleConv(base_channels + base_channels*2, base_channels)
        self.trans_up1 = TransformerBlock(base_channels, emb_dim, num_heads)

        # Final 1×1 conv to RGB
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        # build time embedding once
        t_emb = self.time_mlp(t)  # [B, emb_dim]

        # —— Encode ——
        e1 = self.enc1(x)               # [B,  64, 16,16]
        e1 = self.trans1(e1, t_emb)

        x2 = self.enc2(e1)              # [B, 128,16,16]
        d2 = self.down2(x2)             # [B, 128, 8, 8]
        e2 = self.trans2(d2, t_emb)

        x3 = self.enc3(e2)              # [B, 256, 8, 8]
        d3 = self.down3(x3)             # [B, 256, 4, 4]
        e3 = self.trans3(d3, t_emb)

        x4 = self.enc4(e3)              # [B, 512, 4, 4]
        b  = self.down4(x4)             # [B, 512, 2, 2]
        b  = self.trans4(b, t_emb)

        # —— Decode ——
        u3 = self.up4(b)                           # [B, 256, 4, 4]
        u3 = torch.cat([u3, x4], dim=1)            # [B, 256+512=768, 4,4]
        u3 = self.dec4(u3)                         # [B, 256, 4, 4]
        u3 = self.trans_up3(u3, t_emb)

        u2 = self.up3(u3)                          # [B, 128, 8, 8]
        u2 = torch.cat([u2, x3], dim=1)            # [B, 128+256=384, 8,8]
        u2 = self.dec3(u2)                         # [B, 128, 8, 8]
        u2 = self.trans_up2(u2, t_emb)

        u1 = self.up2(u2)                          # [B, 64,16,16]
        u1 = torch.cat([u1, x2], dim=1)            # [B, 64+128=192,16,16]
        u1 = self.dec2(u1)                         # [B, 64,16,16]
        u1 = self.trans_up1(u1, t_emb)

        return self.out_conv(u1)                   # [B,  3,16,16]
    


## From pixart-alpha

In [4]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.GELU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.GELU(),
            nn.Conv2d(128, 256, 4, 2, 1), nn.GELU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(256 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(256 * 8 * 8, latent_dim)

    def forward(self, x):
        h = self.conv(x)
        return self.fc_mu(h), self.fc_logvar(h)
    
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 256 * 8 * 8)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.GELU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.GELU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Tanh()
        )

    def forward(self, z):
        h = self.fc(z).view(-1, 256, 8, 8)
        return self.deconv(h)
    
def get_timestep_embedding(timesteps, dim):
    """
    Create sinusoidal embeddings.
    
    timesteps: Tensor of shape [B] (assumed integer values)
    dim: Embedding dimension
    Returns tensor of shape [B, dim]
    """
    assert len(timesteps.shape) == 1
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float().unsqueeze(1) * emb.unsqueeze(0)  # [B, half_dim]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1))
    return emb  # shape [B, dim]

class TransformerBlock(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ff(self.norm2(x))
        return x
    
class UViT(nn.Module):
    def __init__(self, latent_shape=(32, 32), in_channels=4, dim=512, depth=6, heads=8):
        super().__init__()
        self.latent_shape = latent_shape  # (H, W)
        self.in_channels = in_channels
        self.dim = dim
        self.channels = in_channels  # required by the diffusion wrapper
        
        # Map each token from its channel dimension
        self.patch_embed = nn.Linear(in_channels, dim)
        self.time_embed = nn.Linear(dim, dim)  # projects the timestep embedding
        
        self.encoder = nn.ModuleList([TransformerBlock(dim, heads) for _ in range(depth)])
        self.mid_block = TransformerBlock(dim, heads)
        self.decoder = nn.ModuleList([TransformerBlock(dim, heads) for _ in range(depth)])

        self.out = nn.Linear(dim, in_channels)

    def forward(self, latents, t_emb):
        # latents: [B, C, H, W]
        B, C, H, W = latents.shape
        
        # If t_emb is a raw timestep (Long) convert it to a float embedding.
        if t_emb.dtype == torch.long or t_emb.dtype == torch.int64:
            # Create sinusoidal embeddings of size [B, dim]
            t_emb = get_timestep_embedding(t_emb, self.dim).to(latents.dtype)
        
        # Rearrange latents: treat each spatial location as a token.
        # From [B, C, H, W] → [B, H*W, C]
        x = latents.permute(0, 2, 3, 1).reshape(B, H * W, C)
        # Map channels into transformer dimension
        x = self.patch_embed(x)  # shape: [B, H*W, dim]
        
        # Add time embedding.
        # self.time_embed expects input of shape [B, dim], then output is [B, dim]
        # Unsqueeze to match token dimension: [B, 1, dim]
        x = x + self.time_embed(t_emb).unsqueeze(1)
        
        skips = []
        for block in self.encoder:
            x = block(x)
            skips.append(x)

        x = self.mid_block(x)

        for block, skip in zip(self.decoder, reversed(skips)):
            x = block(x + skip)

        x = self.out(x)  # shape: [B, H*W, C]
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)  # restore shape to [B, C, H, W]
        return x

## Diffusion process

In [5]:
# normalization functions
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        image_size: int, 
        timesteps: int,
        beta_schedule: str = 'linear',
        auto_normalize: bool = True,
        lambda_l1: float = 0.1,
    ):
        super().__init__()

        self.model = model
        self.channels = model.channels
        self.image_size = image_size
        self.num_timesteps = timesteps
        self.lambda_l1 = lambda_l1

        # 1) build beta schedule
        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f'Unknown beta schedule: {beta_schedule}')
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        # 2) register all buffers
        self.register_buffer('betas', betas.float())
        self.register_buffer('alphas_cumprod', alphas_cumprod.float())
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod).float())
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod - 1.0).float())

        # posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance.float())
        self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)).float())
        self.register_buffer('posterior_mean_coef1', (betas * torch.sqrt(alphas_cumprod_prev) / (1 - alphas_cumprod)).float())
        self.register_buffer('posterior_mean_coef2', ((1 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1 - alphas_cumprod)).float())

        # 3) Normalization functions
        self.normalize = normalize_to_neg_one_to_one if auto_normalize else (lambda x: x)
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else (lambda x: x)

    def dynamic_threshold(self, x0, percentile=0.995):
        """ Clips x0 dynamically to avoid over/under-exposed outputs """
        s = torch.quantile(x0.abs().flatten(1), percentile, dim=1)
        s = torch.maximum(s, torch.ones_like(s))[:, None, None, None]
        return x0.clamp(-s, s) / s

    def predict_start_from_noise(self, x_t, t, noise):
        """Estimate x_0 from x_t and predicted noise"""
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 
            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )
    
    def q_posterior(self, x_start, x_t, t):
        """ Compute mean & variance of q(x_{t-1}) | x_t, x_0 """
        mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        var = extract(self.posterior_variance, t, x_t.shape)
        log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return mean, var, log_var
    
    @torch.no_grad()
    def p_sample(self, x, t: int, clip_denoised: bool = True):
        """ One diffusion reverse step """
        b = x.shape[0]
        t_batch = torch.full((b,), t, device=x.device, dtype=torch.long)
        # model predicts noise
        pred_noise = self.model(x, t_batch)
        # recover x0; optionally clamp it
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        x0 = self.dynamic_threshold(x0)

        # posterior mean & variance
        mean, _, log_var = self.q_posterior(x0, x, t_batch)
        noise = torch.randn_like(x) if t > 0 else 0.0
        return mean + torch.exp(0.5 * log_var) * noise

    @torch.no_grad()
    def p_sample_loop(self, shape, return_all_timesteps: bool = False):
        img = torch.randn(shape, device=self.betas.device)
        if return_all_timesteps:
            all_imgs = [img]
            for t in tqdm(reversed(range(self.num_timesteps)), desc='sampling'):
                img = self.p_sample(img, t)
                all_imgs.append(img)
            out = torch.stack(all_imgs, dim=1)
        else:
            for t in tqdm(reversed(range(self.num_timesteps)), desc='sampling'):
                img = self.p_sample(img, t)
            out = img
        return self.unnormalize(out)
    
    @torch.no_grad()
    def p_sample_ddim(self, x, t, t_prev, eta=0.0, clip_denoised=True):
        """ One DDIM sampling step from x_t to x_{t_prev}"""
        b = x.shape[0]
        t_batch = torch.full((b, ), t, device=x.device, dtype=torch.long)
        # predict noise
        pred_noise = self.model(x, t_batch)
        # pred x0
        x0 = self.predict_start_from_noise(x, t_batch, pred_noise)
        x0 = self.dynamic_threshold(x0)

        alpha_t = self.alphas_cumprod[t]
        alpha_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=x.device)
        sqrt_alpha_t = alpha_t.sqrt()
        sqrt_alpha_prev = alpha_prev.sqrt()
        sigma_t = eta * ((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev)).sqrt()
        pred_dir = (1 - alpha_prev - sigma_t ** 2).sqrt() * pred_noise
        noise = sigma_t * torch.randn_like(x) if t_prev >= 0 else 0.0
        x_prev = sqrt_alpha_prev * x0 + pred_dir + noise

        return x_prev, x0
    
    @torch.no_grad()
    def ddim_sample_loop(self, shape, num_ddim_steps=50, eta=0.0, return_all_timesteps=False):
        """ Run the fuull DDIM sampling loop """
        device = self.betas.device
        img = torch.randn(shape, device=device)
        all_imgs = [img]

        # Create a custom timestep schedule
        ddim_timesteps = np.linspace(0, self.num_timesteps - 1, num_ddim_steps, dtype=int)

        for i in tqdm(range(num_ddim_steps - 1, -1, -1), desc='DDIM sampling'):
            t = ddim_timesteps[i]
            t_prev = ddim_timesteps[i - 1] if i > 0 else -1

            img, _ = self.p_sample_ddim(img, t, t_prev, eta=eta)
            all_imgs.append(img)
        
        out = torch.stack(all_imgs, dim=1) if return_all_timesteps else img
        return self.unnormalize(out)

    def sample(self, batch_size=16, use_ddim=False, num_ddim_steps=50, eta=0.0, return_all_timesteps=False):
        shape = (batch_size, self.channels, self.image_size, self.image_size)
        if use_ddim:
            return self.ddim_sample_loop(shape, num_ddim_steps=num_ddim_steps, eta=eta, return_all_timesteps=return_all_timesteps)
        else:
            return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)
    
    def q_sample(self, x_start, t, noise=None):
        """ Forward noising process """
        if noise is None:
            noise = torch.randn_like(x_start)
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
    
    def p_losses(self, x_start, t, noise=None):
        """ MSE + L1 loss between true noise and model's prediction """
        if noise is None:
            noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)
        pred_noise = self.model(x_noisy, t)
        mse_loss = F.mse_loss(pred_noise, noise, reduction='none')
        mse_loss = mse_loss.mean(dim=list(range(1, mse_loss.ndim))) # mean over c, h, w

        # x_recon = self.predict_start_from_noise(x_noisy, t, pred_noise)
        # x_recon = x_recon.clamp(-1., -1.)
        # l1_loss = F.l1_loss(x_recon, x_start, reduction='none')
        # l1_loss = l1_loss.mean(dim=list(range(1, l1_loss.ndim)))

        # loss = mse_loss + self.lambda_l1 * l1_loss
        loss = mse_loss
        return loss.mean()
    
    def forward(self, img):
        """ Training entrypoint: sample random timesteps & return loss """
        b, c, h, w = img.shape
        assert h == self.image_size and w == self.image_size
        t = torch.randint(0, self.num_timesteps, (b, ), device=img.device)
        img = self.normalize(img)
        return self.p_losses(img, t)

## Trainer

In [None]:
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def cycle(dl):
    while True:
        for data in dl:
            yield data

class Trainer:
    def __init__(
        self,
        diffusion_model: nn.Module,
        data_folder: str,
        batch_size: int = 16,
        lr: float = 1e-4,
        num_steps: int = 100000,
        grad_accum_steps: int = 1,
        ema_decay: float = 0.995,
        save_interval: int = 1000,
        num_samples: int = 25,
        results_folder: str = './results',
        use_ddim = False,
        num_ddim_steps=50,
        eta = 0.0
    ):
        # Accelerator
        self.accelerator = Accelerator(mixed_precision='fp16')
        self.device = self.accelerator.device

        # Training State
        self.batch_size       = batch_size
        self.grad_accum_steps = grad_accum_steps
        self.num_steps        = num_steps
        self.save_interval    = save_interval
        self.num_samples      = num_samples

        # model, optimizer, EMA
        self.model = diffusion_model.to(self.device)
        self.optimizer = Adam(self.model.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.num_steps, eta_min=1e-6)
        self.model, self.optimizer = self.accelerator.prepare(
            self.model, self.optimizer
        )
        
        # Use EMA on the raw model
        self.ema = EMA(self.accelerator.unwrap_model(self.model), beta=ema_decay)

        # Data
        ds = CustomLatentDataset(data_folder, diffusion_model.image_size, vae)
        dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())
        self.dl = cycle(self.accelerator.prepare(dl))

        # checkpoints & samples
        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(parents=True, exist_ok=True)
        self.step = 0
        self.use_ddim = use_ddim
        self.num_ddim_steps = num_ddim_steps
        self.eta = eta
    
    def save(self, milestone: int):
        """Save model, optimizer, EMA and step counter."""
        if not self.accelerator.is_main_process:
            return
        ckpt = self.results_folder / f'model-{milestone}.pt'
        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'optimizer': self.optimizer.state_dict(),
            'ema':       self.ema.state_dict(),
        }
        torch.save(data, ckpt)

    def load(self, ckpt_path: str):
        """Load all state (model, optimizer, EMA, step)."""
        data = torch.load(ckpt_path, map_location=self.device)
        raw_model = self.accelerator.unwrap_model(self.model)
        raw_model.load_state_dict(data['model'])
        self.optimizer.load_state_dict(data['optimizer'])
        self.ema.load_state_dict(data['ema'])
        self.step = data['step']

    def _sample_and_save(self, milestone: int):
        """Generate `num_samples` via EMA model and save grid."""
        self.ema.ema_model.eval()
        batches = num_to_groups(self.num_samples, self.batch_size)
        imgs = torch.cat([
            self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta) for n in batches
        ], dim=0)
        path = self.results_folder / f'sample-{milestone}.png'
        vutils.save_image(imgs, path, nrow=int(math.sqrt(self.num_samples)))

    def train(self):
        """Run the training loop with gradient accumulation, EMA updates, and periodic sampling."""
        pbar = tqdm(total=self.num_steps, initial=self.step, disable=True)
        while self.step < self.num_steps:
            total_loss = 0.0

            # gradient accumulation
            for _ in range(self.grad_accum_steps):
                batch = next(self.dl).to(self.device)
                with self.accelerator.autocast():
                    loss = self.model(batch) / self.grad_accum_steps
                total_loss += loss.item()
                self.accelerator.backward(loss)

            # optimizer step
            self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
            self.step += 1

            # EMA & sampling
            if self.accelerator.is_main_process:
                self.ema.update()
                if self.step % self.save_interval == 0:
                    milestone = self.step // self.save_interval
                    self._sample_and_save(milestone)
                    self.save(milestone)
            if self.step % 10 == 0 and self.accelerator.is_main_process:
                print(f"[Step {self.step}] loss: {total_loss:.4f}")

            pbar.set_description(f'loss: {total_loss:.4f}')
            pbar.update(1)

        if self.accelerator.is_main_process:
            print('Training complete.')

    def inference(self, total: int = 1000, output_path: str = './submission'):
        """
        Generate `total` images using the same batch size as training
        (self.batch_size, e.g. 16), and save them to disk.
        """
        Path(output_path).mkdir(parents=True, exist_ok=True)
        count = 0
        batch_size = self.batch_size  # manually fixed at 16

        with torch.no_grad():
            while count < total:
                n = min(batch_size, total - count)

                # clear any leftover GPU memory
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()

                # sample n images
                imgs = self.ema.ema_model.sample(batch_size=n, use_ddim=self.use_ddim, num_ddim_steps=self.num_ddim_steps, eta=self.eta)

                # move to CPU and save
                imgs = imgs.cpu()
                for img in imgs:
                    count += 1
                    vutils.save_image(img, f"{output_path}/{count}.jpg")

                # clean up 
                del imgs
                if torch.cuda.is_available():
                    gc.collect()
                    torch.cuda.empty_cache()

        print("Inference complete.")

## Usage

In [8]:
gc.collect()
torch.cuda.empty_cache()

In [9]:
# model = TransformerBackbone(
#     in_channels=3, img_size=16, patch_size=1,
#     emb_dim=1024, depth=8, num_heads=8, time_emb_dim=128
# )

# model = UNetTransformer(
#     in_channels=3,
#     base_channels=64,
#     emb_dim=256,
#     num_heads=4,
#     time_emb_dim=128,
# )

model = UViT((16, 16), 3, dim=256, depth=6, heads=8)

diffusion = GaussianDiffusion(
    model=model, image_size=16, timesteps=1000,
    beta_schedule='cosine', auto_normalize=True
)
trainer = Trainer(
    diffusion_model=diffusion,
    data_folder = './sprites.npy',
    batch_size=32, lr=1e-4, num_steps=50000,
    grad_accum_steps=1, ema_decay=0.999,
    save_interval=1000, num_samples=25,
    results_folder='./results',
    use_ddim=False, num_ddim_steps=50, eta=0.0
)


In [10]:
total_params = sum(p.numel() for p in model.parameters())
print("Params: ", total_params/1e6, "M")

Params:  10.334467 M


In [11]:
trainer.train()

[Step 10] loss: 5.7415
[Step 20] loss: 2.0250
[Step 30] loss: 1.0543
[Step 40] loss: 0.8597
[Step 50] loss: 0.6562
[Step 60] loss: 0.5841
[Step 70] loss: 0.4760
[Step 80] loss: 0.5589
[Step 90] loss: 0.4919
[Step 100] loss: 0.5668
[Step 110] loss: 0.4235
[Step 120] loss: 0.4329
[Step 130] loss: 0.4620
[Step 140] loss: 0.4182
[Step 150] loss: 0.5452
[Step 160] loss: 0.3483
[Step 170] loss: 0.4825
[Step 180] loss: 0.5201
[Step 190] loss: 0.3940
[Step 200] loss: 0.4752
[Step 210] loss: 0.4016
[Step 220] loss: 0.4636
[Step 230] loss: 0.3973
[Step 240] loss: 0.3456
[Step 250] loss: 0.3587
[Step 260] loss: 0.4227
[Step 270] loss: 0.3488
[Step 280] loss: 0.4658
[Step 290] loss: 0.4055
[Step 300] loss: 0.3921
[Step 310] loss: 0.4518
[Step 320] loss: 0.3281
[Step 330] loss: 0.3635
[Step 340] loss: 0.3296
[Step 350] loss: 0.4560
[Step 360] loss: 0.4826
[Step 370] loss: 0.4448
[Step 380] loss: 0.3343
[Step 390] loss: 0.2936
[Step 400] loss: 0.4284
[Step 410] loss: 0.2878
[Step 420] loss: 0.2962
[

sampling: 0it [00:00, ?it/s]

[Step 1000] loss: 0.2902
[Step 1010] loss: 0.2847
[Step 1020] loss: 0.3017
[Step 1030] loss: 0.2664
[Step 1040] loss: 0.2650
[Step 1050] loss: 0.2778
[Step 1060] loss: 0.2289
[Step 1070] loss: 0.2946
[Step 1080] loss: 0.2845
[Step 1090] loss: 0.2071
[Step 1100] loss: 0.2505
[Step 1110] loss: 0.2999
[Step 1120] loss: 0.3292
[Step 1130] loss: 0.2806
[Step 1140] loss: 0.3170
[Step 1150] loss: 0.2414
[Step 1160] loss: 0.2475
[Step 1170] loss: 0.3011
[Step 1180] loss: 0.2982
[Step 1190] loss: 0.2678
[Step 1200] loss: 0.2122
[Step 1210] loss: 0.2548
[Step 1220] loss: 0.2431
[Step 1230] loss: 0.2515
[Step 1240] loss: 0.2530
[Step 1250] loss: 0.2582
[Step 1260] loss: 0.2410
[Step 1270] loss: 0.2247
[Step 1280] loss: 0.2010
[Step 1290] loss: 0.2267
[Step 1300] loss: 0.2834
[Step 1310] loss: 0.2206
[Step 1320] loss: 0.2566
[Step 1330] loss: 0.2398
[Step 1340] loss: 0.2713
[Step 1350] loss: 0.2447
[Step 1360] loss: 0.2195
[Step 1370] loss: 0.2673
[Step 1380] loss: 0.2251
[Step 1390] loss: 0.2077


sampling: 0it [00:00, ?it/s]

[Step 2000] loss: 0.2010
[Step 2010] loss: 0.2711
[Step 2020] loss: 0.2977
[Step 2030] loss: 0.2233
[Step 2040] loss: 0.1921
[Step 2050] loss: 0.2432
[Step 2060] loss: 0.2433
[Step 2070] loss: 0.1935
[Step 2080] loss: 0.2157
[Step 2090] loss: 0.2113
[Step 2100] loss: 0.2520
[Step 2110] loss: 0.2082
[Step 2120] loss: 0.2330
[Step 2130] loss: 0.1976
[Step 2140] loss: 0.2190
[Step 2150] loss: 0.2361
[Step 2160] loss: 0.2676
[Step 2170] loss: 0.2452
[Step 2180] loss: 0.2267
[Step 2190] loss: 0.2390
[Step 2200] loss: 0.2157
[Step 2210] loss: 0.1870
[Step 2220] loss: 0.1963
[Step 2230] loss: 0.2199
[Step 2240] loss: 0.2371
[Step 2250] loss: 0.2463
[Step 2260] loss: 0.2572
[Step 2270] loss: 0.2103
[Step 2280] loss: 0.2244
[Step 2290] loss: 0.2763
[Step 2300] loss: 0.2367
[Step 2310] loss: 0.2051
[Step 2320] loss: 0.2498
[Step 2330] loss: 0.2073
[Step 2340] loss: 0.2441
[Step 2350] loss: 0.2211
[Step 2360] loss: 0.1999
[Step 2370] loss: 0.2429
[Step 2380] loss: 0.2169
[Step 2390] loss: 0.2669


sampling: 0it [00:00, ?it/s]

[Step 3000] loss: 0.2235
[Step 3010] loss: 0.2086
[Step 3020] loss: 0.2234
[Step 3030] loss: 0.2116
[Step 3040] loss: 0.1970
[Step 3050] loss: 0.2232
[Step 3060] loss: 0.2236
[Step 3070] loss: 0.2430
[Step 3080] loss: 0.2201
[Step 3090] loss: 0.2199
[Step 3100] loss: 0.2413
[Step 3110] loss: 0.2271
[Step 3120] loss: 0.2242
[Step 3130] loss: 0.2147
[Step 3140] loss: 0.2286
[Step 3150] loss: 0.2190
[Step 3160] loss: 0.2549
[Step 3170] loss: 0.2185
[Step 3180] loss: 0.2329
[Step 3190] loss: 0.2082
[Step 3200] loss: 0.2467
[Step 3210] loss: 0.2664
[Step 3220] loss: 0.2183
[Step 3230] loss: 0.2250
[Step 3240] loss: 0.2171
[Step 3250] loss: 0.1922
[Step 3260] loss: 0.2394
[Step 3270] loss: 0.1834
[Step 3280] loss: 0.2305
[Step 3290] loss: 0.1938
[Step 3300] loss: 0.2424
[Step 3310] loss: 0.2228
[Step 3320] loss: 0.2144
[Step 3330] loss: 0.2198
[Step 3340] loss: 0.1746
[Step 3350] loss: 0.2164
[Step 3360] loss: 0.2456
[Step 3370] loss: 0.2231
[Step 3380] loss: 0.1944
[Step 3390] loss: 0.2566


sampling: 0it [00:00, ?it/s]

[Step 4000] loss: 0.2274
[Step 4010] loss: 0.2523
[Step 4020] loss: 0.1663
[Step 4030] loss: 0.1955
[Step 4040] loss: 0.2114
[Step 4050] loss: 0.1635
[Step 4060] loss: 0.2218
[Step 4070] loss: 0.2134
[Step 4080] loss: 0.1745
[Step 4090] loss: 0.2384
[Step 4100] loss: 0.2042
[Step 4110] loss: 0.1963
[Step 4120] loss: 0.2257
[Step 4130] loss: 0.2337
[Step 4140] loss: 0.2074
[Step 4150] loss: 0.2015
[Step 4160] loss: 0.2109
[Step 4170] loss: 0.2502
[Step 4180] loss: 0.2024
[Step 4190] loss: 0.1879
[Step 4200] loss: 0.2378
[Step 4210] loss: 0.2090
[Step 4220] loss: 0.2556
[Step 4230] loss: 0.2219
[Step 4240] loss: 0.1908
[Step 4250] loss: 0.2445
[Step 4260] loss: 0.2461
[Step 4270] loss: 0.2402
[Step 4280] loss: 0.2137
[Step 4290] loss: 0.2067
[Step 4300] loss: 0.1925
[Step 4310] loss: 0.2667
[Step 4320] loss: 0.2096
[Step 4330] loss: 0.2184
[Step 4340] loss: 0.2414
[Step 4350] loss: 0.2414
[Step 4360] loss: 0.2474
[Step 4370] loss: 0.2078
[Step 4380] loss: 0.1993
[Step 4390] loss: 0.1639


sampling: 0it [00:00, ?it/s]

[Step 5000] loss: 0.1739
[Step 5010] loss: 0.2154
[Step 5020] loss: 0.2350
[Step 5030] loss: 0.2218
[Step 5040] loss: 0.1653
[Step 5050] loss: 0.2117
[Step 5060] loss: 0.2121
[Step 5070] loss: 0.2310
[Step 5080] loss: 0.2248
[Step 5090] loss: 0.1979
[Step 5100] loss: 0.1833
[Step 5110] loss: 0.2209
[Step 5120] loss: 0.1725
[Step 5130] loss: 0.1971
[Step 5140] loss: 0.2336
[Step 5150] loss: 0.2182
[Step 5160] loss: 0.2410
[Step 5170] loss: 0.2024
[Step 5180] loss: 0.2083
[Step 5190] loss: 0.2252
[Step 5200] loss: 0.1975
[Step 5210] loss: 0.1926
[Step 5220] loss: 0.2271
[Step 5230] loss: 0.2275
[Step 5240] loss: 0.1943
[Step 5250] loss: 0.2094
[Step 5260] loss: 0.2388
[Step 5270] loss: 0.1995
[Step 5280] loss: 0.2022
[Step 5290] loss: 0.2516
[Step 5300] loss: 0.2279
[Step 5310] loss: 0.2325
[Step 5320] loss: 0.2113
[Step 5330] loss: 0.1925
[Step 5340] loss: 0.2168
[Step 5350] loss: 0.1933
[Step 5360] loss: 0.1906
[Step 5370] loss: 0.2108
[Step 5380] loss: 0.2147
[Step 5390] loss: 0.2210


sampling: 0it [00:00, ?it/s]

[Step 6000] loss: 0.2130
[Step 6010] loss: 0.1665
[Step 6020] loss: 0.2055
[Step 6030] loss: 0.2222
[Step 6040] loss: 0.2064
[Step 6050] loss: 0.1973
[Step 6060] loss: 0.2113
[Step 6070] loss: 0.2056
[Step 6080] loss: 0.2526
[Step 6090] loss: 0.1966
[Step 6100] loss: 0.2348
[Step 6110] loss: 0.2067
[Step 6120] loss: 0.2399
[Step 6130] loss: 0.1940
[Step 6140] loss: 0.2605
[Step 6150] loss: 0.2292
[Step 6160] loss: 0.2322
[Step 6170] loss: 0.2277
[Step 6180] loss: 0.2460
[Step 6190] loss: 0.1865
[Step 6200] loss: 0.2047
[Step 6210] loss: 0.2319
[Step 6220] loss: 0.2022
[Step 6230] loss: 0.2365
[Step 6240] loss: 0.2117
[Step 6250] loss: 0.2491
[Step 6260] loss: 0.2385
[Step 6270] loss: 0.1809
[Step 6280] loss: 0.2161
[Step 6290] loss: 0.1846
[Step 6300] loss: 0.2056
[Step 6310] loss: 0.2144
[Step 6320] loss: 0.2114
[Step 6330] loss: 0.2322
[Step 6340] loss: 0.1996
[Step 6350] loss: 0.1754
[Step 6360] loss: 0.2466
[Step 6370] loss: 0.2166
[Step 6380] loss: 0.2102
[Step 6390] loss: 0.2084


sampling: 0it [00:00, ?it/s]

[Step 7000] loss: 0.1981
[Step 7010] loss: 0.2310
[Step 7020] loss: 0.1878
[Step 7030] loss: 0.1679
[Step 7040] loss: 0.2005
[Step 7050] loss: 0.1932
[Step 7060] loss: 0.1814
[Step 7070] loss: 0.2296
[Step 7080] loss: 0.2159
[Step 7090] loss: 0.2114
[Step 7100] loss: 0.2235
[Step 7110] loss: 0.2298
[Step 7120] loss: 0.1971
[Step 7130] loss: 0.1894
[Step 7140] loss: 0.2040
[Step 7150] loss: 0.1999
[Step 7160] loss: 0.1933
[Step 7170] loss: 0.1840
[Step 7180] loss: 0.1757
[Step 7190] loss: 0.1867
[Step 7200] loss: 0.2335
[Step 7210] loss: 0.2172
[Step 7220] loss: 0.2222
[Step 7230] loss: 0.2504
[Step 7240] loss: 0.2171
[Step 7250] loss: 0.2070
[Step 7260] loss: 0.1990
[Step 7270] loss: 0.1877
[Step 7280] loss: 0.2232
[Step 7290] loss: 0.2101
[Step 7300] loss: 0.2293
[Step 7310] loss: 0.1837
[Step 7320] loss: 0.2097
[Step 7330] loss: 0.1983
[Step 7340] loss: 0.2141
[Step 7350] loss: 0.2468
[Step 7360] loss: 0.2270
[Step 7370] loss: 0.2265
[Step 7380] loss: 0.2305
[Step 7390] loss: 0.2205


sampling: 0it [00:00, ?it/s]

[Step 8000] loss: 0.2111
[Step 8010] loss: 0.2099
[Step 8020] loss: 0.2138
[Step 8030] loss: 0.2163
[Step 8040] loss: 0.2168
[Step 8050] loss: 0.2252
[Step 8060] loss: 0.2001
[Step 8070] loss: 0.2129
[Step 8080] loss: 0.1983
[Step 8090] loss: 0.1854
[Step 8100] loss: 0.2035
[Step 8110] loss: 0.2342
[Step 8120] loss: 0.2130
[Step 8130] loss: 0.1772
[Step 8140] loss: 0.1920
[Step 8150] loss: 0.2062
[Step 8160] loss: 0.2418
[Step 8170] loss: 0.2540
[Step 8180] loss: 0.2004
[Step 8190] loss: 0.1782
[Step 8200] loss: 0.1990
[Step 8210] loss: 0.2217
[Step 8220] loss: 0.1983
[Step 8230] loss: 0.2008
[Step 8240] loss: 0.2047
[Step 8250] loss: 0.2310
[Step 8260] loss: 0.2051
[Step 8270] loss: 0.2017
[Step 8280] loss: 0.2108
[Step 8290] loss: 0.2120
[Step 8300] loss: 0.2109
[Step 8310] loss: 0.1910
[Step 8320] loss: 0.2086
[Step 8330] loss: 0.2002
[Step 8340] loss: 0.2133
[Step 8350] loss: 0.2035
[Step 8360] loss: 0.2066
[Step 8370] loss: 0.2354
[Step 8380] loss: 0.1683
[Step 8390] loss: 0.2356


sampling: 0it [00:00, ?it/s]

[Step 9000] loss: 0.2233
[Step 9010] loss: 0.2342
[Step 9020] loss: 0.2302
[Step 9030] loss: 0.1857
[Step 9040] loss: 0.1879
[Step 9050] loss: 0.2117
[Step 9060] loss: 0.1908
[Step 9070] loss: 0.1723
[Step 9080] loss: 0.1930
[Step 9090] loss: 0.2190
[Step 9100] loss: 0.2253
[Step 9110] loss: 0.1800
[Step 9120] loss: 0.2260
[Step 9130] loss: 0.2350
[Step 9140] loss: 0.2433
[Step 9150] loss: 0.2215
[Step 9160] loss: 0.1874
[Step 9170] loss: 0.1750
[Step 9180] loss: 0.2130
[Step 9190] loss: 0.1805
[Step 9200] loss: 0.2042
[Step 9210] loss: 0.1553
[Step 9220] loss: 0.2223
[Step 9230] loss: 0.2027
[Step 9240] loss: 0.1784
[Step 9250] loss: 0.2299
[Step 9260] loss: 0.1898
[Step 9270] loss: 0.2120
[Step 9280] loss: 0.2131
[Step 9290] loss: 0.2280
[Step 9300] loss: 0.1912
[Step 9310] loss: 0.1846
[Step 9320] loss: 0.1840
[Step 9330] loss: 0.1983
[Step 9340] loss: 0.2326
[Step 9350] loss: 0.1687
[Step 9360] loss: 0.2458
[Step 9370] loss: 0.2088
[Step 9380] loss: 0.2849
[Step 9390] loss: 0.2239


sampling: 0it [00:00, ?it/s]

[Step 10000] loss: 0.1938
[Step 10010] loss: 0.2245
[Step 10020] loss: 0.2270
[Step 10030] loss: 0.2125
[Step 10040] loss: 0.2255
[Step 10050] loss: 0.1881
[Step 10060] loss: 0.1901
[Step 10070] loss: 0.1991
[Step 10080] loss: 0.2143
[Step 10090] loss: 0.1833
[Step 10100] loss: 0.1989
[Step 10110] loss: 0.2019
[Step 10120] loss: 0.1644
[Step 10130] loss: 0.1932
[Step 10140] loss: 0.1509
[Step 10150] loss: 0.1988
[Step 10160] loss: 0.1792
[Step 10170] loss: 0.2070
[Step 10180] loss: 0.2036
[Step 10190] loss: 0.2394
[Step 10200] loss: 0.2202
[Step 10210] loss: 0.2035
[Step 10220] loss: 0.1824
[Step 10230] loss: 0.1658
[Step 10240] loss: 0.2104
[Step 10250] loss: 0.2207
[Step 10260] loss: 0.2306
[Step 10270] loss: 0.2038
[Step 10280] loss: 0.1992
[Step 10290] loss: 0.2193
[Step 10300] loss: 0.1857
[Step 10310] loss: 0.1959
[Step 10320] loss: 0.1908
[Step 10330] loss: 0.2246
[Step 10340] loss: 0.2106
[Step 10350] loss: 0.1917
[Step 10360] loss: 0.2177
[Step 10370] loss: 0.2129
[Step 10380]

sampling: 0it [00:00, ?it/s]

[Step 11000] loss: 0.2023
[Step 11010] loss: 0.2038
[Step 11020] loss: 0.2225
[Step 11030] loss: 0.1845
[Step 11040] loss: 0.1837
[Step 11050] loss: 0.2137
[Step 11060] loss: 0.2425
[Step 11070] loss: 0.1991
[Step 11080] loss: 0.2108
[Step 11090] loss: 0.2181
[Step 11100] loss: 0.2301
[Step 11110] loss: 0.1985
[Step 11120] loss: 0.1741
[Step 11130] loss: 0.1803
[Step 11140] loss: 0.2276
[Step 11150] loss: 0.2361
[Step 11160] loss: 0.1843
[Step 11170] loss: 0.2544
[Step 11180] loss: 0.2030
[Step 11190] loss: 0.2074
[Step 11200] loss: 0.1704
[Step 11210] loss: 0.1793
[Step 11220] loss: 0.2074
[Step 11230] loss: 0.1756
[Step 11240] loss: 0.2248
[Step 11250] loss: 0.1949
[Step 11260] loss: 0.1939
[Step 11270] loss: 0.2183
[Step 11280] loss: 0.2320
[Step 11290] loss: 0.2139
[Step 11300] loss: 0.1698
[Step 11310] loss: 0.2141
[Step 11320] loss: 0.1924
[Step 11330] loss: 0.2031
[Step 11340] loss: 0.2154
[Step 11350] loss: 0.2438
[Step 11360] loss: 0.2083
[Step 11370] loss: 0.2019
[Step 11380]

sampling: 0it [00:00, ?it/s]

[Step 12000] loss: 0.2575
[Step 12010] loss: 0.1861
[Step 12020] loss: 0.2024
[Step 12030] loss: 0.2082
[Step 12040] loss: 0.2367
[Step 12050] loss: 0.2204
[Step 12060] loss: 0.2108
[Step 12070] loss: 0.2038
[Step 12080] loss: 0.2148
[Step 12090] loss: 0.2009
[Step 12100] loss: 0.2299
[Step 12110] loss: 0.1688
[Step 12120] loss: 0.2063
[Step 12130] loss: 0.2495
[Step 12140] loss: 0.2198
[Step 12150] loss: 0.2545
[Step 12160] loss: 0.2417
[Step 12170] loss: 0.2119
[Step 12180] loss: 0.1778
[Step 12190] loss: 0.1944
[Step 12200] loss: 0.2023
[Step 12210] loss: 0.1966
[Step 12220] loss: 0.2265
[Step 12230] loss: 0.1873
[Step 12240] loss: 0.2131
[Step 12250] loss: 0.2073
[Step 12260] loss: 0.1970
[Step 12270] loss: 0.1879
[Step 12280] loss: 0.1794
[Step 12290] loss: 0.2243
[Step 12300] loss: 0.2134
[Step 12310] loss: 0.2081
[Step 12320] loss: 0.2463
[Step 12330] loss: 0.1919
[Step 12340] loss: 0.2037
[Step 12350] loss: 0.1913
[Step 12360] loss: 0.2228
[Step 12370] loss: 0.2351
[Step 12380]

sampling: 0it [00:00, ?it/s]

[Step 13000] loss: 0.1913
[Step 13010] loss: 0.2104
[Step 13020] loss: 0.2220
[Step 13030] loss: 0.2030
[Step 13040] loss: 0.2163
[Step 13050] loss: 0.1683
[Step 13060] loss: 0.1713
[Step 13070] loss: 0.1830
[Step 13080] loss: 0.2139
[Step 13090] loss: 0.1944
[Step 13100] loss: 0.1707
[Step 13110] loss: 0.1864
[Step 13120] loss: 0.2117
[Step 13130] loss: 0.1856
[Step 13140] loss: 0.1754
[Step 13150] loss: 0.2062
[Step 13160] loss: 0.2199
[Step 13170] loss: 0.2206
[Step 13180] loss: 0.2019
[Step 13190] loss: 0.1764
[Step 13200] loss: 0.2133
[Step 13210] loss: 0.1758
[Step 13220] loss: 0.1966
[Step 13230] loss: 0.2194
[Step 13240] loss: 0.2187
[Step 13250] loss: 0.1730
[Step 13260] loss: 0.2126
[Step 13270] loss: 0.2029
[Step 13280] loss: 0.2442
[Step 13290] loss: 0.2106
[Step 13300] loss: 0.2019
[Step 13310] loss: 0.2085
[Step 13320] loss: 0.2012
[Step 13330] loss: 0.2272
[Step 13340] loss: 0.1960
[Step 13350] loss: 0.2057
[Step 13360] loss: 0.1935
[Step 13370] loss: 0.1834
[Step 13380]

sampling: 0it [00:00, ?it/s]

[Step 14000] loss: 0.2245
[Step 14010] loss: 0.2498
[Step 14020] loss: 0.2000
[Step 14030] loss: 0.2625
[Step 14040] loss: 0.2042
[Step 14050] loss: 0.1971
[Step 14060] loss: 0.2122
[Step 14070] loss: 0.2197
[Step 14080] loss: 0.2121
[Step 14090] loss: 0.1873
[Step 14100] loss: 0.1770
[Step 14110] loss: 0.1748
[Step 14120] loss: 0.2321
[Step 14130] loss: 0.2127
[Step 14140] loss: 0.2081
[Step 14150] loss: 0.2079
[Step 14160] loss: 0.1948
[Step 14170] loss: 0.1941
[Step 14180] loss: 0.1958
[Step 14190] loss: 0.1934
[Step 14200] loss: 0.1996
[Step 14210] loss: 0.1920
[Step 14220] loss: 0.2157
[Step 14230] loss: 0.2200
[Step 14240] loss: 0.1950
[Step 14250] loss: 0.2209
[Step 14260] loss: 0.1932
[Step 14270] loss: 0.2087
[Step 14280] loss: 0.2277
[Step 14290] loss: 0.2753
[Step 14300] loss: 0.2412
[Step 14310] loss: 0.2001
[Step 14320] loss: 0.2007
[Step 14330] loss: 0.1973
[Step 14340] loss: 0.2120
[Step 14350] loss: 0.1815
[Step 14360] loss: 0.1809
[Step 14370] loss: 0.2187
[Step 14380]

sampling: 0it [00:00, ?it/s]

[Step 15000] loss: 0.1997
[Step 15010] loss: 0.2045
[Step 15020] loss: 0.2004
[Step 15030] loss: 0.1997
[Step 15040] loss: 0.2062
[Step 15050] loss: 0.2322
[Step 15060] loss: 0.2437
[Step 15070] loss: 0.1980
[Step 15080] loss: 0.1939
[Step 15090] loss: 0.2125
[Step 15100] loss: 0.2043
[Step 15110] loss: 0.1883
[Step 15120] loss: 0.1988
[Step 15130] loss: 0.2348
[Step 15140] loss: 0.2217
[Step 15150] loss: 0.1606
[Step 15160] loss: 0.2266
[Step 15170] loss: 0.1873
[Step 15180] loss: 0.1984
[Step 15190] loss: 0.2079
[Step 15200] loss: 0.2147
[Step 15210] loss: 0.2118
[Step 15220] loss: 0.1969
[Step 15230] loss: 0.2302
[Step 15240] loss: 0.2267
[Step 15250] loss: 0.1785
[Step 15260] loss: 0.2345
[Step 15270] loss: 0.1882
[Step 15280] loss: 0.2438
[Step 15290] loss: 0.2069
[Step 15300] loss: 0.2098
[Step 15310] loss: 0.1963
[Step 15320] loss: 0.2277
[Step 15330] loss: 0.2412
[Step 15340] loss: 0.2350
[Step 15350] loss: 0.2633
[Step 15360] loss: 0.2184
[Step 15370] loss: 0.2107
[Step 15380]

sampling: 0it [00:00, ?it/s]

[Step 16000] loss: 0.2036
[Step 16010] loss: 0.1642
[Step 16020] loss: 0.2223
[Step 16030] loss: 0.2087
[Step 16040] loss: 0.2153
[Step 16050] loss: 0.2375
[Step 16060] loss: 0.2191
[Step 16070] loss: 0.2226
[Step 16080] loss: 0.2285
[Step 16090] loss: 0.2098
[Step 16100] loss: 0.2266
[Step 16110] loss: 0.1616
[Step 16120] loss: 0.1941
[Step 16130] loss: 0.2309
[Step 16140] loss: 0.1988
[Step 16150] loss: 0.1993
[Step 16160] loss: 0.2182
[Step 16170] loss: 0.1811
[Step 16180] loss: 0.2026
[Step 16190] loss: 0.2129
[Step 16200] loss: 0.2032
[Step 16210] loss: 0.2135
[Step 16220] loss: 0.1522
[Step 16230] loss: 0.1871
[Step 16240] loss: 0.2495
[Step 16250] loss: 0.2207
[Step 16260] loss: 0.2257
[Step 16270] loss: 0.1896
[Step 16280] loss: 0.1793
[Step 16290] loss: 0.2097
[Step 16300] loss: 0.2073
[Step 16310] loss: 0.2134
[Step 16320] loss: 0.2358
[Step 16330] loss: 0.2155
[Step 16340] loss: 0.2037
[Step 16350] loss: 0.2338
[Step 16360] loss: 0.1903
[Step 16370] loss: 0.1897
[Step 16380]

sampling: 0it [00:00, ?it/s]

[Step 17000] loss: 0.2124
[Step 17010] loss: 0.1867
[Step 17020] loss: 0.2181
[Step 17030] loss: 0.1974
[Step 17040] loss: 0.2096
[Step 17050] loss: 0.2077
[Step 17060] loss: 0.2143
[Step 17070] loss: 0.1877
[Step 17080] loss: 0.2318
[Step 17090] loss: 0.1784
[Step 17100] loss: 0.2507
[Step 17110] loss: 0.2072
[Step 17120] loss: 0.1975
[Step 17130] loss: 0.2104
[Step 17140] loss: 0.1812
[Step 17150] loss: 0.1810
[Step 17160] loss: 0.1752
[Step 17170] loss: 0.2373
[Step 17180] loss: 0.2253
[Step 17190] loss: 0.2453
[Step 17200] loss: 0.2173
[Step 17210] loss: 0.2577
[Step 17220] loss: 0.2234
[Step 17230] loss: 0.2202
[Step 17240] loss: 0.1768
[Step 17250] loss: 0.2054
[Step 17260] loss: 0.2087
[Step 17270] loss: 0.2228
[Step 17280] loss: 0.1727
[Step 17290] loss: 0.2077
[Step 17300] loss: 0.2192
[Step 17310] loss: 0.2365
[Step 17320] loss: 0.2102
[Step 17330] loss: 0.2088
[Step 17340] loss: 0.2274
[Step 17350] loss: 0.2135
[Step 17360] loss: 0.1840
[Step 17370] loss: 0.1962
[Step 17380]

sampling: 0it [00:00, ?it/s]

[Step 18000] loss: 0.1746
[Step 18010] loss: 0.2440
[Step 18020] loss: 0.2226
[Step 18030] loss: 0.1951
[Step 18040] loss: 0.2274
[Step 18050] loss: 0.2301
[Step 18060] loss: 0.2004
[Step 18070] loss: 0.1839
[Step 18080] loss: 0.2405
[Step 18090] loss: 0.2017
[Step 18100] loss: 0.2236
[Step 18110] loss: 0.2178
[Step 18120] loss: 0.1979
[Step 18130] loss: 0.2102
[Step 18140] loss: 0.2246
[Step 18150] loss: 0.1699
[Step 18160] loss: 0.2136
[Step 18170] loss: 0.2295
[Step 18180] loss: 0.1988
[Step 18190] loss: 0.2158
[Step 18200] loss: 0.2245
[Step 18210] loss: 0.2037
[Step 18220] loss: 0.2199
[Step 18230] loss: 0.2260
[Step 18240] loss: 0.1769
[Step 18250] loss: 0.1505
[Step 18260] loss: 0.2133
[Step 18270] loss: 0.2232
[Step 18280] loss: 0.2169
[Step 18290] loss: 0.2064
[Step 18300] loss: 0.2060
[Step 18310] loss: 0.1954
[Step 18320] loss: 0.1639
[Step 18330] loss: 0.2061
[Step 18340] loss: 0.1907
[Step 18350] loss: 0.1979
[Step 18360] loss: 0.2208
[Step 18370] loss: 0.1814
[Step 18380]

sampling: 0it [00:00, ?it/s]

[Step 19000] loss: 0.1834
[Step 19010] loss: 0.1962
[Step 19020] loss: 0.2081
[Step 19030] loss: 0.2109
[Step 19040] loss: 0.1962
[Step 19050] loss: 0.1999
[Step 19060] loss: 0.2409
[Step 19070] loss: 0.2219
[Step 19080] loss: 0.2075
[Step 19090] loss: 0.2018
[Step 19100] loss: 0.2385
[Step 19110] loss: 0.2049
[Step 19120] loss: 0.1714
[Step 19130] loss: 0.2089
[Step 19140] loss: 0.1726
[Step 19150] loss: 0.2294
[Step 19160] loss: 0.1961
[Step 19170] loss: 0.2515
[Step 19180] loss: 0.1959
[Step 19190] loss: 0.2119
[Step 19200] loss: 0.2039
[Step 19210] loss: 0.1966
[Step 19220] loss: 0.1832
[Step 19230] loss: 0.2219
[Step 19240] loss: 0.2232
[Step 19250] loss: 0.2307
[Step 19260] loss: 0.1810
[Step 19270] loss: 0.2314
[Step 19280] loss: 0.1704
[Step 19290] loss: 0.1935
[Step 19300] loss: 0.2137
[Step 19310] loss: 0.2039
[Step 19320] loss: 0.2003
[Step 19330] loss: 0.2319
[Step 19340] loss: 0.1854
[Step 19350] loss: 0.2143
[Step 19360] loss: 0.2104
[Step 19370] loss: 0.1949
[Step 19380]

sampling: 0it [00:00, ?it/s]

[Step 20000] loss: 0.2299
[Step 20010] loss: 0.2030
[Step 20020] loss: 0.1963
[Step 20030] loss: 0.2318
[Step 20040] loss: 0.1678
[Step 20050] loss: 0.1778
[Step 20060] loss: 0.1745
[Step 20070] loss: 0.1955
[Step 20080] loss: 0.1902
[Step 20090] loss: 0.2245
[Step 20100] loss: 0.2019
[Step 20110] loss: 0.2310
[Step 20120] loss: 0.1835
[Step 20130] loss: 0.1616
[Step 20140] loss: 0.2261
[Step 20150] loss: 0.1998
[Step 20160] loss: 0.1742
[Step 20170] loss: 0.2049
[Step 20180] loss: 0.1952
[Step 20190] loss: 0.2030
[Step 20200] loss: 0.2751
[Step 20210] loss: 0.2255
[Step 20220] loss: 0.2104
[Step 20230] loss: 0.2067
[Step 20240] loss: 0.1840
[Step 20250] loss: 0.2303
[Step 20260] loss: 0.2080
[Step 20270] loss: 0.2121
[Step 20280] loss: 0.2041
[Step 20290] loss: 0.1993
[Step 20300] loss: 0.2149
[Step 20310] loss: 0.1897
[Step 20320] loss: 0.1957
[Step 20330] loss: 0.2184
[Step 20340] loss: 0.2018
[Step 20350] loss: 0.2060
[Step 20360] loss: 0.2174
[Step 20370] loss: 0.1981
[Step 20380]

sampling: 0it [00:00, ?it/s]

[Step 21000] loss: 0.1997
[Step 21010] loss: 0.2042
[Step 21020] loss: 0.2070
[Step 21030] loss: 0.2076
[Step 21040] loss: 0.2086
[Step 21050] loss: 0.1857
[Step 21060] loss: 0.2065
[Step 21070] loss: 0.2334
[Step 21080] loss: 0.1894
[Step 21090] loss: 0.2024
[Step 21100] loss: 0.2195
[Step 21110] loss: 0.2026
[Step 21120] loss: 0.1784
[Step 21130] loss: 0.1886
[Step 21140] loss: 0.2048
[Step 21150] loss: 0.1752
[Step 21160] loss: 0.2057
[Step 21170] loss: 0.1801
[Step 21180] loss: 0.1974
[Step 21190] loss: 0.2121
[Step 21200] loss: 0.1832
[Step 21210] loss: 0.2159
[Step 21220] loss: 0.2272
[Step 21230] loss: 0.1924
[Step 21240] loss: 0.1917
[Step 21250] loss: 0.2347
[Step 21260] loss: 0.1904
[Step 21270] loss: 0.2098
[Step 21280] loss: 0.2061
[Step 21290] loss: 0.2301
[Step 21300] loss: 0.2027
[Step 21310] loss: 0.2477
[Step 21320] loss: 0.2141
[Step 21330] loss: 0.2192
[Step 21340] loss: 0.2143
[Step 21350] loss: 0.2195
[Step 21360] loss: 0.1898
[Step 21370] loss: 0.2323
[Step 21380]

sampling: 0it [00:00, ?it/s]

[Step 22000] loss: 0.1984
[Step 22010] loss: 0.1619
[Step 22020] loss: 0.2136
[Step 22030] loss: 0.1557
[Step 22040] loss: 0.2042
[Step 22050] loss: 0.1992
[Step 22060] loss: 0.2181
[Step 22070] loss: 0.2138
[Step 22080] loss: 0.1652
[Step 22090] loss: 0.1955
[Step 22100] loss: 0.1856
[Step 22110] loss: 0.2223
[Step 22120] loss: 0.2396
[Step 22130] loss: 0.1989
[Step 22140] loss: 0.2115
[Step 22150] loss: 0.2123
[Step 22160] loss: 0.2054
[Step 22170] loss: 0.2296
[Step 22180] loss: 0.1911
[Step 22190] loss: 0.2346
[Step 22200] loss: 0.1690
[Step 22210] loss: 0.1898
[Step 22220] loss: 0.1657
[Step 22230] loss: 0.1978
[Step 22240] loss: 0.2370
[Step 22250] loss: 0.2133
[Step 22260] loss: 0.2088
[Step 22270] loss: 0.1858
[Step 22280] loss: 0.1977
[Step 22290] loss: 0.1961
[Step 22300] loss: 0.1658
[Step 22310] loss: 0.2055
[Step 22320] loss: 0.1886
[Step 22330] loss: 0.2190
[Step 22340] loss: 0.2346
[Step 22350] loss: 0.1898
[Step 22360] loss: 0.2092
[Step 22370] loss: 0.1930
[Step 22380]

sampling: 0it [00:00, ?it/s]

[Step 23000] loss: 0.1853
[Step 23010] loss: 0.1775
[Step 23020] loss: 0.2132
[Step 23030] loss: 0.1658
[Step 23040] loss: 0.1836
[Step 23050] loss: 0.1948
[Step 23060] loss: 0.1924
[Step 23070] loss: 0.2090
[Step 23080] loss: 0.2443
[Step 23090] loss: 0.2026
[Step 23100] loss: 0.2423
[Step 23110] loss: 0.2034
[Step 23120] loss: 0.2014
[Step 23130] loss: 0.1818
[Step 23140] loss: 0.2248
[Step 23150] loss: 0.2215
[Step 23160] loss: 0.1971
[Step 23170] loss: 0.2674
[Step 23180] loss: 0.1715
[Step 23190] loss: 0.2171
[Step 23200] loss: 0.2165
[Step 23210] loss: 0.2154
[Step 23220] loss: 0.1758
[Step 23230] loss: 0.2102
[Step 23240] loss: 0.1865
[Step 23250] loss: 0.2261
[Step 23260] loss: 0.2186
[Step 23270] loss: 0.2183
[Step 23280] loss: 0.1920
[Step 23290] loss: 0.2159
[Step 23300] loss: 0.2340
[Step 23310] loss: 0.1825
[Step 23320] loss: 0.1836
[Step 23330] loss: 0.2175
[Step 23340] loss: 0.2108
[Step 23350] loss: 0.2017
[Step 23360] loss: 0.2288
[Step 23370] loss: 0.2326
[Step 23380]

sampling: 0it [00:00, ?it/s]

[Step 24000] loss: 0.2281
[Step 24010] loss: 0.1822
[Step 24020] loss: 0.1814
[Step 24030] loss: 0.2143
[Step 24040] loss: 0.1946
[Step 24050] loss: 0.2058
[Step 24060] loss: 0.2390
[Step 24070] loss: 0.1831
[Step 24080] loss: 0.2447
[Step 24090] loss: 0.1689
[Step 24100] loss: 0.1757
[Step 24110] loss: 0.2101
[Step 24120] loss: 0.2104
[Step 24130] loss: 0.2049
[Step 24140] loss: 0.2167
[Step 24150] loss: 0.1787
[Step 24160] loss: 0.2167
[Step 24170] loss: 0.2209
[Step 24180] loss: 0.2125
[Step 24190] loss: 0.2246
[Step 24200] loss: 0.1627
[Step 24210] loss: 0.1941
[Step 24220] loss: 0.2170
[Step 24230] loss: 0.2108
[Step 24240] loss: 0.2144
[Step 24250] loss: 0.2123
[Step 24260] loss: 0.2080
[Step 24270] loss: 0.1640
[Step 24280] loss: 0.2241
[Step 24290] loss: 0.2289
[Step 24300] loss: 0.1971
[Step 24310] loss: 0.2286
[Step 24320] loss: 0.2090
[Step 24330] loss: 0.1919
[Step 24340] loss: 0.1950
[Step 24350] loss: 0.1986
[Step 24360] loss: 0.2094
[Step 24370] loss: 0.2114
[Step 24380]

sampling: 0it [00:00, ?it/s]

[Step 25000] loss: 0.2236
[Step 25010] loss: 0.2110
[Step 25020] loss: 0.1822
[Step 25030] loss: 0.1799
[Step 25040] loss: 0.1995
[Step 25050] loss: 0.1744
[Step 25060] loss: 0.1769
[Step 25070] loss: 0.2093
[Step 25080] loss: 0.2133
[Step 25090] loss: 0.1794
[Step 25100] loss: 0.2042
[Step 25110] loss: 0.2257
[Step 25120] loss: 0.1694
[Step 25130] loss: 0.1876
[Step 25140] loss: 0.2096
[Step 25150] loss: 0.2287
[Step 25160] loss: 0.2115
[Step 25170] loss: 0.1846
[Step 25180] loss: 0.2126
[Step 25190] loss: 0.2051
[Step 25200] loss: 0.1877
[Step 25210] loss: 0.2085
[Step 25220] loss: 0.1760
[Step 25230] loss: 0.2064
[Step 25240] loss: 0.2017
[Step 25250] loss: 0.1691
[Step 25260] loss: 0.1819
[Step 25270] loss: 0.1976
[Step 25280] loss: 0.1852
[Step 25290] loss: 0.1980
[Step 25300] loss: 0.2028
[Step 25310] loss: 0.1909
[Step 25320] loss: 0.2139
[Step 25330] loss: 0.1734
[Step 25340] loss: 0.1631
[Step 25350] loss: 0.1770
[Step 25360] loss: 0.1897
[Step 25370] loss: 0.1859
[Step 25380]

sampling: 0it [00:00, ?it/s]

[Step 26000] loss: 0.2205
[Step 26010] loss: 0.2023
[Step 26020] loss: 0.1954
[Step 26030] loss: 0.1679
[Step 26040] loss: 0.2085
[Step 26050] loss: 0.2278
[Step 26060] loss: 0.2218
[Step 26070] loss: 0.2114
[Step 26080] loss: 0.2075
[Step 26090] loss: 0.2118
[Step 26100] loss: 0.2112
[Step 26110] loss: 0.2178
[Step 26120] loss: 0.1928
[Step 26130] loss: 0.1923
[Step 26140] loss: 0.1895
[Step 26150] loss: 0.1761
[Step 26160] loss: 0.1798
[Step 26170] loss: 0.2278
[Step 26180] loss: 0.2205
[Step 26190] loss: 0.2004
[Step 26200] loss: 0.1907
[Step 26210] loss: 0.2314
[Step 26220] loss: 0.1919
[Step 26230] loss: 0.1834
[Step 26240] loss: 0.2292
[Step 26250] loss: 0.1862
[Step 26260] loss: 0.2032
[Step 26270] loss: 0.1802
[Step 26280] loss: 0.2234
[Step 26290] loss: 0.2103
[Step 26300] loss: 0.2029
[Step 26310] loss: 0.2064
[Step 26320] loss: 0.1662
[Step 26330] loss: 0.1995
[Step 26340] loss: 0.2281
[Step 26350] loss: 0.1723
[Step 26360] loss: 0.2213
[Step 26370] loss: 0.2532
[Step 26380]

sampling: 0it [00:00, ?it/s]

[Step 27000] loss: 0.1732
[Step 27010] loss: 0.1893
[Step 27020] loss: 0.2472
[Step 27030] loss: 0.2171
[Step 27040] loss: 0.2561
[Step 27050] loss: 0.2354
[Step 27060] loss: 0.1945
[Step 27070] loss: 0.2037
[Step 27080] loss: 0.1813
[Step 27090] loss: 0.2024
[Step 27100] loss: 0.2466
[Step 27110] loss: 0.2204
[Step 27120] loss: 0.2106
[Step 27130] loss: 0.1752
[Step 27140] loss: 0.2283
[Step 27150] loss: 0.1919
[Step 27160] loss: 0.2208
[Step 27170] loss: 0.1834
[Step 27180] loss: 0.1949
[Step 27190] loss: 0.2265
[Step 27200] loss: 0.2216
[Step 27210] loss: 0.2192
[Step 27220] loss: 0.1799
[Step 27230] loss: 0.1995
[Step 27240] loss: 0.2192
[Step 27250] loss: 0.2015
[Step 27260] loss: 0.2161
[Step 27270] loss: 0.1794
[Step 27280] loss: 0.2119
[Step 27290] loss: 0.2314
[Step 27300] loss: 0.1655
[Step 27310] loss: 0.2069
[Step 27320] loss: 0.1977
[Step 27330] loss: 0.1749
[Step 27340] loss: 0.1861
[Step 27350] loss: 0.1871
[Step 27360] loss: 0.1949
[Step 27370] loss: 0.1935
[Step 27380]

sampling: 0it [00:00, ?it/s]

[Step 28000] loss: 0.1986
[Step 28010] loss: 0.2364
[Step 28020] loss: 0.1889
[Step 28030] loss: 0.1933
[Step 28040] loss: 0.1992
[Step 28050] loss: 0.2248
[Step 28060] loss: 0.2049
[Step 28070] loss: 0.1870
[Step 28080] loss: 0.2065
[Step 28090] loss: 0.2061
[Step 28100] loss: 0.2278
[Step 28110] loss: 0.2373
[Step 28120] loss: 0.2221
[Step 28130] loss: 0.2074
[Step 28140] loss: 0.1574
[Step 28150] loss: 0.1802
[Step 28160] loss: 0.1677
[Step 28170] loss: 0.2199
[Step 28180] loss: 0.2129
[Step 28190] loss: 0.1843
[Step 28200] loss: 0.1912
[Step 28210] loss: 0.2512
[Step 28220] loss: 0.1907
[Step 28230] loss: 0.2162
[Step 28240] loss: 0.1635
[Step 28250] loss: 0.2273
[Step 28260] loss: 0.1797
[Step 28270] loss: 0.2418
[Step 28280] loss: 0.2326
[Step 28290] loss: 0.2464
[Step 28300] loss: 0.2295
[Step 28310] loss: 0.1820
[Step 28320] loss: 0.1981
[Step 28330] loss: 0.1975
[Step 28340] loss: 0.2025
[Step 28350] loss: 0.2018
[Step 28360] loss: 0.1735
[Step 28370] loss: 0.1837
[Step 28380]

sampling: 0it [00:00, ?it/s]

[Step 29000] loss: 0.2129
[Step 29010] loss: 0.2241
[Step 29020] loss: 0.1807
[Step 29030] loss: 0.2111
[Step 29040] loss: 0.2094
[Step 29050] loss: 0.2016
[Step 29060] loss: 0.1794
[Step 29070] loss: 0.1786
[Step 29080] loss: 0.2160
[Step 29090] loss: 0.1757
[Step 29100] loss: 0.1941
[Step 29110] loss: 0.1744
[Step 29120] loss: 0.2060
[Step 29130] loss: 0.1967
[Step 29140] loss: 0.1681
[Step 29150] loss: 0.2239
[Step 29160] loss: 0.1900
[Step 29170] loss: 0.2035
[Step 29180] loss: 0.2032
[Step 29190] loss: 0.2258
[Step 29200] loss: 0.2269
[Step 29210] loss: 0.2247
[Step 29220] loss: 0.2191
[Step 29230] loss: 0.1794
[Step 29240] loss: 0.1888
[Step 29250] loss: 0.2325
[Step 29260] loss: 0.2105
[Step 29270] loss: 0.2148
[Step 29280] loss: 0.2210
[Step 29290] loss: 0.2056
[Step 29300] loss: 0.2062
[Step 29310] loss: 0.1937
[Step 29320] loss: 0.1946
[Step 29330] loss: 0.2314
[Step 29340] loss: 0.2036
[Step 29350] loss: 0.2320
[Step 29360] loss: 0.2233
[Step 29370] loss: 0.1987
[Step 29380]

sampling: 0it [00:00, ?it/s]

[Step 30000] loss: 0.2009
[Step 30010] loss: 0.2159
[Step 30020] loss: 0.1991
[Step 30030] loss: 0.1965
[Step 30040] loss: 0.1971
[Step 30050] loss: 0.2465
[Step 30060] loss: 0.2039
[Step 30070] loss: 0.2046
[Step 30080] loss: 0.2071
[Step 30090] loss: 0.1777
[Step 30100] loss: 0.2292
[Step 30110] loss: 0.2488
[Step 30120] loss: 0.2110
[Step 30130] loss: 0.2071
[Step 30140] loss: 0.2006
[Step 30150] loss: 0.1993
[Step 30160] loss: 0.2412
[Step 30170] loss: 0.2219
[Step 30180] loss: 0.2073
[Step 30190] loss: 0.2127
[Step 30200] loss: 0.2052
[Step 30210] loss: 0.2017
[Step 30220] loss: 0.2109
[Step 30230] loss: 0.2294
[Step 30240] loss: 0.2080
[Step 30250] loss: 0.2186
[Step 30260] loss: 0.2205
[Step 30270] loss: 0.2173
[Step 30280] loss: 0.1940
[Step 30290] loss: 0.2082
[Step 30300] loss: 0.2410
[Step 30310] loss: 0.1950
[Step 30320] loss: 0.1960
[Step 30330] loss: 0.2232
[Step 30340] loss: 0.2137
[Step 30350] loss: 0.2059
[Step 30360] loss: 0.1855
[Step 30370] loss: 0.2010
[Step 30380]

sampling: 0it [00:00, ?it/s]

[Step 31000] loss: 0.1844
[Step 31010] loss: 0.2342
[Step 31020] loss: 0.1603
[Step 31030] loss: 0.2026
[Step 31040] loss: 0.2151
[Step 31050] loss: 0.1870
[Step 31060] loss: 0.2130
[Step 31070] loss: 0.2096
[Step 31080] loss: 0.2008
[Step 31090] loss: 0.2227
[Step 31100] loss: 0.1980
[Step 31110] loss: 0.2136
[Step 31120] loss: 0.1540
[Step 31130] loss: 0.2135
[Step 31140] loss: 0.1737
[Step 31150] loss: 0.1993
[Step 31160] loss: 0.2091
[Step 31170] loss: 0.2261
[Step 31180] loss: 0.2033
[Step 31190] loss: 0.1959
[Step 31200] loss: 0.1745
[Step 31210] loss: 0.1942
[Step 31220] loss: 0.2099
[Step 31230] loss: 0.2207
[Step 31240] loss: 0.2098
[Step 31250] loss: 0.2044
[Step 31260] loss: 0.2052
[Step 31270] loss: 0.2051
[Step 31280] loss: 0.1586
[Step 31290] loss: 0.1908
[Step 31300] loss: 0.2414
[Step 31310] loss: 0.2020
[Step 31320] loss: 0.2004
[Step 31330] loss: 0.1510
[Step 31340] loss: 0.1956
[Step 31350] loss: 0.1945
[Step 31360] loss: 0.2094
[Step 31370] loss: 0.2012
[Step 31380]

sampling: 0it [00:00, ?it/s]

[Step 32000] loss: 0.1987
[Step 32010] loss: 0.2045
[Step 32020] loss: 0.2192
[Step 32030] loss: 0.2110
[Step 32040] loss: 0.2186
[Step 32050] loss: 0.1677
[Step 32060] loss: 0.2066
[Step 32070] loss: 0.2094
[Step 32080] loss: 0.1665
[Step 32090] loss: 0.2308
[Step 32100] loss: 0.2040
[Step 32110] loss: 0.1718
[Step 32120] loss: 0.1975
[Step 32130] loss: 0.2062
[Step 32140] loss: 0.2165
[Step 32150] loss: 0.1814
[Step 32160] loss: 0.2150
[Step 32170] loss: 0.2104
[Step 32180] loss: 0.1900
[Step 32190] loss: 0.1855
[Step 32200] loss: 0.1889
[Step 32210] loss: 0.1955
[Step 32220] loss: 0.2181
[Step 32230] loss: 0.2074
[Step 32240] loss: 0.1971
[Step 32250] loss: 0.1619
[Step 32260] loss: 0.2180
[Step 32270] loss: 0.2118
[Step 32280] loss: 0.1569
[Step 32290] loss: 0.1838
[Step 32300] loss: 0.2108
[Step 32310] loss: 0.2031
[Step 32320] loss: 0.1957
[Step 32330] loss: 0.2110
[Step 32340] loss: 0.1815
[Step 32350] loss: 0.2229
[Step 32360] loss: 0.2079
[Step 32370] loss: 0.1920
[Step 32380]

sampling: 0it [00:00, ?it/s]

[Step 33000] loss: 0.2249
[Step 33010] loss: 0.1816
[Step 33020] loss: 0.2159
[Step 33030] loss: 0.2612
[Step 33040] loss: 0.1876
[Step 33050] loss: 0.1553
[Step 33060] loss: 0.1564
[Step 33070] loss: 0.2076
[Step 33080] loss: 0.2075
[Step 33090] loss: 0.2062
[Step 33100] loss: 0.1700
[Step 33110] loss: 0.2428
[Step 33120] loss: 0.1734
[Step 33130] loss: 0.2012
[Step 33140] loss: 0.2350
[Step 33150] loss: 0.2126
[Step 33160] loss: 0.2045
[Step 33170] loss: 0.2060
[Step 33180] loss: 0.1793
[Step 33190] loss: 0.2226
[Step 33200] loss: 0.1980
[Step 33210] loss: 0.2119
[Step 33220] loss: 0.2163
[Step 33230] loss: 0.2182
[Step 33240] loss: 0.2205
[Step 33250] loss: 0.2317
[Step 33260] loss: 0.1928
[Step 33270] loss: 0.1835
[Step 33280] loss: 0.1939
[Step 33290] loss: 0.2036
[Step 33300] loss: 0.2126
[Step 33310] loss: 0.2289
[Step 33320] loss: 0.2333
[Step 33330] loss: 0.2104
[Step 33340] loss: 0.1696
[Step 33350] loss: 0.1874
[Step 33360] loss: 0.2117
[Step 33370] loss: 0.2201
[Step 33380]

sampling: 0it [00:00, ?it/s]

[Step 34000] loss: 0.2229
[Step 34010] loss: 0.1990
[Step 34020] loss: 0.2045
[Step 34030] loss: 0.1906
[Step 34040] loss: 0.1956
[Step 34050] loss: 0.2005
[Step 34060] loss: 0.2020
[Step 34070] loss: 0.2000
[Step 34080] loss: 0.1973
[Step 34090] loss: 0.1596
[Step 34100] loss: 0.2045
[Step 34110] loss: 0.2283
[Step 34120] loss: 0.2232
[Step 34130] loss: 0.1934
[Step 34140] loss: 0.2211
[Step 34150] loss: 0.1677
[Step 34160] loss: 0.1974
[Step 34170] loss: 0.2310
[Step 34180] loss: 0.1932
[Step 34190] loss: 0.1998
[Step 34200] loss: 0.1668
[Step 34210] loss: 0.1753
[Step 34220] loss: 0.2261
[Step 34230] loss: 0.1796
[Step 34240] loss: 0.2459
[Step 34250] loss: 0.2112
[Step 34260] loss: 0.2371
[Step 34270] loss: 0.1660
[Step 34280] loss: 0.2387
[Step 34290] loss: 0.2203
[Step 34300] loss: 0.1963
[Step 34310] loss: 0.2007
[Step 34320] loss: 0.2151
[Step 34330] loss: 0.2606
[Step 34340] loss: 0.2097
[Step 34350] loss: 0.2002
[Step 34360] loss: 0.1770
[Step 34370] loss: 0.2019
[Step 34380]

sampling: 0it [00:00, ?it/s]

[Step 35000] loss: 0.2250
[Step 35010] loss: 0.2070
[Step 35020] loss: 0.1709
[Step 35030] loss: 0.2047
[Step 35040] loss: 0.2176
[Step 35050] loss: 0.1922
[Step 35060] loss: 0.2078
[Step 35070] loss: 0.1732
[Step 35080] loss: 0.2337
[Step 35090] loss: 0.2158
[Step 35100] loss: 0.2359
[Step 35110] loss: 0.1749
[Step 35120] loss: 0.2148
[Step 35130] loss: 0.2017
[Step 35140] loss: 0.1601
[Step 35150] loss: 0.2220
[Step 35160] loss: 0.2261
[Step 35170] loss: 0.2173
[Step 35180] loss: 0.2277
[Step 35190] loss: 0.2020
[Step 35200] loss: 0.2051
[Step 35210] loss: 0.1573
[Step 35220] loss: 0.2000
[Step 35230] loss: 0.1675
[Step 35240] loss: 0.1926
[Step 35250] loss: 0.1962
[Step 35260] loss: 0.2056
[Step 35270] loss: 0.1931
[Step 35280] loss: 0.1947
[Step 35290] loss: 0.1932
[Step 35300] loss: 0.2590
[Step 35310] loss: 0.2118
[Step 35320] loss: 0.2103
[Step 35330] loss: 0.1806
[Step 35340] loss: 0.1834
[Step 35350] loss: 0.1857
[Step 35360] loss: 0.2067
[Step 35370] loss: 0.1747
[Step 35380]

sampling: 0it [00:00, ?it/s]

[Step 36000] loss: 0.2116
[Step 36010] loss: 0.1929
[Step 36020] loss: 0.1907
[Step 36030] loss: 0.1739
[Step 36040] loss: 0.2009
[Step 36050] loss: 0.2109
[Step 36060] loss: 0.1676
[Step 36070] loss: 0.1898
[Step 36080] loss: 0.1737
[Step 36090] loss: 0.2226
[Step 36100] loss: 0.2077
[Step 36110] loss: 0.1892
[Step 36120] loss: 0.2349
[Step 36130] loss: 0.1969
[Step 36140] loss: 0.2203
[Step 36150] loss: 0.1973
[Step 36160] loss: 0.1984
[Step 36170] loss: 0.1974
[Step 36180] loss: 0.1742
[Step 36190] loss: 0.1910
[Step 36200] loss: 0.1639
[Step 36210] loss: 0.2238
[Step 36220] loss: 0.2233
[Step 36230] loss: 0.2237
[Step 36240] loss: 0.2319
[Step 36250] loss: 0.1795
[Step 36260] loss: 0.1791
[Step 36270] loss: 0.2238
[Step 36280] loss: 0.2172
[Step 36290] loss: 0.1876
[Step 36300] loss: 0.1862
[Step 36310] loss: 0.2158
[Step 36320] loss: 0.1823
[Step 36330] loss: 0.2188
[Step 36340] loss: 0.2341
[Step 36350] loss: 0.1931
[Step 36360] loss: 0.1926
[Step 36370] loss: 0.1836
[Step 36380]

sampling: 0it [00:00, ?it/s]

[Step 37000] loss: 0.1828
[Step 37010] loss: 0.1831
[Step 37020] loss: 0.2004
[Step 37030] loss: 0.2223
[Step 37040] loss: 0.1655
[Step 37050] loss: 0.2580
[Step 37060] loss: 0.2074
[Step 37070] loss: 0.1762
[Step 37080] loss: 0.2298
[Step 37090] loss: 0.2039
[Step 37100] loss: 0.2271
[Step 37110] loss: 0.1960
[Step 37120] loss: 0.2175
[Step 37130] loss: 0.2325
[Step 37140] loss: 0.1561
[Step 37150] loss: 0.2073
[Step 37160] loss: 0.1884
[Step 37170] loss: 0.2170
[Step 37180] loss: 0.1538
[Step 37190] loss: 0.2169
[Step 37200] loss: 0.1745
[Step 37210] loss: 0.1723
[Step 37220] loss: 0.1708
[Step 37230] loss: 0.2231
[Step 37240] loss: 0.2227
[Step 37250] loss: 0.1933
[Step 37260] loss: 0.1822
[Step 37270] loss: 0.1685
[Step 37280] loss: 0.1952
[Step 37290] loss: 0.1956
[Step 37300] loss: 0.1855
[Step 37310] loss: 0.2230
[Step 37320] loss: 0.2019
[Step 37330] loss: 0.2016
[Step 37340] loss: 0.2048
[Step 37350] loss: 0.2094
[Step 37360] loss: 0.1654
[Step 37370] loss: 0.1862
[Step 37380]

sampling: 0it [00:00, ?it/s]

[Step 38000] loss: 0.1882
[Step 38010] loss: 0.1547
[Step 38020] loss: 0.2370
[Step 38030] loss: 0.2379
[Step 38040] loss: 0.2435
[Step 38050] loss: 0.2038
[Step 38060] loss: 0.1999
[Step 38070] loss: 0.1967
[Step 38080] loss: 0.2318
[Step 38090] loss: 0.2215
[Step 38100] loss: 0.1872
[Step 38110] loss: 0.1956
[Step 38120] loss: 0.1890
[Step 38130] loss: 0.1628
[Step 38140] loss: 0.2050
[Step 38150] loss: 0.2202
[Step 38160] loss: 0.1689
[Step 38170] loss: 0.2138
[Step 38180] loss: 0.1712
[Step 38190] loss: 0.2314
[Step 38200] loss: 0.1935
[Step 38210] loss: 0.1965
[Step 38220] loss: 0.2509
[Step 38230] loss: 0.2055
[Step 38240] loss: 0.2107
[Step 38250] loss: 0.1824
[Step 38260] loss: 0.1932
[Step 38270] loss: 0.1689
[Step 38280] loss: 0.1984
[Step 38290] loss: 0.2079
[Step 38300] loss: 0.2019
[Step 38310] loss: 0.2036
[Step 38320] loss: 0.2148
[Step 38330] loss: 0.2091
[Step 38340] loss: 0.2065
[Step 38350] loss: 0.2422
[Step 38360] loss: 0.1708
[Step 38370] loss: 0.1752
[Step 38380]

sampling: 0it [00:00, ?it/s]

[Step 39000] loss: 0.1844
[Step 39010] loss: 0.1889
[Step 39020] loss: 0.1912
[Step 39030] loss: 0.1809
[Step 39040] loss: 0.2254
[Step 39050] loss: 0.2119
[Step 39060] loss: 0.1924
[Step 39070] loss: 0.2377
[Step 39080] loss: 0.1946
[Step 39090] loss: 0.2588
[Step 39100] loss: 0.1785
[Step 39110] loss: 0.1521
[Step 39120] loss: 0.1682
[Step 39130] loss: 0.1998
[Step 39140] loss: 0.2143
[Step 39150] loss: 0.2324
[Step 39160] loss: 0.1734
[Step 39170] loss: 0.2295
[Step 39180] loss: 0.2045
[Step 39190] loss: 0.1995
[Step 39200] loss: 0.2023
[Step 39210] loss: 0.2111
[Step 39220] loss: 0.2015
[Step 39230] loss: 0.1848
[Step 39240] loss: 0.1982
[Step 39250] loss: 0.1940
[Step 39260] loss: 0.2034
[Step 39270] loss: 0.2022
[Step 39280] loss: 0.2187
[Step 39290] loss: 0.1918
[Step 39300] loss: 0.1708
[Step 39310] loss: 0.1618
[Step 39320] loss: 0.1950
[Step 39330] loss: 0.2174
[Step 39340] loss: 0.2088
[Step 39350] loss: 0.1592
[Step 39360] loss: 0.2050
[Step 39370] loss: 0.1867
[Step 39380]

sampling: 0it [00:00, ?it/s]

[Step 40000] loss: 0.1848
[Step 40010] loss: 0.2123
[Step 40020] loss: 0.2234
[Step 40030] loss: 0.1574
[Step 40040] loss: 0.1667
[Step 40050] loss: 0.2153
[Step 40060] loss: 0.2033
[Step 40070] loss: 0.2161
[Step 40080] loss: 0.1945
[Step 40090] loss: 0.2028
[Step 40100] loss: 0.1778
[Step 40110] loss: 0.1937
[Step 40120] loss: 0.2216
[Step 40130] loss: 0.1804
[Step 40140] loss: 0.1902
[Step 40150] loss: 0.1872
[Step 40160] loss: 0.1950
[Step 40170] loss: 0.1750
[Step 40180] loss: 0.1965
[Step 40190] loss: 0.2223
[Step 40200] loss: 0.1993
[Step 40210] loss: 0.2043
[Step 40220] loss: 0.2295
[Step 40230] loss: 0.2252
[Step 40240] loss: 0.1893
[Step 40250] loss: 0.1738
[Step 40260] loss: 0.1984
[Step 40270] loss: 0.2007
[Step 40280] loss: 0.2082
[Step 40290] loss: 0.2325
[Step 40300] loss: 0.1752
[Step 40310] loss: 0.1865
[Step 40320] loss: 0.2302
[Step 40330] loss: 0.2521
[Step 40340] loss: 0.2022
[Step 40350] loss: 0.2354
[Step 40360] loss: 0.2170
[Step 40370] loss: 0.2104
[Step 40380]

sampling: 0it [00:00, ?it/s]

[Step 41000] loss: 0.2268
[Step 41010] loss: 0.2327
[Step 41020] loss: 0.2067
[Step 41030] loss: 0.1851
[Step 41040] loss: 0.1940
[Step 41050] loss: 0.2133
[Step 41060] loss: 0.1880
[Step 41070] loss: 0.1939
[Step 41080] loss: 0.1997
[Step 41090] loss: 0.1855
[Step 41100] loss: 0.1696
[Step 41110] loss: 0.2405
[Step 41120] loss: 0.1913
[Step 41130] loss: 0.1917
[Step 41140] loss: 0.1783
[Step 41150] loss: 0.2064
[Step 41160] loss: 0.2503
[Step 41170] loss: 0.2055
[Step 41180] loss: 0.2278
[Step 41190] loss: 0.1739
[Step 41200] loss: 0.1910
[Step 41210] loss: 0.2149
[Step 41220] loss: 0.2156
[Step 41230] loss: 0.2068
[Step 41240] loss: 0.1892
[Step 41250] loss: 0.1478
[Step 41260] loss: 0.2163
[Step 41270] loss: 0.2153
[Step 41280] loss: 0.2120
[Step 41290] loss: 0.1930
[Step 41300] loss: 0.1963
[Step 41310] loss: 0.2075
[Step 41320] loss: 0.1711
[Step 41330] loss: 0.1728
[Step 41340] loss: 0.2308
[Step 41350] loss: 0.2051
[Step 41360] loss: 0.1864
[Step 41370] loss: 0.1994
[Step 41380]

sampling: 0it [00:00, ?it/s]

[Step 42000] loss: 0.1888
[Step 42010] loss: 0.2009
[Step 42020] loss: 0.2473
[Step 42030] loss: 0.1925
[Step 42040] loss: 0.2001
[Step 42050] loss: 0.2204
[Step 42060] loss: 0.1720
[Step 42070] loss: 0.1704
[Step 42080] loss: 0.2013
[Step 42090] loss: 0.1938
[Step 42100] loss: 0.2028
[Step 42110] loss: 0.2181
[Step 42120] loss: 0.2025
[Step 42130] loss: 0.2004
[Step 42140] loss: 0.1808
[Step 42150] loss: 0.2042
[Step 42160] loss: 0.1666
[Step 42170] loss: 0.2153
[Step 42180] loss: 0.2137
[Step 42190] loss: 0.2101
[Step 42200] loss: 0.1891
[Step 42210] loss: 0.2035
[Step 42220] loss: 0.2410
[Step 42230] loss: 0.1705
[Step 42240] loss: 0.2111
[Step 42250] loss: 0.1851
[Step 42260] loss: 0.1666
[Step 42270] loss: 0.2103
[Step 42280] loss: 0.1562
[Step 42290] loss: 0.2105
[Step 42300] loss: 0.2399
[Step 42310] loss: 0.2410
[Step 42320] loss: 0.1818
[Step 42330] loss: 0.1734
[Step 42340] loss: 0.2193
[Step 42350] loss: 0.2009
[Step 42360] loss: 0.2016
[Step 42370] loss: 0.1959
[Step 42380]

sampling: 0it [00:00, ?it/s]

[Step 43000] loss: 0.1921
[Step 43010] loss: 0.2285
[Step 43020] loss: 0.1996
[Step 43030] loss: 0.1777
[Step 43040] loss: 0.2077
[Step 43050] loss: 0.1923
[Step 43060] loss: 0.1991
[Step 43070] loss: 0.1877
[Step 43080] loss: 0.2197
[Step 43090] loss: 0.2388
[Step 43100] loss: 0.2227
[Step 43110] loss: 0.1983
[Step 43120] loss: 0.1753
[Step 43130] loss: 0.2132
[Step 43140] loss: 0.1889
[Step 43150] loss: 0.2120
[Step 43160] loss: 0.2048
[Step 43170] loss: 0.2667
[Step 43180] loss: 0.1927
[Step 43190] loss: 0.1666
[Step 43200] loss: 0.2037
[Step 43210] loss: 0.1580
[Step 43220] loss: 0.1998
[Step 43230] loss: 0.1683
[Step 43240] loss: 0.1947
[Step 43250] loss: 0.2045
[Step 43260] loss: 0.2087
[Step 43270] loss: 0.2285
[Step 43280] loss: 0.2349
[Step 43290] loss: 0.2292
[Step 43300] loss: 0.2116
[Step 43310] loss: 0.2031
[Step 43320] loss: 0.2145
[Step 43330] loss: 0.2169
[Step 43340] loss: 0.1920
[Step 43350] loss: 0.1538
[Step 43360] loss: 0.2274
[Step 43370] loss: 0.1922
[Step 43380]

sampling: 0it [00:00, ?it/s]

[Step 44000] loss: 0.2126
[Step 44010] loss: 0.2220
[Step 44020] loss: 0.1629
[Step 44030] loss: 0.1966
[Step 44040] loss: 0.2050
[Step 44050] loss: 0.1978
[Step 44060] loss: 0.2165
[Step 44070] loss: 0.2023
[Step 44080] loss: 0.1861
[Step 44090] loss: 0.2081
[Step 44100] loss: 0.1917
[Step 44110] loss: 0.1943
[Step 44120] loss: 0.2170
[Step 44130] loss: 0.2152
[Step 44140] loss: 0.1690
[Step 44150] loss: 0.2088
[Step 44160] loss: 0.2215
[Step 44170] loss: 0.1936
[Step 44180] loss: 0.1754
[Step 44190] loss: 0.1787
[Step 44200] loss: 0.2431
[Step 44210] loss: 0.2275
[Step 44220] loss: 0.2101
[Step 44230] loss: 0.2259
[Step 44240] loss: 0.1921
[Step 44250] loss: 0.1961
[Step 44260] loss: 0.2377
[Step 44270] loss: 0.2123
[Step 44280] loss: 0.1859
[Step 44290] loss: 0.2456
[Step 44300] loss: 0.2084
[Step 44310] loss: 0.2171
[Step 44320] loss: 0.1760
[Step 44330] loss: 0.2418
[Step 44340] loss: 0.2034
[Step 44350] loss: 0.1582
[Step 44360] loss: 0.2030
[Step 44370] loss: 0.1930
[Step 44380]

sampling: 0it [00:00, ?it/s]

[Step 45000] loss: 0.2133
[Step 45010] loss: 0.2465
[Step 45020] loss: 0.2099
[Step 45030] loss: 0.2077
[Step 45040] loss: 0.2047
[Step 45050] loss: 0.2364
[Step 45060] loss: 0.2129
[Step 45070] loss: 0.2083
[Step 45080] loss: 0.1692
[Step 45090] loss: 0.1844
[Step 45100] loss: 0.1799
[Step 45110] loss: 0.1858
[Step 45120] loss: 0.1706
[Step 45130] loss: 0.2118
[Step 45140] loss: 0.1862
[Step 45150] loss: 0.2149
[Step 45160] loss: 0.1576
[Step 45170] loss: 0.1591
[Step 45180] loss: 0.1981
[Step 45190] loss: 0.2136
[Step 45200] loss: 0.2097
[Step 45210] loss: 0.2537
[Step 45220] loss: 0.2126
[Step 45230] loss: 0.1883
[Step 45240] loss: 0.1677
[Step 45250] loss: 0.1837
[Step 45260] loss: 0.2072
[Step 45270] loss: 0.2353
[Step 45280] loss: 0.1673
[Step 45290] loss: 0.1926
[Step 45300] loss: 0.2028
[Step 45310] loss: 0.1865
[Step 45320] loss: 0.2075
[Step 45330] loss: 0.1845
[Step 45340] loss: 0.1823
[Step 45350] loss: 0.1913
[Step 45360] loss: 0.1762
[Step 45370] loss: 0.2060
[Step 45380]

sampling: 0it [00:00, ?it/s]

[Step 46000] loss: 0.1764
[Step 46010] loss: 0.1963
[Step 46020] loss: 0.2142
[Step 46030] loss: 0.2405
[Step 46040] loss: 0.2270
[Step 46050] loss: 0.2002
[Step 46060] loss: 0.2168
[Step 46070] loss: 0.1879
[Step 46080] loss: 0.1810
[Step 46090] loss: 0.2048
[Step 46100] loss: 0.2282
[Step 46110] loss: 0.1933
[Step 46120] loss: 0.1961
[Step 46130] loss: 0.1810
[Step 46140] loss: 0.1958
[Step 46150] loss: 0.2138
[Step 46160] loss: 0.1934
[Step 46170] loss: 0.2108
[Step 46180] loss: 0.2071
[Step 46190] loss: 0.1811
[Step 46200] loss: 0.2041
[Step 46210] loss: 0.2100
[Step 46220] loss: 0.2016
[Step 46230] loss: 0.2073
[Step 46240] loss: 0.2365
[Step 46250] loss: 0.1729
[Step 46260] loss: 0.1944
[Step 46270] loss: 0.2346
[Step 46280] loss: 0.2295
[Step 46290] loss: 0.2111
[Step 46300] loss: 0.1696
[Step 46310] loss: 0.2099
[Step 46320] loss: 0.1972
[Step 46330] loss: 0.2009
[Step 46340] loss: 0.1760
[Step 46350] loss: 0.2203
[Step 46360] loss: 0.2364
[Step 46370] loss: 0.1931
[Step 46380]

sampling: 0it [00:00, ?it/s]

[Step 47000] loss: 0.1662
[Step 47010] loss: 0.1905
[Step 47020] loss: 0.1969
[Step 47030] loss: 0.2039
[Step 47040] loss: 0.2452
[Step 47050] loss: 0.2205
[Step 47060] loss: 0.1612
[Step 47070] loss: 0.2106
[Step 47080] loss: 0.1982
[Step 47090] loss: 0.1864
[Step 47100] loss: 0.2021
[Step 47110] loss: 0.2054
[Step 47120] loss: 0.2176
[Step 47130] loss: 0.1801
[Step 47140] loss: 0.1842
[Step 47150] loss: 0.2022
[Step 47160] loss: 0.1316
[Step 47170] loss: 0.1682
[Step 47180] loss: 0.2119
[Step 47190] loss: 0.1891
[Step 47200] loss: 0.2105
[Step 47210] loss: 0.1794
[Step 47220] loss: 0.2269
[Step 47230] loss: 0.1655
[Step 47240] loss: 0.2081
[Step 47250] loss: 0.1877
[Step 47260] loss: 0.2140
[Step 47270] loss: 0.1974
[Step 47280] loss: 0.1690
[Step 47290] loss: 0.1872
[Step 47300] loss: 0.2421
[Step 47310] loss: 0.1903
[Step 47320] loss: 0.1933
[Step 47330] loss: 0.2216
[Step 47340] loss: 0.2228
[Step 47350] loss: 0.1768
[Step 47360] loss: 0.2034
[Step 47370] loss: 0.2048
[Step 47380]

sampling: 0it [00:00, ?it/s]

[Step 48000] loss: 0.2017
[Step 48010] loss: 0.1958
[Step 48020] loss: 0.1930
[Step 48030] loss: 0.2051
[Step 48040] loss: 0.2195
[Step 48050] loss: 0.1925
[Step 48060] loss: 0.1864
[Step 48070] loss: 0.1387
[Step 48080] loss: 0.2068
[Step 48090] loss: 0.2087
[Step 48100] loss: 0.1909
[Step 48110] loss: 0.1902
[Step 48120] loss: 0.1832
[Step 48130] loss: 0.2198
[Step 48140] loss: 0.1798
[Step 48150] loss: 0.1786
[Step 48160] loss: 0.1920
[Step 48170] loss: 0.1927
[Step 48180] loss: 0.2591
[Step 48190] loss: 0.1942
[Step 48200] loss: 0.2100
[Step 48210] loss: 0.1863
[Step 48220] loss: 0.1858
[Step 48230] loss: 0.1904
[Step 48240] loss: 0.1801
[Step 48250] loss: 0.2079
[Step 48260] loss: 0.1664
[Step 48270] loss: 0.1749
[Step 48280] loss: 0.1840
[Step 48290] loss: 0.1997
[Step 48300] loss: 0.1996
[Step 48310] loss: 0.1815
[Step 48320] loss: 0.1813
[Step 48330] loss: 0.2380
[Step 48340] loss: 0.2313
[Step 48350] loss: 0.2043
[Step 48360] loss: 0.2052
[Step 48370] loss: 0.2202
[Step 48380]

sampling: 0it [00:00, ?it/s]

[Step 49000] loss: 0.1670
[Step 49010] loss: 0.2011
[Step 49020] loss: 0.1817
[Step 49030] loss: 0.1899
[Step 49040] loss: 0.2210
[Step 49050] loss: 0.1970
[Step 49060] loss: 0.1814
[Step 49070] loss: 0.1966
[Step 49080] loss: 0.1904
[Step 49090] loss: 0.2125
[Step 49100] loss: 0.2252
[Step 49110] loss: 0.2013
[Step 49120] loss: 0.2187
[Step 49130] loss: 0.1993
[Step 49140] loss: 0.2195
[Step 49150] loss: 0.1997
[Step 49160] loss: 0.2421
[Step 49170] loss: 0.2029
[Step 49180] loss: 0.2066
[Step 49190] loss: 0.2030
[Step 49200] loss: 0.1984
[Step 49210] loss: 0.2371
[Step 49220] loss: 0.1701
[Step 49230] loss: 0.1872
[Step 49240] loss: 0.1996
[Step 49250] loss: 0.1800
[Step 49260] loss: 0.2194
[Step 49270] loss: 0.1780
[Step 49280] loss: 0.2175
[Step 49290] loss: 0.1974
[Step 49300] loss: 0.2351
[Step 49310] loss: 0.2132
[Step 49320] loss: 0.2008
[Step 49330] loss: 0.2148
[Step 49340] loss: 0.2131
[Step 49350] loss: 0.1927
[Step 49360] loss: 0.1916
[Step 49370] loss: 0.2050
[Step 49380]

sampling: 0it [00:00, ?it/s]

[Step 50000] loss: 0.2231
Training complete.


## Inference

In [None]:
ckpt = './UNetTransformer-emb256.pt'
trainer.load(ckpt)
# trainer.inference()

In [None]:
diffusion = trainer.model
with torch.no_grad():
    all_steps = diffusion.sample(batch_size=1, return_all_timesteps=True)
# shape [1, T+1, C, H, W]

imgs = all_steps.squeeze(0).permute(0, 2, 3, 1).cpu().numpy()
imgs = unnormalize_to_zero_to_one(imgs)

timesteps = [0, diffusion.num_timesteps // 4, diffusion.num_timesteps // 2,
             3 * diffusion.num_timesteps // 4, diffusion.num_timesteps]
sel = [imgs[t] for t in timesteps]

# Plot
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for ax, im, step in zip(axes, sel, timesteps):
    ax.imshow(im)
    ax.axis('off')
    ax.set_title(f"Step {step}")
plt.tight_layout()
plt.show()