In [None]:
import torch
import torchvision

import time
import random

import numpy as np
import matplotlib.pyplot as plt

In [None]:
NUM_EPOCHS = 500
LEARNING_RATE = 0.0001 
BATCH_SIZE = 64

NUM_TIMESTEPS = 1000
BETA_START = 0.0001
BETA_END = 0.02

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: x*2 - 1),
])

In [None]:
def get_sum_of_random_numbers(samples, num_of_samples):
    selected_samples = list()
    for _ in range(num_of_samples):
        selected_samples.append(random.choice(samples))
    
    return np.sum(selected_samples)

In [None]:
class LinearNoiseScheduler():
    def __init__(self, num_timesteps, beta_start, beta_end, device):
        
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, num_timesteps).to(device)
        self.alphas = 1. - self.betas
        self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1. - self.alpha_cum_prod)
    
    def add_noise(self, original, noise, t):
        original_shape = original.shape
        batch_size = original_shape[0]
        sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t].reshape(batch_size)

        
        sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.view(-1, 1, 1, 1)
        
        return sqrt_alpha_cum_prod*original + sqrt_one_minus_alpha_cum_prod*noise
    
    def sample_prev_timestep(self, xt, noise_pred, t):
        x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod[t] * noise_pred)) / torch.sqrt(self.alpha_cum_prod[t]))
        x0 = torch.clamp(x0, -1, 1)
        
        mean = xt - ((self.betas[t]*noise_pred)/(self.sqrt_one_minus_alpha_cum_prod[t]))
        mean = mean/torch.sqrt(self.alphas[t])

        if t==0:
            return mean, x0
        else:
            variance = (1 - self.alpha_cum_prod[t-1])/(1. - self.alpha_cum_prod[t])
            variance = variance*self.betas[t]
            sigma = variance**0.5
            z = torch.randn(xt.shape).to(xt.device)

            return mean + sigma*z, x0

In [None]:
def get_time_embedding(time_steps, t_emb_dim):
    factor = 1000**((torch.arange(start=0, end=t_emb_dim//2, device=time_steps.device)/(t_emb_dim//2)))

    t_emb = time_steps[:, None].repeat(1, t_emb_dim//2)/factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)

    return t_emb

In [None]:
class DownBlock(torch.nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim, down_sample=True, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        self.resnet_conv_first = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )

        self.t_emb_layers = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.SiLU(),
                torch.nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])

        self.resnet_conv_second = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.GroupNorm(8, out_channels),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )

        self.attention_norms = torch.nn.ModuleList(
            [torch.nn.GroupNorm(8, out_channels)
             for _ in range(num_layers)]
        )
        
        self.attentions = torch.nn.ModuleList(
            [torch.nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
             for _ in range(num_layers)]
        )

        self.residual_input_conv = torch.nn.ModuleList(
            [
                torch.nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )

        self.down_sample_conv = torch.nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else torch.nn.Identity()
    
    def forward(self, x, t_emb):
        out = x
        for i in range(self.num_layers):
            
            # Resnet block of Unet
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            # Attention block of Unet
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
        out = self.down_sample_conv(out)
        return out


class MidBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.resnet_conv_first = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers+1)
            ]
        )
        self.t_emb_layers = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.SiLU(),
                torch.nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers + 1)
        ])

        self.resnet_conv_second = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.GroupNorm(8, out_channels),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers+1)
            ]
        )
        
        self.attention_norms = torch.nn.ModuleList(
            [torch.nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)]
        )
        
        self.attentions = torch.nn.ModuleList(
            [torch.nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)]
        )

        self.residual_input_conv = torch.nn.ModuleList(
            [
                torch.nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers+1)
            ]
        )
    
    def forward(self, x, t_emb):
        out = x
        
        # First resnet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)
        
        for i in range(self.num_layers):
            
            # Attention Block
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i+1](out)
            out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i+1](out)
            out = out + self.residual_input_conv[i+1](resnet_input)
        
        return out


class UpBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.up_sample = up_sample
        self.resnet_conv_first = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )

        self.t_emb_layers = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.SiLU(),
                torch.nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])

        self.resnet_conv_second = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.GroupNorm(8, out_channels),
                    torch.nn.SiLU(),
                    torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        
        self.attention_norms = torch.nn.ModuleList(
            [
                torch.nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)
            ]
        )
        
        self.attentions = torch.nn.ModuleList(
            [
                torch.nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )

        self.residual_input_conv = torch.nn.ModuleList(
            [
                torch.nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )

        self.up_sample_conv = torch.nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) if self.up_sample else torch.nn.Identity()
    
    def forward(self, x, out_down, t_emb):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim=1)
        
        out = x
        for i in range(self.num_layers):
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn

        return out



In [None]:
class UNet(torch.nn.Module):
    def __init__(self, im_channels):
        super(UNet, self).__init__()
        
        self.down_channels = [32, 64, 128, 256]
        self.mid_channels = [256, 256, 128]
        self.t_emb_dim = 128
        self.down_sample = [True, True, False]
        self.num_layers = 2

        self.t_proj = torch.nn.Sequential(
            torch.nn.Linear(self.t_emb_dim, self.t_emb_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )

        self.up_sample = list(reversed(self.down_sample))
        self.conv_in = torch.nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)

        self.downs = torch.nn.ModuleList()
        for i in range(len(self.down_channels)-1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim, down_sample=self.down_sample[i], num_heads=4, num_layers=self.num_layers))
        
        self.mids = torch.nn.ModuleList()
        for i in range(len(self.mid_channels)-1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim, num_heads=4, num_layers=self.num_layers))

        self.ups = torch.nn.ModuleList()
        for i in reversed(range(len(self.down_channels)-1)):
            self.ups.append(UpBlock(self.down_channels[i]*2, self.down_channels[i-1] if i!=0 else 16, self.t_emb_dim, up_sample=self.down_sample[i], num_heads=4, num_layers=self.num_layers))
        
        self.norm_out = torch.nn.GroupNorm(8, 16)
        self.conv_out = torch.nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
    
    def forward(self, x, t):
        out = self.conv_in(x)

        t_emb = get_time_embedding(t, self.t_emb_dim)
        t_emb = self.t_proj(t_emb)

        down_outs = list()
        for down in self.downs:
            down_outs.append(out)
            out = down(out, t_emb)

        for mid in self.mids:
            out = mid(out, t_emb)
        
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)

        out = self.norm_out(out)
        out = torch.nn.SiLU()(out)
        out = self.conv_out(out)

        return out        


In [None]:
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=dataset_transforms, download=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=dataset_transforms, download=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)

scheduler = LinearNoiseScheduler(num_timesteps=NUM_TIMESTEPS, beta_start=BETA_START, beta_end=BETA_END, device=device)

model = UNet(im_channels=1)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = torch.nn.MSELoss()


In [None]:
for epoch_idx in range(NUM_EPOCHS):
    losses = list()
    for im, _ in train_dataloader:
        optimizer.zero_grad()

        im = im.to(device)
        noise = torch.randn_like(im).to(device)
        t = torch.randint(0, NUM_TIMESTEPS, (im.shape[0],)).to(device)

        noisy_im = scheduler.add_noise(im, noise, t)

        noise_pred = model(noisy_im, t)

        loss = criterion(noise_pred, noise)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
    
    print(f"Epoch: {epoch_idx+1}, Train Loss: {np.mean(losses):.5f}")

        

In [None]:
class DiffusionModelSampler():
    def __init__(self, T, model, beta_start, beta_end, device):
        self.T = T
        self.model = model
        self.device = device

        self.beta = torch.linspace(beta_start, beta_end, T).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)


        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sigmas = torch.sqrt(1.0 - self.alpha_bar)
        self.lambdas = torch.log(self.sqrt_alpha_bar/self.sigmas)

    
    @torch.no_grad()
    def simple_sampling(self, scheduler, n_samples=1, image_channels=1, img_size=(28, 28)):
        xt = torch.randn((n_samples, image_channels, img_size[0], img_size[1]), device=self.device)

        for t in reversed(range(self.T)):
            noise_pred = self.model(xt, torch.as_tensor(t).unsqueeze(0).to(self.device))

            xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device))
        
        
        img = torch.clamp(xt, -1, 1).detach().cpu()
        img = (img+1)/2

        return img
    

    @torch.no_grad()
    def ddim_sampling(self, n_samples=1, image_channels=1, img_size=(32, 32), n_steps=50):
        step_size = self.T//n_steps
        xt = torch.randn((n_samples, image_channels, img_size[0], img_size[1]), device=self.device)

        for tau in range(n_steps):
            t = self.T - tau*step_size # compute the time t
            t_tensor = torch.ones(n_samples, dtype=torch.long, device=self.device)*t

            alpha_bar_t = self.alpha_bar[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            alpha_bar_prev = self.alpha_bar[t-step_size-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if t>step_size else torch.tensor(1.0).to(self.device)

            noise_pred = self.model(xt, t_tensor-1)
            x0_pred = (xt - torch.sqrt(1 - alpha_bar_t)*noise_pred)/torch.sqrt(alpha_bar_t)
            dir_xt = torch.sqrt(1 - alpha_bar_prev)*noise_pred

            xt = torch.sqrt(alpha_bar_prev)*x0_pred + dir_xt
        
        img = torch.clamp(xt, -1, 1).detach().cpu()
        img = (img+1)/2

        return img
    
    @torch.no_grad()
    def dpm_solver_sampling(self, n_samples=1, image_channels=1, img_size=(32, 32), n_steps=10):
        step_size = self.T//n_steps

        xt = torch.randn((n_samples, image_channels, img_size[0], img_size[1]), device=self.device)
        x_tilde = xt

        for tau in range(n_steps):
            t_prev = self.T - tau*step_size
            t_cur = max(t_prev-step_size, 1)

            lam_mid = (self.lambdas[t_prev - 1] + self.lambdas[t_cur - 1])/2.
            s_i = torch.argmin(torch.abs(self.lambdas - lam_mid)).item() + 1

            h = self.lambdas[t_cur - 1] - self.lambdas[t_prev - 1]

            t_prev_tensor = torch.full((n_samples,), t_prev, dtype=torch.long, device=self.device)

            u_i = (self.sqrt_alpha_bar[s_i - 1] / self.sqrt_alpha_bar[t_prev - 1])*x_tilde - self.sigmas[s_i - 1]*(torch.exp(h*0.5) - 1)*self.model(x_tilde, t_prev_tensor - 1)

            t_s_tensor = torch.full((n_samples,), s_i, dtype=torch.long, device=self.device)

            x_tilde = (self.sqrt_alpha_bar[t_cur - 1]/self.sqrt_alpha_bar[t_prev - 1]) * x_tilde - self.sigmas[t_cur - 1]*(torch.exp(h) - 1)*self.model(u_i, t_s_tensor - 1)
        

        img = torch.clamp(x_tilde, -1, 1).detach().cpu()
        img = (img+1)/2
        
        return img

    @torch.no_grad()
    def pndm_sampling(self, n_samples=1, image_channels=1, img_size=(28, 28), n_steps=50):
        step_size = self.T//n_steps
        time_steps = [self.T - i*step_size for i in range(n_steps)]
        
        if time_steps[-1] != 0:
            time_steps.append(0)

        xt = torch.randn((n_samples, image_channels, img_size[0], img_size[1]), device=self.device)

        eps_buffer = list()
        self.counter = 0

        for t, t_next in zip(time_steps[:-1], time_steps[1:]):
            if self.counter<3:
                xt, e_t = self._step_prx(xt, t, t_next)
            else:
                xt, e_t = self._step_plms(xt, t, t_next, eps_buffer)

            eps_buffer.append(e_t)

            if len(eps_buffer)>3:
                eps_buffer.pop(0)

            self.counter+=1


        img = torch.clamp(xt, -1, 1).detach().cpu()
        img = (img+1)/2
        
        return img
    
    def _step_prx(self, xt, t, t_next):
        delta = t - t_next
        tm = int(t - delta/2)
        
        t_vec = torch.full((xt.shape[0],), t, dtype=torch.long, device=self.device)
        tm_vec = torch.full((xt.shape[0],), tm, dtype=torch.long, device=self.device)
        tnext_vec = torch.full((xt.shape[0],), t_next, dtype=torch.long, device=self.device)

        e1 = self.model(xt, t_vec)
        x1 = self._phi(xt, e1, t, tm)
        e2 = self.model(x1, tm_vec)
        x2 = self._phi(xt, e2, t, tm)
        e3 = self.model(x2, tm_vec)
        x3 = self._phi(xt, e3, t, t_next)
        e4 = self.model(x3, tnext_vec)

        e_prime = (e1 + 2*e2 + 2*e3 +e4)/6.0
        x_next = self._phi(xt, e_prime, t, t_next)

        return x_next, e_prime
    
    def _step_plms(self, xt, t, t_next, eps_buffer):
        t_vec = torch.full((xt.shape[0],), t, dtype=torch.long, device=self.device)
        e_t = self.model(xt, t_vec)

        past = torch.stack([e_t, eps_buffer[-1], eps_buffer[-2], eps_buffer[-3]], dim=0)
        e_prime = (55*past[0] - 59*past[1] + 37*past[2] - 9*past[3])/24.0

        x_next = self._phi(xt, e_prime, t, t_next)

        return x_next, e_t
    
    def _phi(self, xt, eps, t, t_next):
        if t>0:
            ab_t = self.alpha_bar[t-1]
        else:
            ab_t = torch.tensor(1.0, device=self.device)
        
        if t_next>0:
            ab_next = self.alpha_bar[t_next - 1]
        else:
            ab_next = torch.tensor(1.0, device=self.device)
        
        denom = ab_t.sqrt() * (((1-ab_next).sqrt())*ab_t.sqrt() + ((1 - ab_t).sqrt())*ab_next.sqrt())

        return (ab_next.sqrt()/ab_t.sqrt())*xt - ((ab_next - ab_t)/denom)*eps

    


In [None]:
torch.save(model.state_dict(), "DDPM_trained_SL.pt")
dm_sampler = DiffusionModelSampler(NUM_TIMESTEPS, model, BETA_START, BETA_END, device)

In [None]:
generation_times = list()
for i in range(100):
    tic = time.time()
    gen_img = dm_sampler.simple_sampling(scheduler=scheduler)
    toc = time.time()
    generation_times.append(toc-tic)
    plt.imshow(gen_img[0, 0], cmap="gray")
    plt.savefig(f"gen_img/DDPM/img_{i}.png", dpi=600)


In [None]:
list_of_10_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=10) for _ in range(10)]
list_of_50_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=50) for _ in range(10)]

print(f"DDIM generation time (1 sample): {np.mean(generation_times):.2f}+-{np.std(generation_times):.4f}")
print(f"DDIM generation time (10 samples): {np.mean(list_of_10_selected ):.2f}+-{np.std(list_of_10_selected):.4f}")
print(f"DDIM generation time (50 samples): {np.mean(list_of_50_selected):.2f}+-{np.std(list_of_50_selected):.4f}")

In [None]:
generation_times = list()
for i in range(100):
    tic = time.time()
    gen_img = dm_sampler.ddim_sampling()
    toc = time.time()
    generation_times.append(toc-tic)
    plt.imshow(gen_img[0, 0], cmap="gray")
    plt.savefig(f"gen_img/DDIM/img_{i}.png", dpi=600)




In [None]:
list_of_10_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=10) for _ in range(10)]
list_of_50_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=50) for _ in range(10)]

print(f"DDIM generation time (1 sample): {np.mean(generation_times):.2f}+-{np.std(generation_times):.4f}")
print(f"DDIM generation time (10 samples): {np.mean(list_of_10_selected ):.2f}+-{np.std(list_of_10_selected):.4f}")
print(f"DDIM generation time (50 samples): {np.mean(list_of_50_selected):.2f}+-{np.std(list_of_50_selected):.4f}")

In [None]:
generation_times = list()
for i in range(100):
    tic = time.time()
    gen_img = dm_sampler.dpm_solver_sampling()
    toc = time.time()
    generation_times.append(toc-tic)
    plt.imshow(gen_img[0, 0], cmap="gray")
    plt.savefig(f"gen_img/DPM/img_{i}.png", dpi=600)



In [None]:
list_of_10_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=10) for _ in range(10)]
list_of_50_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=50) for _ in range(10)]

print(f"DDIM generation time (1 sample): {np.mean(generation_times):.2f}+-{np.std(generation_times):.4f}")
print(f"DDIM generation time (10 samples): {np.mean(list_of_10_selected ):.2f}+-{np.std(list_of_10_selected):.4f}")
print(f"DDIM generation time (50 samples): {np.mean(list_of_50_selected):.2f}+-{np.std(list_of_50_selected):.4f}")

In [None]:
generation_times = list()
for i in range(100):
    tic = time.time()
    gen_img = dm_sampler.pndm_sampling()
    toc = time.time()
    generation_times.append(toc-tic)
    plt.imshow(gen_img[0, 0], cmap="gray")
    plt.savefig(f"gen_img/PN/img_{i}.png", dpi=600)



In [None]:
list_of_10_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=10) for _ in range(10)]
list_of_50_selected = [get_sum_of_random_numbers(generation_times, num_of_samples=50) for _ in range(10)]

print(f"DDIM generation time (1 sample): {np.mean(generation_times):.2f}+-{np.std(generation_times):.4f}")
print(f"DDIM generation time (10 samples): {np.mean(list_of_10_selected ):.2f}+-{np.std(list_of_10_selected):.4f}")
print(f"DDIM generation time (50 samples): {np.mean(list_of_50_selected):.2f}+-{np.std(list_of_50_selected):.4f}")