In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

bs = 100
# MNIST Dataset
#train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=True)
train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST(root='/root/data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [23]:
import torch
import torch.nn as nn

latent_dim = 10

class VAE(nn.Module):

    def __init__(self, back_kernels = [3, 2, 3, 3], channels = [1, 8, 32, 128, 256], latent_dim = latent_dim) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(back_kernels)-i],
                                       channels[len(back_kernels)-i-1],
                                       kernel_size=back_kernels[i],
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(back_kernels)-i-1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)

    def forward(self, x):
        encoded = self.encoder(x)
        encoded = torch.flatten(encoded, 1)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded, mean, logvar

    def sample(self, device):
        z = torch.randn(1, self.latent_dim).to(device)
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded

beta = 1.5

class BVAE(nn.Module):

    def __init__(self, back_kernels = [3, 2, 3, 3], channels = [1, 8, 32, 128, 256], latent_dim = 10, beta = beta) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(back_kernels)-i],
                                       channels[len(back_kernels)-i-1],
                                       kernel_size=back_kernels[i],
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(back_kernels)-i-1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)

    def forward(self, x):
        encoded = self.encoder(x)
        encoded = torch.flatten(encoded, 1)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded, mean, logvar

    def sample(self, device):
        z = torch.randn(1, self.latent_dim).to(device)
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded
    
class CVAE(nn.Module):

    def __init__(self, back_kernels = [3, 2, 3, 3], channels = [1, 8, 32, 128, 256], latent_dim = 10, con_dim = 1) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length + con_dim,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length + con_dim,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim + con_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(back_kernels)-i],
                                       channels[len(back_kernels)-i-1],
                                       kernel_size=back_kernels[i],
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(back_kernels)-i-1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)

    def forward(self, x, c):
        encoded = self.encoder(x)
        encoded = torch.flatten(encoded, 1)
        encoded = torch.cat([encoded, c], dim = 1)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        x = self.decoder_projection(torch.cat([z, c], dim = 1))
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded, mean, logvar

    def sample(self, device, c):
        z = torch.randn(1, self.latent_dim).to(device)
        x = self.decoder_projection(torch.cat([z, c], dim = 1))
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded

class VQVAE(nn.Module):

    def __init__(self, back_kernels = [3, 3], channels = [1, 64, 256], latent_dim = 2, n_embedding = 10) -> None:
        super().__init__()
        
        
        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        modules.append(
            nn.Sequential(
                nn.Conv2d(pre_channel,
                          latent_dim,
                          kernel_size=3,
                          stride=1,
                          padding=1),
            )
        )
        self.encoder = nn.Sequential(*modules)
        self.vq_embedding = nn.Embedding(n_embedding, latent_dim)
        self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding, 1.0 / n_embedding)
        
        # decoder
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(latent_dim,
                          pre_channel,
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.BatchNorm2d(pre_channel),
                nn.ReLU()
            )
        )
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(back_kernels)-i],
                                       channels[len(back_kernels)-i-1],
                                       kernel_size=back_kernels[i],
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(back_kernels)-i-1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)
        

    def forward(self, x):
        ze = self.encoder(x)
        # ze: [Num, Channel, Height, Width]
        # embedding: [emb_num, Channel]
        embedding = self.vq_embedding.weight.data
        Num, Channel, Height, Width = ze.shape
        emb_num, _ = embedding.shape
        embedding_broad = embedding.reshape(1, emb_num, Channel, 1, 1)
        ze_broad = ze.reshape(Num, 1, Channel, Height, Width)
        distance = torch.sum((embedding_broad - ze_broad)**2, 2)
        nearest_nei = torch.argmin(distance, 1)
        zq = self.vq_embedding(nearest_nei).permute(0, 3, 1, 2)
        decoder_input = ze + (zq - ze).detach()
        # decode
        x_hat = self.decoder(decoder_input)
        return x_hat, ze, zq

In [26]:
from time import time
from PIL import Image, ImageDraw, ImageFont
import torch
import torch.nn as nn
import torch.nn.functional as F

def reconstructs(device):
    
    num = 8
    image_list = []
    label_list = []
    for i in range(num):
        image_tensor, label = train_dataset[i]
        image_list.append(image_tensor)
        label_list.append(label)
    label_list = torch.tensor(label_list).view(-1, 1)
    image_batch = torch.stack(image_list)
    x = image_batch.to(device)
    c = label_list.to(device)
    #VAE
    model = VAE(latent_dim = 10).to(device)
    model.load_state_dict(torch.load('/root/VAE/VAE_model.pth', 'cuda:0'))
    model.eval()
    output = model(x)[0]
    output = output.detach().cpu()
    input = image_batch.detach().cpu()
    concatenated_images = torch.cat((input, output), dim=2)
    #BVAE
    model = BVAE(beta = 1.5).to(device)
    model.load_state_dict(torch.load('/root/BVAE/BVAE_model.pth', 'cuda:0'))
    model.eval()
    output = model(x)[0]
    output = output.detach().cpu()
    concatenated_images = torch.cat((concatenated_images, output), dim=2)
    #CVAE
    model = CVAE().to(device)
    model.load_state_dict(torch.load('/root/CVAE/CVAE_model.pth', 'cuda:0'))
    model.eval()
    output = model(x, c)[0]
    output = output.detach().cpu()
    concatenated_images = torch.cat((concatenated_images, output), dim=2)
    #VQVAE
    model = VQVAE().to(device)
    model.load_state_dict(torch.load('/root/VQVAE/VQVAE_model.pth', 'cuda:0'))
    model.eval()
    output = model(x)[0]
    output = output.detach().cpu()
    concatenated_images = torch.cat((concatenated_images, output), dim=2)
    
    save_image(concatenated_images, '/root/VAE/pictures/compare.png')
    num_pic = 5
    original_image = Image.open('/root/VAE/pictures/compare.png')
    original_width, original_height = original_image.size
    new_width = original_width + 100
    new_height = original_height
    new_image = Image.new("RGB", (new_width, new_height), color=(255, 255, 255))
    new_image.paste(original_image, (100, 0))
    draw = ImageDraw.Draw(new_image)
    title_height = original_height // num_pic
    draw.text((10, 10), "Input", fill=(0, 0, 0))
    draw.text((10, title_height + 10), "VAE", fill=(0, 0, 0))
    draw.text((10, title_height*2 + 10), "BVAE", fill=(0, 0, 0))
    draw.text((10, title_height*3 + 10), "CVAE", fill=(0, 0, 0))
    draw.text((10, title_height*4 + 10), "VQVAE", fill=(0, 0, 0))
    new_image.save('/root/VAE/pictures/compare.png')

In [27]:
def main():
    device = 'cuda:0'

    # Load the model
    #model.load_state_dict(torch.load('/root/VAE/VAE_model.pth', 'cuda:0'))

    #Choose which to play
    #train(device, model, 10)
    #generate(device, model)
    #reconstruct(device, model)
    #latent_space(device, model)
    #explore_latent()
    #print(sum(p.numel() for p in model.parameters() if p.requires_grad))
    
    reconstructs(device)

if __name__ == '__main__':
    main()