In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

root = '../../data/CIFAR10/'

# Load the CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, transform=transform, download=False)
test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, transform=transform, download=False)

bs = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=True)

In [None]:
# model

import math
from torch.nn import init

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb, freeze=False),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class ConditionalEmbedding(nn.Module):
    def __init__(self, num_labels, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        self.condEmbedding = nn.Sequential(
            nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.condEmbedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2)

    def forward(self, x, temb, cemb):
        x = self.c1(x) + self.c2(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.c = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.t = nn.ConvTranspose2d(in_ch, in_ch, 5, 2, 2, 1)

    def forward(self, x, temb, cemb):
        _, _, H, W = x.shape
        x = self.t(x)
        x = self.c(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h



class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=True):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.cond_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()


    def forward(self, x, temb, labels):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h += self.cond_proj(labels)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
        super().__init__()
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
 

    def forward(self, x, t, labels):
        # Timestep embedding
        temb = self.time_embedding(t)
        cemb = self.cond_embedding(labels)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb, cemb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb, cemb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb, cemb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

In [None]:
T = 1000
beta_1 = 1e-4
beta_T = 0.028
n_epochs = 500
learning_rate = 1e-3
w = 1.8
p_condi = 0.1

In [None]:
betas = torch.linspace(beta_1, beta_T, T).double()
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)
coeff1 = torch.sqrt(1.0 / alphas)
coeff2 = coeff1 * (1.0 - alphas) / torch.sqrt(1.0 - alphas_bar)
var = betas * (1.0 - F.pad(alphas_bar, [1, 0], value=1)[:T]) / (1.0 - alphas_bar)
sqrt_alphas_bar = torch.sqrt(alphas_bar)
sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - alphas_bar)
used_var = torch.cat([var[1:2], betas[1:]])
used_var1 = var

In [None]:
def extract(v, t, x_shape):
    device = t.device
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

In [None]:
def loss(epsilon, estimated_epsilon):
    return torch.sum((epsilon - estimated_epsilon)**2)/len(epsilon)

In [None]:
import gc
def sample(x_T, model, label, w):
    x_t = x_T
    device = x_t.device
    label0 = torch.zeros_like(label).to(label.device)
    for time_step in reversed(range(T)):
        if(time_step % 100 == 0):
            print(time_step)
        t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
        epsilon = (1 + w) * model(x_t, t, label) - w * model(x_t, t, label0)
        mean = extract(coeff1.to(device), t, x_t.shape) * x_t - extract(coeff2.to(device), t, x_t.shape) * epsilon
        var = extract(used_var1.to(device), t, x_t.shape)
        # no noise when t == 0
        if time_step > 0:
            noise = torch.randn_like(x_t)
        else:
            noise = 0
        x_t = mean + torch.sqrt(var) * noise
        assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        gc.collect()
        torch.cuda.empty_cache()
    return x_t

In [None]:
import time

def train(device, model, optimizer):
    
    begin_time = time.time()
    # train
    with open('./loss.txt', 'w') as file:
        for i in range(n_epochs):
            for batch_idx, (x, label) in enumerate(train_loader):
                optimizer.zero_grad()
                x = x.to(device)
                label = label.to(device) + 1
                if np.random.rand() < p_condi:
                    label = torch.zeros_like(label).to(device)
                epsilon = torch.randn_like(x).to(device)
                t = torch.randint(T, (x.shape[0], )).to(device)
                x_t = extract(sqrt_alphas_bar.to(device), t, x.shape) * x + extract(sqrt_one_minus_alphas_bar.to(device), t, x.shape) * epsilon
                estimated_epsilon = model(x_t, t, label)
                loss_0 = loss(epsilon, estimated_epsilon)
                loss_0.backward()
                optimizer.step()
            
            with torch.no_grad():
                model.eval()
                each_epoch = 1
                n_samples = 100
                #train
                indices = torch.randperm(len(train_dataset))[:n_samples]
                x = torch.stack([train_dataset[i][0] for i in indices]).to(device)
                label = torch.tensor([train_dataset[i][1] for i in indices]).to(device) + 1
                if np.random.rand() < p_condi:
                    label = torch.zeros_like(label).to(device)
                t = torch.randint(T, (x.shape[0], )).to(device)
                epsilon = torch.randn_like(x).to(device)
                x_t = extract(sqrt_alphas_bar.to(device), t, x.shape) * x + extract(sqrt_one_minus_alphas_bar.to(device), t, x.shape) * epsilon
                estimated_epsilon = model(x_t, t, label)
                loss_0 = loss(epsilon, estimated_epsilon)
                if(i % each_epoch == 0):
                    print("epoch: ", i, ", training loss: ", loss_0.item())
                file.write(str(loss_0.item()) + ' ')
                #test
                indices = torch.randperm(len(test_dataset))[:n_samples]
                x = torch.stack([test_dataset[i][0] for i in indices]).to(device)
                label = torch.tensor([test_dataset[i][1] for i in indices]).to(device) + 1
                if np.random.rand() < p_condi:
                    label = torch.zeros_like(label).to(device)
                t = torch.randint(T, (x.shape[0], )).to(device)
                epsilon = torch.randn_like(x).to(device)
                x_t = extract(sqrt_alphas_bar.to(device), t, x.shape) * x + extract(sqrt_one_minus_alphas_bar.to(device), t, x.shape) * epsilon
                estimated_epsilon = model(x_t, t, label)
                loss_0 = loss(epsilon, estimated_epsilon)
                if(i % each_epoch == 0):
                    print("epoch: ", i, ", testing loss: ", loss_0.item())
                file.write(str(loss_0.item()) + '\n')

            if(i % each_epoch == 0):
                training_time = time.time() - begin_time
                minute = int(training_time // 60)
                second = int(training_time % 60)
                print(f'time loss {minute}:{second}')
            
            if(i % 10 == 0 or i == n_epochs-1):
                torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),}, 
                './model_and_optimizer.pth'
                )
                
                

In [None]:

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    
    '''
    #start
    model = UNet(T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2], num_res_blocks=2, dropout=0.15).to(device)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    '''
    
    #round 1: 100 150min
    #round 2: 200 300min
    #round 3: 200 300min
    #round 4: 500 20hr
    #round 5: 500 20hr
    #round 6: 500 9hr
    #round 7: 500 20hr
    #totally: 81.5 hr
    
    #keep training
    model = UNet(T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2], num_res_blocks=2, dropout=0.15).to(device)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    checkpoint = torch.load('./model_and_optimizer.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    
    train(device, model, optimizer)

if __name__ == '__main__':
    main()


In [None]:
#reload
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(T=1000, num_labels=10, ch=128, ch_mult=[1, 2, 2, 2], num_res_blocks=2, dropout=0.15).to(device)
checkpoint = torch.load('./model_and_optimizer.pth')
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
#Random guidance-free Generation
x = torch.randn((49, 3, 32, 32)).to(device)

label = torch.randint(10, (x.shape[0], )) + 1
label = label.to(device)

with torch.no_grad():
    model.eval()
    x_0 = sample(x, model, label, w)

from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(x_0*0.5 + 0.5)
save_image(resized_image, './pictures/genera.png', nrow=7)

In [None]:
#Guidance-free Generation
num0 = 3
x = torch.randn((num0*10, 3, 32, 32)).to(device)
label = torch.tensor([i+1 for i in range(10) for _ in range(num0)]).to(device)

with torch.no_grad():
    model.eval()
    x_0 = sample(x, model, label, w)
    
from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(x_0*0.5 + 0.5)
save_image(resized_image, './pictures/guid_genera.png', nrow=num0)

In [None]:
#Conditional Generation
num0 = 3
x = torch.randn((num0*10, 3, 32, 32)).to(device)
label = torch.tensor([i+1 for i in range(10) for _ in range(num0)]).to(device)

with torch.no_grad():
    model.eval()
    x_0 = sample(x, model, label, 0)
    
from torchvision.utils import save_image
resized_image = torchvision.transforms.Resize((50, 50))(x_0*0.5 + 0.5)
save_image(resized_image, './pictures/condi_genera.png', nrow=num0)