In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np

In [None]:
# model

import math
from torch.nn import init

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb, freeze=False),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class ConditionalEmbedding(nn.Module):
    def __init__(self, num_labels, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        self.condEmbedding = nn.Sequential(
            nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.condEmbedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2)

    def forward(self, x, temb, cemb):
        x = self.c1(x) + self.c2(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.c = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.t = nn.ConvTranspose2d(in_ch, in_ch, 5, 2, 2, 1)

    def forward(self, x, temb, cemb):
        _, _, H, W = x.shape
        x = self.t(x)
        x = self.c(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h



class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=True):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.cond_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()


    def forward(self, x, temb, labels):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h += self.cond_proj(labels)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
        super().__init__()
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
 

    def forward(self, x, t, labels):
        # Timestep embedding
        temb = self.time_embedding(t)
        cemb = self.cond_embedding(labels)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb, cemb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb, cemb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb, cemb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

In [None]:
T = 1000
beta_1 = 1e-4
beta_T = 0.028
n_epochs = 100
learning_rate = 1e-3
w = 1.8
p_condi = 0.1

In [None]:
betas = torch.linspace(beta_1, beta_T, T).double()
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)
coeff1 = torch.sqrt(1.0 / alphas)
coeff2 = coeff1 * (1.0 - alphas) / torch.sqrt(1.0 - alphas_bar)
var = betas * (1.0 - F.pad(alphas_bar, [1, 0], value=1)[:T]) / (1.0 - alphas_bar)
sqrt_alphas_bar = torch.sqrt(alphas_bar)
sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - alphas_bar)
used_var = torch.cat([var[1:2], betas[1:]])
used_var1 = var

In [None]:
def extract(v, t, x_shape):
    device = t.device
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

In [None]:
import gc
def DDIM_sample(x_T, model, label, w, eta, simple_var, ddim_step):
    with torch.no_grad():
        model.eval()
        x_t = x_T
        device = x_t.device
        label0 = torch.zeros_like(label).to(label.device)
        ts = torch.linspace(T, 0, (ddim_step + 1)).to(device).to(torch.long)
        for i in range(1, ddim_step + 1):
            cur_t = ts[i - 1]
            prev_t = ts[i]
            alphas_bar_cur = alphas_bar[cur_t - 1]
            alphas_bar_prev = alphas_bar[prev_t - 1] if prev_t >= 1 else 1

            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * (cur_t - 1)
            epsilon = (1 + w) * model(x_t, t, label) - w * model(x_t, t, label0)
            var = eta * (1-alphas_bar_prev) / (1-alphas_bar_cur) * (1 - alphas_bar_cur/alphas_bar_prev)
            x_t = (alphas_bar_prev/alphas_bar_cur)**0.5 * x_t + ((1-alphas_bar_prev-var)**0.5 - (alphas_bar_prev*(1-alphas_bar_cur)/alphas_bar_cur)**0.5) * epsilon + torch.randn_like(x_t) * (var if simple_var == False else (1 - alphas_bar_cur/alphas_bar_prev))**0.5
            #gc.collect()
            #torch.cuda.empty_cache()
        return x_t

In [None]:
import gc
def sample(x_T, model, label, w):
    x_t = x_T
    device = x_t.device
    label0 = torch.zeros_like(label).to(label.device)
    for time_step in reversed(range(T)):
        if(time_step % 100 == 0):
            print(time_step)
        t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
        epsilon = (1 + w) * model(x_t, t, label) - w * model(x_t, t, label0)
        mean = extract(coeff1.to(device), t, x_t.shape) * x_t - extract(coeff2.to(device), t, x_t.shape) * epsilon
        var = extract(used_var1.to(device), t, x_t.shape)
        # no noise when t == 0
        if time_step > 0:
            noise = torch.randn_like(x_t)
        else:
            noise = 0
        x_t = mean + torch.sqrt(var) * noise
        assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        gc.collect()
        torch.cuda.empty_cache()
    return x_t

In [None]:
#reload
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2], num_res_blocks=2, dropout=0.15).to(device)
checkpoint = torch.load('./model_and_optimizer.pth')
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
#Random guidance-free Generation
num0 = 10
x = torch.randn((num0*10, 3, 32, 32)).to(device)

label = torch.tensor([i+1 for i in range(10) for _ in range(num0) ])
label = label.to(device)

with torch.no_grad():
    model.eval()
    x_0 = sample(x, model, label, w)

from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(x_0*0.5 + 0.5)
save_image(resized_image, './pictures/genera.png', nrow=num0)

In [None]:
#Conditional Generation
num0 = 8
x = torch.randn((num0*10, 3, 32, 32)).to(device)

label = torch.tensor([i+1 for i in range(10) for _ in range(num0) ])
label = label.to(device)

with torch.no_grad():
    model.eval()
    x_0 = sample(x, model, label, 0)

from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(x_0*0.5 + 0.5)
save_image(resized_image, './pictures/condi_genera.png', nrow=num0)

In [None]:
# denoising process
x_T = torch.randn((10, 3, 32, 32)).to(device)
label = torch.tensor([i+1 for i in range(10)]).to(device)
show_list = [900, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0]

with torch.no_grad():
    model.eval()
    
    x_t = x_T
    device = x_t.device
    label0 = torch.zeros_like(label).to(label.device)
    X = x_t
    for time_step in reversed(range(T)):

        t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
        epsilon = (1 + w) * model(x_t, t, label) - w * model(x_t, t, label0)
        mean = extract(coeff1.to(device), t, x_t.shape) * x_t - extract(coeff2.to(device), t, x_t.shape) * epsilon
        var = extract(used_var1.to(device), t, x_t.shape)
        # no noise when t == 0
        if time_step > 0:
            noise = torch.randn_like(x_t)
        else:
            noise = 0
        x_t = mean + torch.sqrt(var) * noise

        if(time_step in show_list):
            print(time_step)
            X = torch.cat((X, x_t), dim=0)

        assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        gc.collect()
        torch.cuda.empty_cache()

for j in range(10):
    for i in range(len(show_list)+1):
        if(i == 0 and j ==0):
            Y = X[0][None, :]
        else:
            Y = torch.cat((Y, X[i*10 + j][None, :]), dim=0)

from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(Y*0.5 + 0.5)
save_image(resized_image, './pictures/denoising.png', nrow=len(show_list)+1)

In [None]:
# compare different variance in DDIM
label = torch.tensor([1, 2, 3, 8, 9]).to(device)
x = torch.randn((5, 3, 32, 32)).to(device)
time_steps = 1000

with torch.no_grad():
    model.eval()
    for eta in [0, 0.2, 0.5, 1.0]:
        x_0 = DDIM_sample(x, model, label, w, eta, False, time_steps)
        if(eta == 0):
            X = x_0
        else:
            X = torch.cat((X, x_0), dim=0)

    x_0 = DDIM_sample(x, model, label, w, 1, True, time_steps)
    X = torch.cat((X, x_0), dim=0)

import matplotlib.pyplot as plt
grid = torchvision.utils.make_grid(X*0.5 + 0.5, nrow=5).to("cpu")
np_grid = grid.permute(1, 2, 0).numpy()
fig, ax = plt.subplots(figsize=(3, 3))
ax.imshow(np_grid)
ax.axis('off')
rows = 5
cols = 5
text_list = ['0', '0.2', '0.5', '1', r'$\hat{\sigma}$']
for i in range(rows):
    ax.text(-23, 35 * i + 17, text_list[i], verticalalignment='center', fontsize=12, color='black')
plt.savefig('./pictures/var.png')
plt.close(fig)

In [None]:
#Random guidance-free Generation
num0 = 8
x = torch.randn((num0*10, 3, 32, 32)).to(device)

label = torch.tensor([i+1 for i in range(10) for _ in range(num0) ])
label = label.to(device)

In [None]:
# different number of steps in DIMM
import time
start_time = time.time()
with torch.no_grad():
    model.eval()
    x_0 = DDIM_sample(x, model, label, w, 0, False, 100)

from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(x_0*0.5 + 0.5)
save_image(resized_image, './pictures/ddim_genera.png', nrow=num0)
print(time.time() - start_time)

In [None]:
# effect of x_T
num0 = 8
diff = 0.001

for k in range(10):
    
    original_x = torch.randn((1, 3, 32, 32)).to(device)
    x = original_x

    for i in range(num0):
        x0 = original_x[0]
        x0 = x[0] - diff*num0/2 + i*diff
        x = torch.cat((x, x0[None, :]), 0)

    x = x[1:]

    label = torch.tensor([k+1]*num0)
    label = label.to(device)

    with torch.no_grad():
        model.eval()
        # True means use simple variance beta_t in sampling, eta is 1
        x0 = DDIM_sample(x, model, label, w, 0, False, 10)
    
    if(k == 0):
        x_0 = x0
    else:
        x_0 = torch.cat((x_0, x0), 0)

from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(x_0*0.5 + 0.5)
save_image(resized_image, './pictures/effect.png', nrow=num0)

#not very sensitive to x_T, can try inverse function

In [None]:
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader, Subset

root = '../../data/CIFAR10/'
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, transform=transform, download=False)

num_samples = 30000
indices = np.random.choice(len(train_dataset), num_samples, replace=False)
subset_train_dataset = Subset(train_dataset, indices)

print("start deleting")
import shutil
save_dir = '../samples/original'
if os.path.exists(save_dir) and os.path.isdir(save_dir):
    shutil.rmtree(save_dir)
else:
    print(f"Directory does not exist.")
print("done deleting")

os.makedirs(save_dir, exist_ok=True)

print("start saving")
def save_images(dataset, save_dir):
    for idx, (image, label) in enumerate(dataset):
        if(idx%1000 == 0):
            print(idx)
        image = transforms.ToPILImage()(image)
        image.save(os.path.join(save_dir, f'image_{idx}.png'))
save_images(subset_train_dataset, save_dir)
print("done saving")
'''

In [None]:
# IS
metrics_dict = torch_fidelity.calculate_metrics(
    input1= '../samples/original',
    cuda=True,
    isc=True
)

In [None]:
import gc
def DDIM_fast_sample(x_T, model, label, eta, simple_var, ddim_step):
    with torch.no_grad():
        model.eval()
        x_t = x_T
        device = x_t.device
        ts = torch.linspace(T, 0, (ddim_step + 1)).to(device).to(torch.long)
        for i in range(1, ddim_step + 1):
            cur_t = ts[i - 1]
            prev_t = ts[i]
            alphas_bar_cur = alphas_bar[cur_t - 1]
            alphas_bar_prev = alphas_bar[prev_t - 1] if prev_t >= 1 else 1

            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * (cur_t - 1)
            epsilon = model(x_t, t, label)
            var = eta * (1-alphas_bar_prev) / (1-alphas_bar_cur) * (1 - alphas_bar_cur/alphas_bar_prev)
            x_t = (alphas_bar_prev/alphas_bar_cur)**0.5 * x_t + ((1-alphas_bar_prev-var)**0.5 - (alphas_bar_prev*(1-alphas_bar_cur)/alphas_bar_cur)**0.5) * epsilon + torch.randn_like(x_t) * (var if simple_var == False else (1 - alphas_bar_cur/alphas_bar_prev))**0.5
        return x_t

In [None]:
import time
start_time = time.time()

with torch.no_grad():
    model.eval()
        
    for i in range(100):
        print(i)

        num_samples = 300
        x = torch.randn((num_samples, 3, 32, 32)).to(device)
        label = torch.randint(10, (x.shape[0], )) + 1
        label = label.to(device)
        
        x0 = DDIM_fast_sample(x, model, label, 1, False, 1000)
        x0 = x0*0.5 +0.5

        if(i == 0):
            x_0 = x0
        else:
            x_0 = torch.cat((x_0, x0), 0)
        
print(time.time() - start_time)

In [None]:
from torchvision.transforms import ToPILImage

save_dir = '../samples/generated'

print("start deleting")
import shutil
if os.path.exists(save_dir) and os.path.isdir(save_dir):
    shutil.rmtree(save_dir)
else:
    print(f"Directory does not exist.")
print("done deleting")

os.makedirs(save_dir, exist_ok=True)
to_pil = ToPILImage()

print("start saving")
for i in range(x_0.size(0)):
    if(i % 10000 == 0):
        print(i)
    image_tensor = x_0[i]
    image = to_pil(image_tensor)
    save_path = os.path.join(save_dir, f'image_{i}.png')
    image.save(save_path)
print("done saving")

In [None]:
import torch_fidelity

In [None]:
# IS and FID
metrics_dict = torch_fidelity.calculate_metrics(
    input1= '../samples/generated',
    input2= '../samples/original',
    cuda=True,
    isc=True,
    fid=True
)

In [None]:
with open("output.txt", "w") as file:
    # Write the output to the file
    file.write(str(metrics_dict['inception_score_mean']) + '\n' + str(metrics_dict['frechet_inception_distance']))