In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

root = '/root/data/MNIST/'

bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST(root=root, train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim, dim, 1)
        self.batchnorm2d = nn.BatchNorm2d(dim)

    def forward(self, x):
        tmp = self.conv1(x)
        tmp = self.batchnorm2d(tmp)
        tmp = self.relu(tmp)
        tmp = self.conv2(tmp)
        tmp = self.batchnorm2d(tmp)
        tmp = x + tmp
        tmp = self.relu(tmp)
        return tmp

class CVAE(nn.Module):

    def __init__(self, condi_dim, channels, latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder_projection = nn.Sequential(
                nn.Linear(pre_channel * img_length * img_length + condi_dim, pre_channel * img_length * img_length),
                nn.ReLU()
        )
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim + condi_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(channels)-i-1],
                                       channels[len(channels)-i-2],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(channels)-i-2]),
                    nn.ReLU()
                )
            )
        self.decoder_layers = nn.Sequential(*modules)
        
    def decoder(self, z, c):
        z = torch.cat([z, c], dim = 1)
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded

    def forward(self, x, c):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = torch.cat([x, c], dim = 1)
        encoded = self.encoder_projection(x)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z, c)
        return decoded, mean, logvar

In [None]:
from time import time
import numpy as np

n_epochs = 200
learning_rate = 1e-2
kl_weight = 1e-5
num_class = 11
p_condi = 0.1

def one_hot_encode(labels, num_classes):
    one_hot_labels = np.zeros((len(labels), num_classes))
    one_hot_labels[np.arange(len(labels)), labels] = 1
    return one_hot_labels.tolist()

def elbo_loss(x, x_hat, mean, logvar):
    recons_loss = F.mse_loss(x_hat, x)
    kl_loss = -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar))
    loss = (recons_loss + kl_loss * kl_weight)/len(x)
    return loss

def reconst_loss(x, x_hat):
    return F.mse_loss(x_hat, x) / len(x)

def kl_loss(mean, logvar):
    return -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar)) / len(mean)

def train(device, model):
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    begin_time = time()
    # train
    with open('./loss.txt', 'w') as file:
        for i in range(n_epochs):
            for batch_idx, (x, label) in enumerate(train_loader):
                x = x.to(device)
                if np.random.rand() < p_condi:
                    label = torch.zeros_like(label)
                else:
                    label = label+1
                label = one_hot_encode(label, num_classes = num_class)
                label = torch.tensor(label)
                label = label.to(device)
                x_hat, mean, logvar = model(x, label)
                loss = elbo_loss(x, x_hat, mean, logvar)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            # estimate loss
            model.eval()
            with torch.no_grad():
                each_epoch = 10
                n_samples = 100
                #train
                indices = torch.randperm(len(train_dataset))[:n_samples]
                x = torch.stack([train_dataset[i][0] for i in indices]).to(device)
                label = one_hot_encode([train_dataset[i][1]+1 for i in indices], num_classes = num_class)
                label = torch.tensor(label).to(device)
                x_hat, mean, logvar = model(x, label)
                loss = elbo_loss(x, x_hat, mean, logvar)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} kl loss: {:.7f}'.format(i, loss_kl))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + ' ')
                #test
                indices = torch.randperm(len(test_dataset))[:n_samples]
                x = torch.stack([test_dataset[i][0] for i in indices]).to(device)
                label = one_hot_encode([test_dataset[i][1]+1 for i in indices], num_classes = num_class)
                label = torch.tensor(label).to(device)
                x_hat, mean, logvar = model(x, label)
                loss = elbo_loss(x, x_hat, mean, logvar)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} kl loss: {:.7f}'.format(i, loss_kl))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + '\n')
            
            #time
            if(i % each_epoch == 0):
                training_time = time() - begin_time
                minute = int(training_time // 60)
                second = int(training_time % 60)
                print(f'time loss {minute}:{second}')
        
        torch.save(model, './model.pth')
        
    tot_training_time = time() - begin_time
    minute = int(tot_training_time // 60)
    second = int(tot_training_time % 60)
    print(f'total time loss {minute}:{second}')

def initialize_parameters(model):
    for param in model.parameters():
        param.data.normal_(mean=0, std=0.01)

In [None]:
def main():
    device = 'cuda:0'

    model = CVAE(condi_dim = num_class, channels = [1, 500, 500], latent_dim = 500).to(device)
    initialize_parameters(model)

    train(device, model)

if __name__ == '__main__':
    main()

In [None]:
device = 'cpu'
model = torch.load('./model.pth', map_location=device)

In [None]:
latent_dim = 500

In [None]:
import torchvision
from torchvision.utils import save_image
# Generation
w = 0
import random
with torch.no_grad():    
    
    noise = torch.randn(100, latent_dim).to(device)
    c = [i+1 for i in range(10) for _ in range(10)]
    c0 = [0]*100
    c = one_hot_encode(c, num_classes = num_class)
    c0 = one_hot_encode(c0, num_classes = num_class)
    c = torch.tensor(c).to(device)
    c0 = torch.tensor(c0).to(device)
    generated_imgs = (1+w)*model.decoder(noise, c) - w*model.decoder(noise, c0)
    resized_image = torchvision.transforms.Resize((50, 50))(generated_imgs)
    save_image(resized_image, './pictures/genera.png', nrow=10)