In [1]:
from functools import partial
# from utils.networkHelper import *
import math
from inspect import isfunction
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm.auto import tqdm
from pathlib import Path

from einops.layers.torch import Rearrange
from einops import rearrange, reduce

import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.optim import Adam
from torchvision import transforms, datasets
from torchvision.utils import save_image

In [2]:
# global varialbes
image_size = 64
channels = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
train_epochs = 20
timesteps = 1000
# device = torch.device('cuda:'+str(torch.cuda.device_count()-1)
                    #   if torch.cuda.is_available() else 'cpu')
data_path_root = '/data/wumin/dataset/celeba/'
bs = 100
results_folder = Path("./samples/ddpm/celeba")
results_folder.mkdir(exist_ok=True)

# MNIST Dataset
transform = transforms.Compose([
    transforms.CenterCrop((178,178)),
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
    ])

train_dataset = datasets.CelebA(
    root=data_path_root, split='train', transform=transform, download=True)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=bs, shuffle=True)

Files already downloaded and verified


In [3]:
for x, y in train_loader:
    print(x.shape)
    save_image(x.add(1).mul(0.5), results_folder/'real.png', nrow=10)
    break

torch.Size([100, 3, 64, 64])


In [3]:
def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    # 暂时没用到
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(
            half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class WeightStandardizedConv2d(nn.Conv2d):
    """
    权重标准化后的卷积模块
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1",
                     partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(
                t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(
                t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y",
                        h=self.heads, x=h, y=w)
        return self.to_out(out)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        self_condition=False,
        resnet_block_groups=4,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        # changed to 1 and 0 from 7,3
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out,
                                    time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out,
                                    time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(
        ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)


def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# forward diffusion (using the nice property)


def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )
    
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss


@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise

# Algorithm 2 (including returning all images)


@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full(
            (b,), i, device=device, dtype=torch.long), i)
        imgs.append(img)
    return torch.stack(imgs, dim=0)


@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


def show_result(model, epoch, batches=10, display_step=100, show=False, save=False,  path='result.png'):
    # save generated images
    ds = display_step
    all_images_list = sample(model, image_size=image_size, batch_size=batches, channels=channels)
    all_images = rearrange(
        all_images_list, 't b c h w -> (b t) c h w').add(1).mul(0.5)
    all_images = all_images[torch.arange(ds-1, all_images.shape[0], ds),]
    save_image(all_images, str(results_folder /
               f'sample_epoch_{epoch}.png'), nrow=timesteps//ds)


def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    # display the training losses as epoch increasing
    x = range(len(hist['losses']))

    y = hist['losses']

    plt.plot(x, y, label='Loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=5)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()
        
        
def train(model, optimizer, epoch, train_hist):
    losses = []
    for step, (batch, _) in enumerate(train_loader):
        optimizer.zero_grad()

        batch_size = batch.shape[0]
        batch = batch.to(device)

        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        
        loss = p_losses(model, batch, t, loss_type="huber")
        losses.append(loss.item())

        if step % 50 == 0:
            print("epoch: {} step: {} Loss: {:.5f}".format(
                epoch, step, loss.item()))
            
        loss.backward()
        optimizer.step()

    train_hist['losses'].append(np.mean(losses))
    print("================> epoch: {} average Loss: {:.5f}".format(
                epoch, train_hist['losses'][-1]))
    

In [5]:
len(train_loader)

1628

In [6]:
model = Unet(
    dim=image_size * 3,
    channels=channels,
    dim_mults=(1, 2, 4,)
).to(device)

# model = torch.load('models/ddpm_celeba.pth').to(device)
optimizer = Adam(model.parameters(), lr=2e-4)


In [7]:
train_hist = {}
train_hist['losses'] = []
print("start training")

for epoch in range(train_epochs):
    if (epoch+1) % 10 == 0:
        optimizer.param_groups[0]['lr'] /= 2
    train(model, optimizer, epoch, train_hist)
    
    # save generated images
    show_result(model, epoch)

show_train_hist(train_hist, False, True, results_folder / 'history_train_losses.png')

start training
epoch: 0 step: 0 Loss: 0.53420
epoch: 0 step: 50 Loss: 0.04274
epoch: 0 step: 100 Loss: 0.01921
epoch: 0 step: 150 Loss: 0.02371
epoch: 0 step: 200 Loss: 0.02255
epoch: 0 step: 250 Loss: 0.00853
epoch: 0 step: 300 Loss: 0.01424
epoch: 0 step: 350 Loss: 0.00805
epoch: 0 step: 400 Loss: 0.01630
epoch: 0 step: 450 Loss: 0.01760
epoch: 0 step: 500 Loss: 0.01693
epoch: 0 step: 550 Loss: 0.01494
epoch: 0 step: 600 Loss: 0.01629
epoch: 0 step: 650 Loss: 0.01541
epoch: 0 step: 700 Loss: 0.01683
epoch: 0 step: 750 Loss: 0.01003
epoch: 0 step: 800 Loss: 0.01048
epoch: 0 step: 850 Loss: 0.01338
epoch: 0 step: 900 Loss: 0.01212
epoch: 0 step: 950 Loss: 0.01143
epoch: 0 step: 1000 Loss: 0.01063
epoch: 0 step: 1050 Loss: 0.01329
epoch: 0 step: 1100 Loss: 0.01268
epoch: 0 step: 1150 Loss: 0.01115
epoch: 0 step: 1200 Loss: 0.00900
epoch: 0 step: 1250 Loss: 0.01111
epoch: 0 step: 1300 Loss: 0.01104
epoch: 0 step: 1350 Loss: 0.01077
epoch: 0 step: 1400 Loss: 0.01168
epoch: 0 step: 1450 Lo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 1 step: 0 Loss: 0.01040
epoch: 1 step: 50 Loss: 0.00665
epoch: 1 step: 100 Loss: 0.00941
epoch: 1 step: 150 Loss: 0.01146
epoch: 1 step: 200 Loss: 0.01117
epoch: 1 step: 250 Loss: 0.01205
epoch: 1 step: 300 Loss: 0.00984
epoch: 1 step: 350 Loss: 0.00759
epoch: 1 step: 400 Loss: 0.00941
epoch: 1 step: 450 Loss: 0.00811
epoch: 1 step: 500 Loss: 0.00878
epoch: 1 step: 550 Loss: 0.00546
epoch: 1 step: 600 Loss: 0.00933
epoch: 1 step: 650 Loss: 0.01537
epoch: 1 step: 700 Loss: 0.00800
epoch: 1 step: 750 Loss: 0.00757
epoch: 1 step: 800 Loss: 0.01089
epoch: 1 step: 850 Loss: 0.01022
epoch: 1 step: 900 Loss: 0.00510
epoch: 1 step: 950 Loss: 0.01194
epoch: 1 step: 1000 Loss: 0.00673
epoch: 1 step: 1050 Loss: 0.01009
epoch: 1 step: 1100 Loss: 0.01301
epoch: 1 step: 1150 Loss: 0.01007
epoch: 1 step: 1200 Loss: 0.00990
epoch: 1 step: 1250 Loss: 0.00886
epoch: 1 step: 1300 Loss: 0.01207
epoch: 1 step: 1350 Loss: 0.00964
epoch: 1 step: 1400 Loss: 0.01280
epoch: 1 step: 1450 Loss: 0.01198
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 2 step: 0 Loss: 0.00876
epoch: 2 step: 50 Loss: 0.00548
epoch: 2 step: 100 Loss: 0.00671
epoch: 2 step: 150 Loss: 0.00742
epoch: 2 step: 200 Loss: 0.01011
epoch: 2 step: 250 Loss: 0.00833
epoch: 2 step: 300 Loss: 0.01218
epoch: 2 step: 350 Loss: 0.01073
epoch: 2 step: 400 Loss: 0.00756
epoch: 2 step: 450 Loss: 0.00910
epoch: 2 step: 500 Loss: 0.00581
epoch: 2 step: 550 Loss: 0.00663
epoch: 2 step: 600 Loss: 0.00656
epoch: 2 step: 650 Loss: 0.00922
epoch: 2 step: 700 Loss: 0.00887
epoch: 2 step: 750 Loss: 0.01297
epoch: 2 step: 800 Loss: 0.00862
epoch: 2 step: 850 Loss: 0.00649
epoch: 2 step: 900 Loss: 0.00768
epoch: 2 step: 950 Loss: 0.00893
epoch: 2 step: 1000 Loss: 0.00715
epoch: 2 step: 1050 Loss: 0.01142
epoch: 2 step: 1100 Loss: 0.01011
epoch: 2 step: 1150 Loss: 0.00826
epoch: 2 step: 1200 Loss: 0.01550
epoch: 2 step: 1250 Loss: 0.00563
epoch: 2 step: 1300 Loss: 0.00966
epoch: 2 step: 1350 Loss: 0.00856
epoch: 2 step: 1400 Loss: 0.00787
epoch: 2 step: 1450 Loss: 0.00909
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 3 step: 0 Loss: 0.01048
epoch: 3 step: 50 Loss: 0.00960
epoch: 3 step: 100 Loss: 0.00689
epoch: 3 step: 150 Loss: 0.00771
epoch: 3 step: 200 Loss: 0.00822
epoch: 3 step: 250 Loss: 0.00811
epoch: 3 step: 300 Loss: 0.01693
epoch: 3 step: 350 Loss: 0.00774
epoch: 3 step: 400 Loss: 0.01165
epoch: 3 step: 450 Loss: 0.00784
epoch: 3 step: 500 Loss: 0.00839
epoch: 3 step: 550 Loss: 0.00835
epoch: 3 step: 600 Loss: 0.01090
epoch: 3 step: 650 Loss: 0.00769
epoch: 3 step: 700 Loss: 0.00657
epoch: 3 step: 750 Loss: 0.01216
epoch: 3 step: 800 Loss: 0.00774
epoch: 3 step: 850 Loss: 0.00854
epoch: 3 step: 900 Loss: 0.01410
epoch: 3 step: 950 Loss: 0.01302
epoch: 3 step: 1000 Loss: 0.01165
epoch: 3 step: 1050 Loss: 0.01050
epoch: 3 step: 1100 Loss: 0.00741
epoch: 3 step: 1150 Loss: 0.00712
epoch: 3 step: 1200 Loss: 0.00854
epoch: 3 step: 1250 Loss: 0.00979
epoch: 3 step: 1300 Loss: 0.00693
epoch: 3 step: 1350 Loss: 0.00648
epoch: 3 step: 1400 Loss: 0.01170
epoch: 3 step: 1450 Loss: 0.00824
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 4 step: 0 Loss: 0.00660
epoch: 4 step: 50 Loss: 0.01092
epoch: 4 step: 100 Loss: 0.01110
epoch: 4 step: 150 Loss: 0.00967
epoch: 4 step: 200 Loss: 0.01075
epoch: 4 step: 250 Loss: 0.00970
epoch: 4 step: 300 Loss: 0.00937
epoch: 4 step: 350 Loss: 0.00664
epoch: 4 step: 400 Loss: 0.00933
epoch: 4 step: 450 Loss: 0.01085
epoch: 4 step: 500 Loss: 0.00835
epoch: 4 step: 550 Loss: 0.00471
epoch: 4 step: 600 Loss: 0.00760
epoch: 4 step: 650 Loss: 0.00674
epoch: 4 step: 700 Loss: 0.00702
epoch: 4 step: 750 Loss: 0.01057
epoch: 4 step: 800 Loss: 0.00668
epoch: 4 step: 850 Loss: 0.00874
epoch: 4 step: 900 Loss: 0.01183
epoch: 4 step: 950 Loss: 0.00793
epoch: 4 step: 1000 Loss: 0.00847
epoch: 4 step: 1050 Loss: 0.00940
epoch: 4 step: 1100 Loss: 0.00574
epoch: 4 step: 1150 Loss: 0.01033
epoch: 4 step: 1200 Loss: 0.00750
epoch: 4 step: 1250 Loss: 0.00886
epoch: 4 step: 1300 Loss: 0.00965
epoch: 4 step: 1350 Loss: 0.00846
epoch: 4 step: 1400 Loss: 0.00776
epoch: 4 step: 1450 Loss: 0.00852
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 5 step: 0 Loss: 0.00799
epoch: 5 step: 50 Loss: 0.00849
epoch: 5 step: 100 Loss: 0.00843
epoch: 5 step: 150 Loss: 0.00787
epoch: 5 step: 200 Loss: 0.00618
epoch: 5 step: 250 Loss: 0.00732
epoch: 5 step: 300 Loss: 0.00726
epoch: 5 step: 350 Loss: 0.00704
epoch: 5 step: 400 Loss: 0.01133
epoch: 5 step: 450 Loss: 0.00838
epoch: 5 step: 500 Loss: 0.01240
epoch: 5 step: 550 Loss: 0.00686
epoch: 5 step: 600 Loss: 0.00789
epoch: 5 step: 650 Loss: 0.00730
epoch: 5 step: 700 Loss: 0.00752
epoch: 5 step: 750 Loss: 0.00710
epoch: 5 step: 800 Loss: 0.00775
epoch: 5 step: 850 Loss: 0.00648
epoch: 5 step: 900 Loss: 0.00906
epoch: 5 step: 950 Loss: 0.00971
epoch: 5 step: 1000 Loss: 0.01147
epoch: 5 step: 1050 Loss: 0.01135
epoch: 5 step: 1100 Loss: 0.00941
epoch: 5 step: 1150 Loss: 0.00992
epoch: 5 step: 1200 Loss: 0.00780
epoch: 5 step: 1250 Loss: 0.00879
epoch: 5 step: 1300 Loss: 0.00590
epoch: 5 step: 1350 Loss: 0.01503
epoch: 5 step: 1400 Loss: 0.01091
epoch: 5 step: 1450 Loss: 0.00650
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 6 step: 0 Loss: 0.00620
epoch: 6 step: 50 Loss: 0.00883
epoch: 6 step: 100 Loss: 0.00880
epoch: 6 step: 150 Loss: 0.00713
epoch: 6 step: 200 Loss: 0.00939
epoch: 6 step: 250 Loss: 0.00910
epoch: 6 step: 300 Loss: 0.00627
epoch: 6 step: 350 Loss: 0.01016
epoch: 6 step: 400 Loss: 0.00787
epoch: 6 step: 450 Loss: 0.01043
epoch: 6 step: 500 Loss: 0.00864
epoch: 6 step: 550 Loss: 0.00730
epoch: 6 step: 600 Loss: 0.00693
epoch: 6 step: 650 Loss: 0.00883
epoch: 6 step: 700 Loss: 0.01108
epoch: 6 step: 750 Loss: 0.00785
epoch: 6 step: 800 Loss: 0.01150
epoch: 6 step: 850 Loss: 0.01129
epoch: 6 step: 900 Loss: 0.00847
epoch: 6 step: 950 Loss: 0.01059
epoch: 6 step: 1000 Loss: 0.00529
epoch: 6 step: 1050 Loss: 0.00949
epoch: 6 step: 1100 Loss: 0.00663
epoch: 6 step: 1150 Loss: 0.00940
epoch: 6 step: 1200 Loss: 0.00875
epoch: 6 step: 1250 Loss: 0.00771
epoch: 6 step: 1300 Loss: 0.00800
epoch: 6 step: 1350 Loss: 0.01013
epoch: 6 step: 1400 Loss: 0.01395
epoch: 6 step: 1450 Loss: 0.00667
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 7 step: 0 Loss: 0.01266
epoch: 7 step: 50 Loss: 0.00783
epoch: 7 step: 100 Loss: 0.00528
epoch: 7 step: 150 Loss: 0.00874
epoch: 7 step: 200 Loss: 0.00958
epoch: 7 step: 250 Loss: 0.00677
epoch: 7 step: 300 Loss: 0.00822
epoch: 7 step: 350 Loss: 0.00886
epoch: 7 step: 400 Loss: 0.00774
epoch: 7 step: 450 Loss: 0.00572
epoch: 7 step: 500 Loss: 0.00635
epoch: 7 step: 550 Loss: 0.00915
epoch: 7 step: 600 Loss: 0.01258
epoch: 7 step: 650 Loss: 0.00788
epoch: 7 step: 700 Loss: 0.01111
epoch: 7 step: 750 Loss: 0.00591
epoch: 7 step: 800 Loss: 0.00694
epoch: 7 step: 850 Loss: 0.00993
epoch: 7 step: 900 Loss: 0.00856
epoch: 7 step: 950 Loss: 0.00710
epoch: 7 step: 1000 Loss: 0.00950
epoch: 7 step: 1050 Loss: 0.00728
epoch: 7 step: 1100 Loss: 0.00723
epoch: 7 step: 1150 Loss: 0.00925
epoch: 7 step: 1200 Loss: 0.00885
epoch: 7 step: 1250 Loss: 0.00976
epoch: 7 step: 1300 Loss: 0.00646
epoch: 7 step: 1350 Loss: 0.00735
epoch: 7 step: 1400 Loss: 0.01114
epoch: 7 step: 1450 Loss: 0.00756
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 8 step: 0 Loss: 0.00807
epoch: 8 step: 50 Loss: 0.00827
epoch: 8 step: 100 Loss: 0.00524
epoch: 8 step: 150 Loss: 0.00752
epoch: 8 step: 200 Loss: 0.00685
epoch: 8 step: 250 Loss: 0.00890
epoch: 8 step: 300 Loss: 0.00846
epoch: 8 step: 350 Loss: 0.01104
epoch: 8 step: 400 Loss: 0.00521
epoch: 8 step: 450 Loss: 0.01499
epoch: 8 step: 500 Loss: 0.00921
epoch: 8 step: 550 Loss: 0.00467
epoch: 8 step: 600 Loss: 0.00750
epoch: 8 step: 650 Loss: 0.01163
epoch: 8 step: 700 Loss: 0.00525
epoch: 8 step: 750 Loss: 0.00962
epoch: 8 step: 800 Loss: 0.00864
epoch: 8 step: 850 Loss: 0.00802
epoch: 8 step: 900 Loss: 0.00919
epoch: 8 step: 950 Loss: 0.00923
epoch: 8 step: 1000 Loss: 0.00984
epoch: 8 step: 1050 Loss: 0.01017
epoch: 8 step: 1100 Loss: 0.00780
epoch: 8 step: 1150 Loss: 0.00821
epoch: 8 step: 1200 Loss: 0.01162
epoch: 8 step: 1250 Loss: 0.00855
epoch: 8 step: 1300 Loss: 0.00925
epoch: 8 step: 1350 Loss: 0.00560
epoch: 8 step: 1400 Loss: 0.00673
epoch: 8 step: 1450 Loss: 0.00786
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 9 step: 0 Loss: 0.00596
epoch: 9 step: 50 Loss: 0.00809
epoch: 9 step: 100 Loss: 0.00965
epoch: 9 step: 150 Loss: 0.00924
epoch: 9 step: 200 Loss: 0.00762
epoch: 9 step: 250 Loss: 0.00744
epoch: 9 step: 300 Loss: 0.01075
epoch: 9 step: 350 Loss: 0.00528
epoch: 9 step: 400 Loss: 0.00919
epoch: 9 step: 450 Loss: 0.00676
epoch: 9 step: 500 Loss: 0.00729
epoch: 9 step: 550 Loss: 0.00640
epoch: 9 step: 600 Loss: 0.01086
epoch: 9 step: 650 Loss: 0.00784
epoch: 9 step: 700 Loss: 0.01179
epoch: 9 step: 750 Loss: 0.01102
epoch: 9 step: 800 Loss: 0.00893
epoch: 9 step: 850 Loss: 0.00726
epoch: 9 step: 900 Loss: 0.00749
epoch: 9 step: 950 Loss: 0.01376
epoch: 9 step: 1000 Loss: 0.00402
epoch: 9 step: 1050 Loss: 0.00893
epoch: 9 step: 1100 Loss: 0.00700
epoch: 9 step: 1150 Loss: 0.01137
epoch: 9 step: 1200 Loss: 0.00611
epoch: 9 step: 1250 Loss: 0.00822
epoch: 9 step: 1300 Loss: 0.01017
epoch: 9 step: 1350 Loss: 0.01015
epoch: 9 step: 1400 Loss: 0.00777
epoch: 9 step: 1450 Loss: 0.00620
epo

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 10 step: 0 Loss: 0.00896
epoch: 10 step: 50 Loss: 0.00807
epoch: 10 step: 100 Loss: 0.00518
epoch: 10 step: 150 Loss: 0.00961
epoch: 10 step: 200 Loss: 0.00387
epoch: 10 step: 250 Loss: 0.00766
epoch: 10 step: 300 Loss: 0.01585
epoch: 10 step: 350 Loss: 0.00770
epoch: 10 step: 400 Loss: 0.00811
epoch: 10 step: 450 Loss: 0.00504
epoch: 10 step: 500 Loss: 0.00802
epoch: 10 step: 550 Loss: 0.00868
epoch: 10 step: 600 Loss: 0.00871
epoch: 10 step: 650 Loss: 0.00811
epoch: 10 step: 700 Loss: 0.00699
epoch: 10 step: 750 Loss: 0.01108
epoch: 10 step: 800 Loss: 0.00904
epoch: 10 step: 850 Loss: 0.00852
epoch: 10 step: 900 Loss: 0.00751
epoch: 10 step: 950 Loss: 0.00675
epoch: 10 step: 1000 Loss: 0.00874
epoch: 10 step: 1050 Loss: 0.00911
epoch: 10 step: 1100 Loss: 0.00933
epoch: 10 step: 1150 Loss: 0.00835
epoch: 10 step: 1200 Loss: 0.00772
epoch: 10 step: 1250 Loss: 0.00787
epoch: 10 step: 1300 Loss: 0.00729
epoch: 10 step: 1350 Loss: 0.00857
epoch: 10 step: 1400 Loss: 0.00645
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 11 step: 0 Loss: 0.01165
epoch: 11 step: 50 Loss: 0.00935
epoch: 11 step: 100 Loss: 0.00509
epoch: 11 step: 150 Loss: 0.00593
epoch: 11 step: 200 Loss: 0.00645
epoch: 11 step: 250 Loss: 0.01111
epoch: 11 step: 300 Loss: 0.00602
epoch: 11 step: 350 Loss: 0.00767
epoch: 11 step: 400 Loss: 0.00819
epoch: 11 step: 450 Loss: 0.00758
epoch: 11 step: 500 Loss: 0.00951
epoch: 11 step: 550 Loss: 0.00680
epoch: 11 step: 600 Loss: 0.00758
epoch: 11 step: 650 Loss: 0.00690
epoch: 11 step: 700 Loss: 0.00651
epoch: 11 step: 750 Loss: 0.00538
epoch: 11 step: 800 Loss: 0.00833
epoch: 11 step: 850 Loss: 0.00912
epoch: 11 step: 900 Loss: 0.00572
epoch: 11 step: 950 Loss: 0.00620
epoch: 11 step: 1000 Loss: 0.00610
epoch: 11 step: 1050 Loss: 0.00721
epoch: 11 step: 1100 Loss: 0.00902
epoch: 11 step: 1150 Loss: 0.00738
epoch: 11 step: 1200 Loss: 0.01115
epoch: 11 step: 1250 Loss: 0.00855
epoch: 11 step: 1300 Loss: 0.00856
epoch: 11 step: 1350 Loss: 0.00880
epoch: 11 step: 1400 Loss: 0.00769
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 12 step: 0 Loss: 0.00784
epoch: 12 step: 50 Loss: 0.00930
epoch: 12 step: 100 Loss: 0.00868
epoch: 12 step: 150 Loss: 0.00670
epoch: 12 step: 200 Loss: 0.00641
epoch: 12 step: 250 Loss: 0.00676
epoch: 12 step: 300 Loss: 0.00610
epoch: 12 step: 350 Loss: 0.00993
epoch: 12 step: 400 Loss: 0.00990
epoch: 12 step: 450 Loss: 0.00831
epoch: 12 step: 500 Loss: 0.00804
epoch: 12 step: 550 Loss: 0.01030
epoch: 12 step: 600 Loss: 0.00800
epoch: 12 step: 650 Loss: 0.00625
epoch: 12 step: 700 Loss: 0.00585
epoch: 12 step: 750 Loss: 0.00829
epoch: 12 step: 800 Loss: 0.00864
epoch: 12 step: 850 Loss: 0.00657
epoch: 12 step: 900 Loss: 0.01118
epoch: 12 step: 950 Loss: 0.01058
epoch: 12 step: 1000 Loss: 0.00721
epoch: 12 step: 1050 Loss: 0.00901
epoch: 12 step: 1100 Loss: 0.01096
epoch: 12 step: 1150 Loss: 0.01055
epoch: 12 step: 1200 Loss: 0.00749
epoch: 12 step: 1250 Loss: 0.00893
epoch: 12 step: 1300 Loss: 0.00660
epoch: 12 step: 1350 Loss: 0.00980
epoch: 12 step: 1400 Loss: 0.01447
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 13 step: 0 Loss: 0.00578
epoch: 13 step: 50 Loss: 0.00542
epoch: 13 step: 100 Loss: 0.00660
epoch: 13 step: 150 Loss: 0.00813
epoch: 13 step: 200 Loss: 0.01158
epoch: 13 step: 250 Loss: 0.01148
epoch: 13 step: 300 Loss: 0.00947
epoch: 13 step: 350 Loss: 0.00893
epoch: 13 step: 400 Loss: 0.00701
epoch: 13 step: 450 Loss: 0.01463
epoch: 13 step: 500 Loss: 0.00685
epoch: 13 step: 550 Loss: 0.01006
epoch: 13 step: 600 Loss: 0.00724
epoch: 13 step: 650 Loss: 0.00969
epoch: 13 step: 700 Loss: 0.00727
epoch: 13 step: 750 Loss: 0.00609
epoch: 13 step: 800 Loss: 0.01090
epoch: 13 step: 850 Loss: 0.01217
epoch: 13 step: 900 Loss: 0.00771
epoch: 13 step: 950 Loss: 0.00833
epoch: 13 step: 1000 Loss: 0.00840
epoch: 13 step: 1050 Loss: 0.00552
epoch: 13 step: 1100 Loss: 0.00645
epoch: 13 step: 1150 Loss: 0.00787
epoch: 13 step: 1200 Loss: 0.00637
epoch: 13 step: 1250 Loss: 0.01031
epoch: 13 step: 1300 Loss: 0.01194
epoch: 13 step: 1350 Loss: 0.00419
epoch: 13 step: 1400 Loss: 0.00653
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 14 step: 0 Loss: 0.00503
epoch: 14 step: 50 Loss: 0.00660
epoch: 14 step: 100 Loss: 0.00633
epoch: 14 step: 150 Loss: 0.00899
epoch: 14 step: 200 Loss: 0.00844
epoch: 14 step: 250 Loss: 0.00839
epoch: 14 step: 300 Loss: 0.00862
epoch: 14 step: 350 Loss: 0.00730
epoch: 14 step: 400 Loss: 0.00792
epoch: 14 step: 450 Loss: 0.00862
epoch: 14 step: 500 Loss: 0.00726
epoch: 14 step: 550 Loss: 0.00500
epoch: 14 step: 600 Loss: 0.00831
epoch: 14 step: 650 Loss: 0.01075
epoch: 14 step: 700 Loss: 0.01180
epoch: 14 step: 750 Loss: 0.00735
epoch: 14 step: 800 Loss: 0.00972
epoch: 14 step: 850 Loss: 0.00795
epoch: 14 step: 900 Loss: 0.00574
epoch: 14 step: 950 Loss: 0.00727
epoch: 14 step: 1000 Loss: 0.00601
epoch: 14 step: 1050 Loss: 0.01108
epoch: 14 step: 1100 Loss: 0.00931
epoch: 14 step: 1150 Loss: 0.00677
epoch: 14 step: 1200 Loss: 0.00789
epoch: 14 step: 1250 Loss: 0.00625
epoch: 14 step: 1300 Loss: 0.00945
epoch: 14 step: 1350 Loss: 0.00761
epoch: 14 step: 1400 Loss: 0.00584
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 15 step: 0 Loss: 0.00569
epoch: 15 step: 50 Loss: 0.00945
epoch: 15 step: 100 Loss: 0.00668
epoch: 15 step: 150 Loss: 0.00949
epoch: 15 step: 200 Loss: 0.00727
epoch: 15 step: 250 Loss: 0.00809
epoch: 15 step: 300 Loss: 0.00901
epoch: 15 step: 350 Loss: 0.00799
epoch: 15 step: 400 Loss: 0.01004
epoch: 15 step: 450 Loss: 0.00874
epoch: 15 step: 500 Loss: 0.00814
epoch: 15 step: 550 Loss: 0.01077
epoch: 15 step: 600 Loss: 0.00515
epoch: 15 step: 650 Loss: 0.00656
epoch: 15 step: 700 Loss: 0.00920
epoch: 15 step: 750 Loss: 0.00840
epoch: 15 step: 800 Loss: 0.00932
epoch: 15 step: 850 Loss: 0.00880
epoch: 15 step: 900 Loss: 0.00655
epoch: 15 step: 950 Loss: 0.00649
epoch: 15 step: 1000 Loss: 0.00835
epoch: 15 step: 1050 Loss: 0.01108
epoch: 15 step: 1100 Loss: 0.00826
epoch: 15 step: 1150 Loss: 0.00996
epoch: 15 step: 1200 Loss: 0.00839
epoch: 15 step: 1250 Loss: 0.00826
epoch: 15 step: 1300 Loss: 0.00566
epoch: 15 step: 1350 Loss: 0.00887
epoch: 15 step: 1400 Loss: 0.00773
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 16 step: 0 Loss: 0.01006
epoch: 16 step: 50 Loss: 0.00909
epoch: 16 step: 100 Loss: 0.00840
epoch: 16 step: 150 Loss: 0.00783
epoch: 16 step: 200 Loss: 0.00770
epoch: 16 step: 250 Loss: 0.00829
epoch: 16 step: 300 Loss: 0.00903
epoch: 16 step: 350 Loss: 0.01018
epoch: 16 step: 400 Loss: 0.00870
epoch: 16 step: 450 Loss: 0.00974
epoch: 16 step: 500 Loss: 0.01140
epoch: 16 step: 550 Loss: 0.00989
epoch: 16 step: 600 Loss: 0.00652
epoch: 16 step: 650 Loss: 0.00854
epoch: 16 step: 700 Loss: 0.01234
epoch: 16 step: 750 Loss: 0.00706
epoch: 16 step: 800 Loss: 0.00824
epoch: 16 step: 850 Loss: 0.00812
epoch: 16 step: 900 Loss: 0.00777
epoch: 16 step: 950 Loss: 0.00732
epoch: 16 step: 1000 Loss: 0.00578
epoch: 16 step: 1050 Loss: 0.00525
epoch: 16 step: 1100 Loss: 0.01209
epoch: 16 step: 1150 Loss: 0.00630
epoch: 16 step: 1200 Loss: 0.00873
epoch: 16 step: 1250 Loss: 0.01041
epoch: 16 step: 1300 Loss: 0.00983
epoch: 16 step: 1350 Loss: 0.01057
epoch: 16 step: 1400 Loss: 0.00935
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 17 step: 0 Loss: 0.00602
epoch: 17 step: 50 Loss: 0.00908
epoch: 17 step: 100 Loss: 0.00862
epoch: 17 step: 150 Loss: 0.00658
epoch: 17 step: 200 Loss: 0.00675
epoch: 17 step: 250 Loss: 0.00617
epoch: 17 step: 300 Loss: 0.00656
epoch: 17 step: 350 Loss: 0.00865
epoch: 17 step: 400 Loss: 0.00457
epoch: 17 step: 450 Loss: 0.00910
epoch: 17 step: 500 Loss: 0.00751
epoch: 17 step: 550 Loss: 0.00836
epoch: 17 step: 600 Loss: 0.00670
epoch: 17 step: 650 Loss: 0.00730
epoch: 17 step: 700 Loss: 0.00542
epoch: 17 step: 750 Loss: 0.00879
epoch: 17 step: 800 Loss: 0.00568
epoch: 17 step: 850 Loss: 0.00810
epoch: 17 step: 900 Loss: 0.00640
epoch: 17 step: 950 Loss: 0.00975
epoch: 17 step: 1000 Loss: 0.00725
epoch: 17 step: 1050 Loss: 0.00634
epoch: 17 step: 1100 Loss: 0.00710
epoch: 17 step: 1150 Loss: 0.00693
epoch: 17 step: 1200 Loss: 0.00874
epoch: 17 step: 1250 Loss: 0.01031
epoch: 17 step: 1300 Loss: 0.00666
epoch: 17 step: 1350 Loss: 0.00695
epoch: 17 step: 1400 Loss: 0.00831
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 18 step: 0 Loss: 0.00903
epoch: 18 step: 50 Loss: 0.00870
epoch: 18 step: 100 Loss: 0.00785
epoch: 18 step: 150 Loss: 0.00883
epoch: 18 step: 200 Loss: 0.00800
epoch: 18 step: 250 Loss: 0.00520
epoch: 18 step: 300 Loss: 0.00543
epoch: 18 step: 350 Loss: 0.01014
epoch: 18 step: 400 Loss: 0.00908
epoch: 18 step: 450 Loss: 0.00456
epoch: 18 step: 500 Loss: 0.00968
epoch: 18 step: 550 Loss: 0.00688
epoch: 18 step: 600 Loss: 0.00670
epoch: 18 step: 650 Loss: 0.00922
epoch: 18 step: 700 Loss: 0.00770
epoch: 18 step: 750 Loss: 0.00959
epoch: 18 step: 800 Loss: 0.00870
epoch: 18 step: 850 Loss: 0.00630
epoch: 18 step: 900 Loss: 0.00981
epoch: 18 step: 950 Loss: 0.00964
epoch: 18 step: 1000 Loss: 0.00691
epoch: 18 step: 1050 Loss: 0.00638
epoch: 18 step: 1100 Loss: 0.00411
epoch: 18 step: 1150 Loss: 0.00550
epoch: 18 step: 1200 Loss: 0.00819
epoch: 18 step: 1250 Loss: 0.00746
epoch: 18 step: 1300 Loss: 0.00824
epoch: 18 step: 1350 Loss: 0.00848
epoch: 18 step: 1400 Loss: 0.00809
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

epoch: 19 step: 0 Loss: 0.00815
epoch: 19 step: 50 Loss: 0.00800
epoch: 19 step: 100 Loss: 0.00895
epoch: 19 step: 150 Loss: 0.00966
epoch: 19 step: 200 Loss: 0.00643
epoch: 19 step: 250 Loss: 0.00801
epoch: 19 step: 300 Loss: 0.00682
epoch: 19 step: 350 Loss: 0.00536
epoch: 19 step: 400 Loss: 0.00859
epoch: 19 step: 450 Loss: 0.00748
epoch: 19 step: 500 Loss: 0.00919
epoch: 19 step: 550 Loss: 0.00945
epoch: 19 step: 600 Loss: 0.00877
epoch: 19 step: 650 Loss: 0.00612
epoch: 19 step: 700 Loss: 0.00983
epoch: 19 step: 750 Loss: 0.00772
epoch: 19 step: 800 Loss: 0.00714
epoch: 19 step: 850 Loss: 0.01009
epoch: 19 step: 900 Loss: 0.00598
epoch: 19 step: 950 Loss: 0.00735
epoch: 19 step: 1000 Loss: 0.00592
epoch: 19 step: 1050 Loss: 0.00825
epoch: 19 step: 1100 Loss: 0.00887
epoch: 19 step: 1150 Loss: 0.00645
epoch: 19 step: 1200 Loss: 0.00812
epoch: 19 step: 1250 Loss: 0.00751
epoch: 19 step: 1300 Loss: 0.00807
epoch: 19 step: 1350 Loss: 0.00788
epoch: 19 step: 1400 Loss: 0.00764
epoch: 1

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

In [12]:
torch.save(model, 'models/ddpm_celeba.pth')

In [9]:
from PIL import Image
import requests
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

reverse_transform = Compose([
    Lambda(lambda t: (t + 1) / 2),
    Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
    Lambda(lambda t: t * 255.),
    Lambda(lambda t: t.numpy().astype(np.uint8)),
    ToPILImage(),
])
def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)

    # turn back into PIL image
    noisy_image = reverse_transform(x_noisy.squeeze())

    return noisy_image

In [17]:
for x, _ in train_loader:
    save_image(x.add(1).mul(0.5).cpu(), results_folder / 'raw.png', nrow=10)
    x_noisy = q_sample(x, t=torch.tensor([300-1])).to(device)
    save_image(x_noisy.add(1).mul(0.5).cpu(), results_folder / 'noisy.png', nrow=10)
    break

In [19]:
img = x_noisy
for i in reversed(range(300)):
    img = p_sample(model, img, torch.full(
        (100,), i, device=device, dtype=torch.long), i)
save_image(img.add(1).mul(0.5), results_folder / 'reconstruction.png', nrow=10)

In [20]:
all_images_list = sample(model, image_size=image_size, batch_size=100, channels=channels)
all_images = all_images_list[-1].add(1).mul(0.5)
save_image(all_images, str(results_folder /
            f'sample_random.png'), nrow=10)

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
# TODO test interpolation

In [29]:
@torch.no_grad()
def p_sample_from(x_start, t, n):
    img = x_start.repeat(n, 1, 1, 1)torch.tensor([]).repeat
    save_image(img.add(1).mul(0.5).cpu(), results_folder / 'sample_start.png', nrow=torch.sqrt(torch.tensor(n)).int())
    for i in tqdm(reversed(range(t)), total=t):
        img = p_sample(model, img, torch.full(
            (n,), i, device=device, dtype=torch.long), i)
    save_image(img.add(1).mul(0.5).cpu(), results_folder / 'sample_end.png', nrow=torch.sqrt(torch.tensor(n)).int())

In [39]:
t = 400
n = 16
for x, y in train_dataset:
    if y[15] == 1:
        save_image(x.add(1).mul(0.5).cpu(), results_folder / 'sample_real.png')
        x_start = q_sample(x, torch.tensor([t-1])).to(device)
        p_sample_from(x_start, t, n)
        break

  0%|          | 0/400 [00:00<?, ?it/s]

In [28]:
for x, y in train_dataset: 
    x = x.unsqueeze(0)
    print(x.shape)
    break

torch.Size([1, 3, 64, 64])
