In [1]:
from functools import partial
# from utils.networkHelper import *
import math
from inspect import isfunction
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm.auto import tqdm
from pathlib import Path

from einops.layers.torch import Rearrange
from einops import rearrange, reduce

import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.optim import Adam
from torchvision import transforms, datasets
from torchvision.utils import save_image

In [2]:
# global varialbes
image_size = 128
channels = 3
train_epochs = 20
timesteps = 1000
device = torch.device('cuda:'+str(torch.cuda.device_count()-1)
                      if torch.cuda.is_available() else 'cpu')
data_path_root = '/data/wumin/dataset/celeba/'
bs = 100
results_folder = Path("./samples/ddpm/celeba_128")
results_folder.mkdir(exist_ok=True)

# MNIST Dataset
transform = transforms.Compose([
    transforms.CenterCrop((178,178)),
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
    ])

train_dataset = datasets.CelebA(
    root=data_path_root, split='train', transform=transform, download=True)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=bs, shuffle=True)

Files already downloaded and verified


In [3]:
for x, y in train_loader:
    print(x.shape)
    save_image(x[:64].add(1).mul(0.5), results_folder/'real.png', nrow=8)
    break

torch.Size([100, 3, 128, 128])


In [4]:
def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    # 暂时没用到
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(
            half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class WeightStandardizedConv2d(nn.Conv2d):
    """
    权重标准化后的卷积模块
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1",
                     partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(
                t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(
                t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y",
                        h=self.heads, x=h, y=w)
        return self.to_out(out)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        self_condition=False,
        resnet_block_groups=4,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        # changed to 1 and 0 from 7,3
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out,
                                    time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out,
                                    time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(
        ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)


def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# forward diffusion (using the nice property)


def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )
    
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss


@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise

# Algorithm 2 (including returning all images)


@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full(
            (b,), i, device=device, dtype=torch.long), i)
        imgs.append(img)
    return torch.stack(imgs, dim=0)


@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


def show_result(model, epoch, batches=8, display_step=125, show=False, save=False,  path='result.png'):
    # save generated images
    ds = display_step
    all_images_list = sample(model, image_size=image_size, batch_size=batches, channels=channels)
    all_images = rearrange(
        all_images_list, 't b c h w -> (b t) c h w').add(1).mul(0.5)
    all_images = all_images[torch.arange(ds-1, all_images.shape[0], ds),]
    save_image(all_images, str(results_folder /
               f'sample_epoch_{epoch}.png'), nrow=timesteps//ds)


def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    # display the training losses as epoch increasing
    x = range(len(hist['losses']))

    y = hist['losses']

    plt.plot(x, y, label='Loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=5)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()
        
        
def train(model, optimizer, epoch, train_hist):
    losses = []
    for step, (batch, _) in enumerate(train_loader):
        optimizer.zero_grad()

        batch_size = batch.shape[0]
        batch = batch.to(device)

        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        
        loss = p_losses(model, batch, t, loss_type="huber")
        losses.append(loss.item())

        if step % 50 == 0:
            print("epoch: {} step: {} Loss: {:.5f}".format(
                epoch, step, loss.item()))
            
        loss.backward()
        optimizer.step()

    train_hist['losses'].append(np.mean(losses))
    print("================> epoch: {} average Loss: {:.5f}".format(
                epoch, train_hist['losses'][-1]))
    

In [5]:
len(train_loader)

1628

In [7]:
model = Unet(
    dim=image_size // 4,
    channels=channels,
    dim_mults=(1, 2, 4,)
).to(device)

# model = torch.load('models/ddpm_celeba.pth').to(device)
optimizer = Adam(model.parameters(), lr=2e-4)


In [8]:
train_hist = {}
train_hist['losses'] = []
print("start training")

for epoch in range(train_epochs):
    if (epoch+1) % 10 == 0:
        optimizer.param_groups[0]['lr'] /= 2
    train(model, optimizer, epoch, train_hist)
    
    # save generated images
    show_result(model, epoch)

show_train_hist(train_hist, False, True, results_folder / 'history_train_losses.png')

start training
epoch: 0 step: 0 Loss: 0.48544
epoch: 0 step: 50 Loss: 0.04851
epoch: 0 step: 100 Loss: 0.03521
epoch: 0 step: 150 Loss: 0.02749
epoch: 0 step: 200 Loss: 0.02323
epoch: 0 step: 250 Loss: 0.01555
epoch: 0 step: 300 Loss: 0.01533
epoch: 0 step: 350 Loss: 0.01613
epoch: 0 step: 400 Loss: 0.00973
epoch: 0 step: 450 Loss: 0.01617
epoch: 0 step: 500 Loss: 0.01568
epoch: 0 step: 550 Loss: 0.01227
epoch: 0 step: 600 Loss: 0.01395
epoch: 0 step: 650 Loss: 0.01206
epoch: 0 step: 700 Loss: 0.00840
epoch: 0 step: 750 Loss: 0.00990
epoch: 0 step: 800 Loss: 0.00708
epoch: 0 step: 850 Loss: 0.01567
epoch: 0 step: 900 Loss: 0.00833
epoch: 0 step: 950 Loss: 0.01036
epoch: 0 step: 1000 Loss: 0.00733
epoch: 0 step: 1050 Loss: 0.00844
epoch: 0 step: 1100 Loss: 0.00792
epoch: 0 step: 1150 Loss: 0.01589
epoch: 0 step: 1200 Loss: 0.01062
epoch: 0 step: 1250 Loss: 0.01025
epoch: 0 step: 1300 Loss: 0.01105
epoch: 0 step: 1350 Loss: 0.00957
epoch: 0 step: 1400 Loss: 0.00757
epoch: 0 step: 1450 Lo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 1 step: 0 Loss: 0.01128
epoch: 1 step: 50 Loss: 0.00723
epoch: 1 step: 100 Loss: 0.00783
epoch: 1 step: 150 Loss: 0.01154
epoch: 1 step: 200 Loss: 0.00432
epoch: 1 step: 250 Loss: 0.00832
epoch: 1 step: 300 Loss: 0.01061
epoch: 1 step: 350 Loss: 0.00555
epoch: 1 step: 400 Loss: 0.00819
epoch: 1 step: 450 Loss: 0.00587
epoch: 1 step: 500 Loss: 0.00625
epoch: 1 step: 550 Loss: 0.00845
epoch: 1 step: 600 Loss: 0.00809
epoch: 1 step: 650 Loss: 0.00699
epoch: 1 step: 700 Loss: 0.00755
epoch: 1 step: 750 Loss: 0.01080
epoch: 1 step: 800 Loss: 0.00677
epoch: 1 step: 850 Loss: 0.00810
epoch: 1 step: 900 Loss: 0.00615
epoch: 1 step: 950 Loss: 0.00840
epoch: 1 step: 1000 Loss: 0.00709
epoch: 1 step: 1050 Loss: 0.01490
epoch: 1 step: 1100 Loss: 0.00716
epoch: 1 step: 1150 Loss: 0.00841
epoch: 1 step: 1200 Loss: 0.00669
epoch: 1 step: 1250 Loss: 0.01118
epoch: 1 step: 1300 Loss: 0.00863
epoch: 1 step: 1350 Loss: 0.00453
epoch: 1 step: 1400 Loss: 0.00690
epoch: 1 step: 1450 Loss: 0.01368
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 2 step: 0 Loss: 0.00706
epoch: 2 step: 50 Loss: 0.00644
epoch: 2 step: 100 Loss: 0.00566
epoch: 2 step: 150 Loss: 0.00644
epoch: 2 step: 200 Loss: 0.00475
epoch: 2 step: 250 Loss: 0.00668
epoch: 2 step: 300 Loss: 0.00366
epoch: 2 step: 350 Loss: 0.00396
epoch: 2 step: 400 Loss: 0.00727
epoch: 2 step: 450 Loss: 0.00714
epoch: 2 step: 500 Loss: 0.00824
epoch: 2 step: 550 Loss: 0.00593
epoch: 2 step: 600 Loss: 0.00705
epoch: 2 step: 650 Loss: 0.01115
epoch: 2 step: 700 Loss: 0.00594
epoch: 2 step: 750 Loss: 0.00781
epoch: 2 step: 800 Loss: 0.01032
epoch: 2 step: 850 Loss: 0.00496
epoch: 2 step: 900 Loss: 0.00549
epoch: 2 step: 950 Loss: 0.00719
epoch: 2 step: 1000 Loss: 0.00837
epoch: 2 step: 1050 Loss: 0.00676
epoch: 2 step: 1100 Loss: 0.00711
epoch: 2 step: 1150 Loss: 0.00926
epoch: 2 step: 1200 Loss: 0.00941
epoch: 2 step: 1250 Loss: 0.00456
epoch: 2 step: 1300 Loss: 0.00538
epoch: 2 step: 1350 Loss: 0.00780
epoch: 2 step: 1400 Loss: 0.00663
epoch: 2 step: 1450 Loss: 0.00545
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 3 step: 0 Loss: 0.00404
epoch: 3 step: 50 Loss: 0.00560
epoch: 3 step: 100 Loss: 0.00540
epoch: 3 step: 150 Loss: 0.00632
epoch: 3 step: 200 Loss: 0.00827
epoch: 3 step: 250 Loss: 0.00792
epoch: 3 step: 300 Loss: 0.00647
epoch: 3 step: 350 Loss: 0.00875
epoch: 3 step: 400 Loss: 0.01034
epoch: 3 step: 450 Loss: 0.00410
epoch: 3 step: 500 Loss: 0.00985
epoch: 3 step: 550 Loss: 0.00692
epoch: 3 step: 600 Loss: 0.00663
epoch: 3 step: 650 Loss: 0.00585
epoch: 3 step: 700 Loss: 0.00538
epoch: 3 step: 750 Loss: 0.00778
epoch: 3 step: 800 Loss: 0.00593
epoch: 3 step: 850 Loss: 0.00751
epoch: 3 step: 900 Loss: 0.00854
epoch: 3 step: 950 Loss: 0.00528
epoch: 3 step: 1000 Loss: 0.00652
epoch: 3 step: 1050 Loss: 0.00683
epoch: 3 step: 1100 Loss: 0.00504
epoch: 3 step: 1150 Loss: 0.00463
epoch: 3 step: 1200 Loss: 0.00566
epoch: 3 step: 1250 Loss: 0.00511
epoch: 3 step: 1300 Loss: 0.00706
epoch: 3 step: 1350 Loss: 0.00504
epoch: 3 step: 1400 Loss: 0.00540
epoch: 3 step: 1450 Loss: 0.00586
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 4 step: 0 Loss: 0.00910
epoch: 4 step: 50 Loss: 0.00640
epoch: 4 step: 100 Loss: 0.00588
epoch: 4 step: 150 Loss: 0.00530
epoch: 4 step: 200 Loss: 0.00897
epoch: 4 step: 250 Loss: 0.00668
epoch: 4 step: 300 Loss: 0.00387
epoch: 4 step: 350 Loss: 0.01217
epoch: 4 step: 400 Loss: 0.00647
epoch: 4 step: 450 Loss: 0.00550
epoch: 4 step: 500 Loss: 0.00951
epoch: 4 step: 550 Loss: 0.00613
epoch: 4 step: 600 Loss: 0.00575
epoch: 4 step: 650 Loss: 0.00576
epoch: 4 step: 700 Loss: 0.00524
epoch: 4 step: 750 Loss: 0.00635
epoch: 4 step: 800 Loss: 0.00559
epoch: 4 step: 850 Loss: 0.00828
epoch: 4 step: 900 Loss: 0.00562
epoch: 4 step: 950 Loss: 0.00597
epoch: 4 step: 1000 Loss: 0.00702
epoch: 4 step: 1050 Loss: 0.01008
epoch: 4 step: 1100 Loss: 0.00549
epoch: 4 step: 1150 Loss: 0.00497
epoch: 4 step: 1200 Loss: 0.00465
epoch: 4 step: 1250 Loss: 0.00515
epoch: 4 step: 1300 Loss: 0.00737
epoch: 4 step: 1350 Loss: 0.00512
epoch: 4 step: 1400 Loss: 0.00637
epoch: 4 step: 1450 Loss: 0.00946
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 5 step: 0 Loss: 0.00689
epoch: 5 step: 50 Loss: 0.00355
epoch: 5 step: 100 Loss: 0.00559
epoch: 5 step: 150 Loss: 0.00553
epoch: 5 step: 200 Loss: 0.00473
epoch: 5 step: 250 Loss: 0.00576
epoch: 5 step: 300 Loss: 0.00542
epoch: 5 step: 350 Loss: 0.00712
epoch: 5 step: 400 Loss: 0.00861
epoch: 5 step: 450 Loss: 0.00773
epoch: 5 step: 500 Loss: 0.00477
epoch: 5 step: 550 Loss: 0.00637
epoch: 5 step: 600 Loss: 0.00545
epoch: 5 step: 650 Loss: 0.00842
epoch: 5 step: 700 Loss: 0.00618
epoch: 5 step: 750 Loss: 0.00503
epoch: 5 step: 800 Loss: 0.00609
epoch: 5 step: 850 Loss: 0.00856
epoch: 5 step: 900 Loss: 0.00425
epoch: 5 step: 950 Loss: 0.00627
epoch: 5 step: 1000 Loss: 0.00505
epoch: 5 step: 1050 Loss: 0.00745
epoch: 5 step: 1100 Loss: 0.00591
epoch: 5 step: 1150 Loss: 0.00901
epoch: 5 step: 1200 Loss: 0.00494
epoch: 5 step: 1250 Loss: 0.00537
epoch: 5 step: 1300 Loss: 0.00519
epoch: 5 step: 1350 Loss: 0.00493
epoch: 5 step: 1400 Loss: 0.00659
epoch: 5 step: 1450 Loss: 0.00782
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 6 step: 0 Loss: 0.00918
epoch: 6 step: 50 Loss: 0.00516
epoch: 6 step: 100 Loss: 0.00632
epoch: 6 step: 150 Loss: 0.00775
epoch: 6 step: 200 Loss: 0.00809
epoch: 6 step: 250 Loss: 0.00682
epoch: 6 step: 300 Loss: 0.00598
epoch: 6 step: 350 Loss: 0.00634
epoch: 6 step: 400 Loss: 0.00515
epoch: 6 step: 450 Loss: 0.00409
epoch: 6 step: 500 Loss: 0.00853
epoch: 6 step: 550 Loss: 0.00486
epoch: 6 step: 600 Loss: 0.00644
epoch: 6 step: 650 Loss: 0.00440
epoch: 6 step: 700 Loss: 0.00619
epoch: 6 step: 750 Loss: 0.00559
epoch: 6 step: 800 Loss: 0.00776
epoch: 6 step: 850 Loss: 0.00798
epoch: 6 step: 900 Loss: 0.00580
epoch: 6 step: 950 Loss: 0.00669
epoch: 6 step: 1000 Loss: 0.00725
epoch: 6 step: 1050 Loss: 0.00560
epoch: 6 step: 1100 Loss: 0.00735
epoch: 6 step: 1150 Loss: 0.00564
epoch: 6 step: 1200 Loss: 0.00627
epoch: 6 step: 1250 Loss: 0.00747
epoch: 6 step: 1300 Loss: 0.00562
epoch: 6 step: 1350 Loss: 0.00623
epoch: 6 step: 1400 Loss: 0.00433
epoch: 6 step: 1450 Loss: 0.00896
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 7 step: 0 Loss: 0.00868
epoch: 7 step: 50 Loss: 0.00918
epoch: 7 step: 100 Loss: 0.00588
epoch: 7 step: 150 Loss: 0.00524
epoch: 7 step: 200 Loss: 0.00397
epoch: 7 step: 250 Loss: 0.00770
epoch: 7 step: 300 Loss: 0.00596
epoch: 7 step: 350 Loss: 0.00499
epoch: 7 step: 400 Loss: 0.00644
epoch: 7 step: 450 Loss: 0.01144
epoch: 7 step: 500 Loss: 0.00635
epoch: 7 step: 550 Loss: 0.00545
epoch: 7 step: 600 Loss: 0.00770
epoch: 7 step: 650 Loss: 0.00830
epoch: 7 step: 700 Loss: 0.00529
epoch: 7 step: 750 Loss: 0.00564
epoch: 7 step: 800 Loss: 0.00584
epoch: 7 step: 850 Loss: 0.00615
epoch: 7 step: 900 Loss: 0.00849
epoch: 7 step: 950 Loss: 0.00777
epoch: 7 step: 1000 Loss: 0.00544
epoch: 7 step: 1050 Loss: 0.00727
epoch: 7 step: 1100 Loss: 0.00556
epoch: 7 step: 1150 Loss: 0.00973
epoch: 7 step: 1200 Loss: 0.00703
epoch: 7 step: 1250 Loss: 0.00463
epoch: 7 step: 1300 Loss: 0.00672
epoch: 7 step: 1350 Loss: 0.00684
epoch: 7 step: 1400 Loss: 0.01252
epoch: 7 step: 1450 Loss: 0.00384
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 8 step: 0 Loss: 0.00538
epoch: 8 step: 50 Loss: 0.00546
epoch: 8 step: 100 Loss: 0.00380
epoch: 8 step: 150 Loss: 0.00453
epoch: 8 step: 200 Loss: 0.00546
epoch: 8 step: 250 Loss: 0.00572
epoch: 8 step: 300 Loss: 0.00619
epoch: 8 step: 350 Loss: 0.00500
epoch: 8 step: 400 Loss: 0.00599
epoch: 8 step: 450 Loss: 0.00604
epoch: 8 step: 500 Loss: 0.00465
epoch: 8 step: 550 Loss: 0.00674
epoch: 8 step: 600 Loss: 0.00564
epoch: 8 step: 650 Loss: 0.00596
epoch: 8 step: 700 Loss: 0.00981
epoch: 8 step: 750 Loss: 0.00570
epoch: 8 step: 800 Loss: 0.00956
epoch: 8 step: 850 Loss: 0.00719
epoch: 8 step: 900 Loss: 0.00525
epoch: 8 step: 950 Loss: 0.00437
epoch: 8 step: 1000 Loss: 0.00407
epoch: 8 step: 1050 Loss: 0.00481
epoch: 8 step: 1100 Loss: 0.00433
epoch: 8 step: 1150 Loss: 0.00390
epoch: 8 step: 1200 Loss: 0.00365
epoch: 8 step: 1250 Loss: 0.00703
epoch: 8 step: 1300 Loss: 0.00927
epoch: 8 step: 1350 Loss: 0.00993
epoch: 8 step: 1400 Loss: 0.00460
epoch: 8 step: 1450 Loss: 0.00679
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 9 step: 0 Loss: 0.00532
epoch: 9 step: 50 Loss: 0.00776
epoch: 9 step: 100 Loss: 0.00407
epoch: 9 step: 150 Loss: 0.00768
epoch: 9 step: 200 Loss: 0.00447
epoch: 9 step: 250 Loss: 0.00372
epoch: 9 step: 300 Loss: 0.00559
epoch: 9 step: 350 Loss: 0.00679
epoch: 9 step: 400 Loss: 0.00501
epoch: 9 step: 450 Loss: 0.00489
epoch: 9 step: 500 Loss: 0.00554
epoch: 9 step: 550 Loss: 0.00570
epoch: 9 step: 600 Loss: 0.00747
epoch: 9 step: 650 Loss: 0.00645
epoch: 9 step: 700 Loss: 0.00593
epoch: 9 step: 750 Loss: 0.00604
epoch: 9 step: 800 Loss: 0.00999
epoch: 9 step: 850 Loss: 0.00834
epoch: 9 step: 900 Loss: 0.00751
epoch: 9 step: 950 Loss: 0.00621
epoch: 9 step: 1000 Loss: 0.00494
epoch: 9 step: 1050 Loss: 0.00639
epoch: 9 step: 1100 Loss: 0.00512
epoch: 9 step: 1150 Loss: 0.00919
epoch: 9 step: 1200 Loss: 0.00810
epoch: 9 step: 1250 Loss: 0.00506
epoch: 9 step: 1300 Loss: 0.00609
epoch: 9 step: 1350 Loss: 0.00390
epoch: 9 step: 1400 Loss: 0.01092
epoch: 9 step: 1450 Loss: 0.00568
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 10 step: 0 Loss: 0.00500
epoch: 10 step: 50 Loss: 0.00574
epoch: 10 step: 100 Loss: 0.00732
epoch: 10 step: 150 Loss: 0.00539
epoch: 10 step: 200 Loss: 0.00826
epoch: 10 step: 250 Loss: 0.00779
epoch: 10 step: 300 Loss: 0.00736
epoch: 10 step: 350 Loss: 0.00691
epoch: 10 step: 400 Loss: 0.00494
epoch: 10 step: 450 Loss: 0.00499
epoch: 10 step: 500 Loss: 0.00696
epoch: 10 step: 550 Loss: 0.00708
epoch: 10 step: 600 Loss: 0.00531
epoch: 10 step: 650 Loss: 0.00795
epoch: 10 step: 700 Loss: 0.00622
epoch: 10 step: 750 Loss: 0.00548
epoch: 10 step: 800 Loss: 0.00812
epoch: 10 step: 850 Loss: 0.00814
epoch: 10 step: 900 Loss: 0.00423
epoch: 10 step: 950 Loss: 0.00534
epoch: 10 step: 1000 Loss: 0.00719
epoch: 10 step: 1050 Loss: 0.00773
epoch: 10 step: 1100 Loss: 0.00704
epoch: 10 step: 1150 Loss: 0.00855
epoch: 10 step: 1200 Loss: 0.00701
epoch: 10 step: 1250 Loss: 0.00560
epoch: 10 step: 1300 Loss: 0.00548
epoch: 10 step: 1350 Loss: 0.00746
epoch: 10 step: 1400 Loss: 0.00607
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 11 step: 0 Loss: 0.00537
epoch: 11 step: 50 Loss: 0.00723
epoch: 11 step: 100 Loss: 0.00779
epoch: 11 step: 150 Loss: 0.00777
epoch: 11 step: 200 Loss: 0.00476
epoch: 11 step: 250 Loss: 0.00539
epoch: 11 step: 300 Loss: 0.00727
epoch: 11 step: 350 Loss: 0.00611
epoch: 11 step: 400 Loss: 0.00618
epoch: 11 step: 450 Loss: 0.00571
epoch: 11 step: 500 Loss: 0.00538
epoch: 11 step: 550 Loss: 0.00464
epoch: 11 step: 600 Loss: 0.00424
epoch: 11 step: 650 Loss: 0.00553
epoch: 11 step: 700 Loss: 0.00694
epoch: 11 step: 750 Loss: 0.00432
epoch: 11 step: 800 Loss: 0.00612
epoch: 11 step: 850 Loss: 0.00476
epoch: 11 step: 900 Loss: 0.00672
epoch: 11 step: 950 Loss: 0.00679
epoch: 11 step: 1000 Loss: 0.00380
epoch: 11 step: 1050 Loss: 0.00541
epoch: 11 step: 1100 Loss: 0.00569
epoch: 11 step: 1150 Loss: 0.00556
epoch: 11 step: 1200 Loss: 0.00855
epoch: 11 step: 1250 Loss: 0.00458
epoch: 11 step: 1300 Loss: 0.00643
epoch: 11 step: 1350 Loss: 0.00474
epoch: 11 step: 1400 Loss: 0.00816
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 12 step: 0 Loss: 0.01090
epoch: 12 step: 50 Loss: 0.00592
epoch: 12 step: 100 Loss: 0.00569
epoch: 12 step: 150 Loss: 0.00324
epoch: 12 step: 200 Loss: 0.00630
epoch: 12 step: 250 Loss: 0.00691
epoch: 12 step: 300 Loss: 0.00760
epoch: 12 step: 350 Loss: 0.00699
epoch: 12 step: 400 Loss: 0.00715
epoch: 12 step: 450 Loss: 0.00441
epoch: 12 step: 500 Loss: 0.00660
epoch: 12 step: 550 Loss: 0.00748
epoch: 12 step: 600 Loss: 0.00699
epoch: 12 step: 650 Loss: 0.00671
epoch: 12 step: 700 Loss: 0.00520
epoch: 12 step: 750 Loss: 0.00597
epoch: 12 step: 800 Loss: 0.00605
epoch: 12 step: 850 Loss: 0.00555
epoch: 12 step: 900 Loss: 0.00621
epoch: 12 step: 950 Loss: 0.00695
epoch: 12 step: 1000 Loss: 0.00801
epoch: 12 step: 1050 Loss: 0.00681
epoch: 12 step: 1100 Loss: 0.00600
epoch: 12 step: 1150 Loss: 0.00606
epoch: 12 step: 1200 Loss: 0.00791
epoch: 12 step: 1250 Loss: 0.00773
epoch: 12 step: 1300 Loss: 0.00590
epoch: 12 step: 1350 Loss: 0.00579
epoch: 12 step: 1400 Loss: 0.00414
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 13 step: 0 Loss: 0.00696
epoch: 13 step: 50 Loss: 0.00527
epoch: 13 step: 100 Loss: 0.00374
epoch: 13 step: 150 Loss: 0.00764
epoch: 13 step: 200 Loss: 0.00436
epoch: 13 step: 250 Loss: 0.00502
epoch: 13 step: 300 Loss: 0.00529
epoch: 13 step: 350 Loss: 0.00784
epoch: 13 step: 400 Loss: 0.00709
epoch: 13 step: 450 Loss: 0.00538
epoch: 13 step: 500 Loss: 0.00478
epoch: 13 step: 550 Loss: 0.00514
epoch: 13 step: 600 Loss: 0.00485
epoch: 13 step: 650 Loss: 0.00553
epoch: 13 step: 700 Loss: 0.00448
epoch: 13 step: 750 Loss: 0.00659
epoch: 13 step: 800 Loss: 0.00376
epoch: 13 step: 850 Loss: 0.00536
epoch: 13 step: 900 Loss: 0.00398
epoch: 13 step: 950 Loss: 0.00669
epoch: 13 step: 1000 Loss: 0.00560
epoch: 13 step: 1050 Loss: 0.00418
epoch: 13 step: 1100 Loss: 0.00465
epoch: 13 step: 1150 Loss: 0.00841
epoch: 13 step: 1200 Loss: 0.00701
epoch: 13 step: 1250 Loss: 0.00626
epoch: 13 step: 1300 Loss: 0.00551
epoch: 13 step: 1350 Loss: 0.00888
epoch: 13 step: 1400 Loss: 0.00505
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 14 step: 0 Loss: 0.00678
epoch: 14 step: 50 Loss: 0.00448
epoch: 14 step: 100 Loss: 0.00718
epoch: 14 step: 150 Loss: 0.00843
epoch: 14 step: 200 Loss: 0.00456
epoch: 14 step: 250 Loss: 0.00650
epoch: 14 step: 300 Loss: 0.00704
epoch: 14 step: 350 Loss: 0.00524
epoch: 14 step: 400 Loss: 0.00569
epoch: 14 step: 450 Loss: 0.00826
epoch: 14 step: 500 Loss: 0.00850
epoch: 14 step: 550 Loss: 0.00750
epoch: 14 step: 600 Loss: 0.00680
epoch: 14 step: 650 Loss: 0.00591
epoch: 14 step: 700 Loss: 0.00454
epoch: 14 step: 750 Loss: 0.00586
epoch: 14 step: 800 Loss: 0.00607
epoch: 14 step: 850 Loss: 0.00637
epoch: 14 step: 900 Loss: 0.00743
epoch: 14 step: 950 Loss: 0.00613
epoch: 14 step: 1000 Loss: 0.00570
epoch: 14 step: 1050 Loss: 0.00614
epoch: 14 step: 1100 Loss: 0.00599
epoch: 14 step: 1150 Loss: 0.00589
epoch: 14 step: 1200 Loss: 0.00445
epoch: 14 step: 1250 Loss: 0.00508
epoch: 14 step: 1300 Loss: 0.00572
epoch: 14 step: 1350 Loss: 0.00566
epoch: 14 step: 1400 Loss: 0.00506
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 15 step: 0 Loss: 0.00584
epoch: 15 step: 50 Loss: 0.00828
epoch: 15 step: 100 Loss: 0.00610
epoch: 15 step: 150 Loss: 0.00489
epoch: 15 step: 200 Loss: 0.00807
epoch: 15 step: 250 Loss: 0.00581
epoch: 15 step: 300 Loss: 0.00841
epoch: 15 step: 350 Loss: 0.00566
epoch: 15 step: 400 Loss: 0.00565
epoch: 15 step: 450 Loss: 0.00603
epoch: 15 step: 500 Loss: 0.00717
epoch: 15 step: 550 Loss: 0.00724
epoch: 15 step: 600 Loss: 0.00582
epoch: 15 step: 650 Loss: 0.00747
epoch: 15 step: 700 Loss: 0.00564
epoch: 15 step: 750 Loss: 0.00501
epoch: 15 step: 800 Loss: 0.00538
epoch: 15 step: 850 Loss: 0.00452
epoch: 15 step: 900 Loss: 0.00609
epoch: 15 step: 950 Loss: 0.00396
epoch: 15 step: 1000 Loss: 0.00572
epoch: 15 step: 1050 Loss: 0.00532
epoch: 15 step: 1100 Loss: 0.00869
epoch: 15 step: 1150 Loss: 0.00456
epoch: 15 step: 1200 Loss: 0.00582
epoch: 15 step: 1250 Loss: 0.00592
epoch: 15 step: 1300 Loss: 0.00639
epoch: 15 step: 1350 Loss: 0.00384
epoch: 15 step: 1400 Loss: 0.00668
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 16 step: 0 Loss: 0.00626
epoch: 16 step: 50 Loss: 0.00419
epoch: 16 step: 100 Loss: 0.00674
epoch: 16 step: 150 Loss: 0.00603
epoch: 16 step: 200 Loss: 0.00816
epoch: 16 step: 250 Loss: 0.00454
epoch: 16 step: 300 Loss: 0.00937
epoch: 16 step: 350 Loss: 0.00436
epoch: 16 step: 400 Loss: 0.00658
epoch: 16 step: 450 Loss: 0.00711
epoch: 16 step: 500 Loss: 0.00761
epoch: 16 step: 550 Loss: 0.00922
epoch: 16 step: 600 Loss: 0.00628
epoch: 16 step: 650 Loss: 0.00582
epoch: 16 step: 700 Loss: 0.00582
epoch: 16 step: 750 Loss: 0.00636
epoch: 16 step: 800 Loss: 0.00460
epoch: 16 step: 850 Loss: 0.00390
epoch: 16 step: 900 Loss: 0.00707
epoch: 16 step: 950 Loss: 0.00610
epoch: 16 step: 1000 Loss: 0.00554
epoch: 16 step: 1050 Loss: 0.00666
epoch: 16 step: 1100 Loss: 0.00572
epoch: 16 step: 1150 Loss: 0.00581
epoch: 16 step: 1200 Loss: 0.00751
epoch: 16 step: 1250 Loss: 0.00841
epoch: 16 step: 1300 Loss: 0.00993
epoch: 16 step: 1350 Loss: 0.00493
epoch: 16 step: 1400 Loss: 0.01052
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 17 step: 0 Loss: 0.00557
epoch: 17 step: 50 Loss: 0.00780
epoch: 17 step: 100 Loss: 0.00608
epoch: 17 step: 150 Loss: 0.01002
epoch: 17 step: 200 Loss: 0.00738
epoch: 17 step: 250 Loss: 0.00453
epoch: 17 step: 300 Loss: 0.00662
epoch: 17 step: 350 Loss: 0.00538
epoch: 17 step: 400 Loss: 0.00569
epoch: 17 step: 450 Loss: 0.00571
epoch: 17 step: 500 Loss: 0.00742
epoch: 17 step: 550 Loss: 0.00894
epoch: 17 step: 600 Loss: 0.00803
epoch: 17 step: 650 Loss: 0.00858
epoch: 17 step: 700 Loss: 0.00667
epoch: 17 step: 750 Loss: 0.00681
epoch: 17 step: 800 Loss: 0.00608
epoch: 17 step: 850 Loss: 0.00555
epoch: 17 step: 900 Loss: 0.00533
epoch: 17 step: 950 Loss: 0.00650
epoch: 17 step: 1000 Loss: 0.00556
epoch: 17 step: 1050 Loss: 0.00727
epoch: 17 step: 1100 Loss: 0.00595
epoch: 17 step: 1150 Loss: 0.00577
epoch: 17 step: 1200 Loss: 0.00664
epoch: 17 step: 1250 Loss: 0.00960
epoch: 17 step: 1300 Loss: 0.00682
epoch: 17 step: 1350 Loss: 0.00541
epoch: 17 step: 1400 Loss: 0.00582
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 18 step: 0 Loss: 0.00693
epoch: 18 step: 50 Loss: 0.00572
epoch: 18 step: 100 Loss: 0.00617
epoch: 18 step: 150 Loss: 0.00630
epoch: 18 step: 200 Loss: 0.00855
epoch: 18 step: 250 Loss: 0.01055
epoch: 18 step: 300 Loss: 0.00416
epoch: 18 step: 350 Loss: 0.00835
epoch: 18 step: 400 Loss: 0.00591
epoch: 18 step: 450 Loss: 0.00613
epoch: 18 step: 500 Loss: 0.00749
epoch: 18 step: 550 Loss: 0.00588
epoch: 18 step: 600 Loss: 0.00568
epoch: 18 step: 650 Loss: 0.00566
epoch: 18 step: 700 Loss: 0.00428
epoch: 18 step: 750 Loss: 0.00729
epoch: 18 step: 800 Loss: 0.00496
epoch: 18 step: 850 Loss: 0.00706
epoch: 18 step: 900 Loss: 0.00399
epoch: 18 step: 950 Loss: 0.00628
epoch: 18 step: 1000 Loss: 0.00527
epoch: 18 step: 1050 Loss: 0.00397
epoch: 18 step: 1100 Loss: 0.00787
epoch: 18 step: 1150 Loss: 0.00586
epoch: 18 step: 1200 Loss: 0.00690
epoch: 18 step: 1250 Loss: 0.00349
epoch: 18 step: 1300 Loss: 0.00554
epoch: 18 step: 1350 Loss: 0.00475
epoch: 18 step: 1400 Loss: 0.00791
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 19 step: 0 Loss: 0.00691
epoch: 19 step: 50 Loss: 0.00557
epoch: 19 step: 100 Loss: 0.00613
epoch: 19 step: 150 Loss: 0.00482
epoch: 19 step: 200 Loss: 0.00489
epoch: 19 step: 250 Loss: 0.00701
epoch: 19 step: 300 Loss: 0.00653
epoch: 19 step: 350 Loss: 0.00758
epoch: 19 step: 400 Loss: 0.00561
epoch: 19 step: 450 Loss: 0.00682
epoch: 19 step: 500 Loss: 0.00695
epoch: 19 step: 550 Loss: 0.00723
epoch: 19 step: 600 Loss: 0.00710
epoch: 19 step: 650 Loss: 0.00771
epoch: 19 step: 700 Loss: 0.00538
epoch: 19 step: 750 Loss: 0.00636
epoch: 19 step: 800 Loss: 0.00688
epoch: 19 step: 850 Loss: 0.00833
epoch: 19 step: 900 Loss: 0.00399
epoch: 19 step: 950 Loss: 0.00446
epoch: 19 step: 1000 Loss: 0.00431
epoch: 19 step: 1050 Loss: 0.00590
epoch: 19 step: 1100 Loss: 0.00709
epoch: 19 step: 1150 Loss: 0.00589
epoch: 19 step: 1200 Loss: 0.00705
epoch: 19 step: 1250 Loss: 0.00506
epoch: 19 step: 1300 Loss: 0.00617
epoch: 19 step: 1350 Loss: 0.00347
epoch: 19 step: 1400 Loss: 0.00621
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
torch.save(model, 'models/ddpm_celeba_128.pth')

In [None]:
from PIL import Image
import requests
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

reverse_transform = Compose([
    Lambda(lambda t: (t + 1) / 2),
    Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
    Lambda(lambda t: t * 255.),
    Lambda(lambda t: t.numpy().astype(np.uint8)),
    ToPILImage(),
])
def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)

    # turn back into PIL image
    noisy_image = reverse_transform(x_noisy.squeeze())

    return noisy_image

In [None]:
for x, _ in train_loader:
    save_image(x.add(1).mul(0.5).cpu(), results_folder / 'raw.png', nrow=10)
    x_noisy = q_sample(x, t=torch.tensor([300-1])).to(device)
    save_image(x_noisy.add(1).mul(0.5).cpu(), results_folder / 'noisy.png', nrow=10)
    break

In [None]:
img = x_noisy
for i in reversed(range(300)):
    img = p_sample(model, img, torch.full(
        (100,), i, device=device, dtype=torch.long), i)
save_image(img.add(1).mul(0.5), results_folder / 'reconstruction.png', nrow=10)

In [None]:
all_images_list = sample(model, image_size=image_size, batch_size=100, channels=channels)
all_images = all_images_list[-1].add(1).mul(0.5)
save_image(all_images, str(results_folder /
            f'sample_random.png'), nrow=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]