In [95]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

bs = 100
# MNIST Dataset
#train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=True)
train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST(root='/root/data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [96]:
import torch
import torch.nn as nn

beta = 1.5

class BVAE(nn.Module):

    def __init__(self, back_kernels = [3, 2, 3, 3], channels = [1, 8, 32, 128, 256], latent_dim = 10, beta = beta) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        modules = []
        img_length = 28
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length,
                                     latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length,
                                    latent_dim)
        self.latent_dim = latent_dim
        
        # decoder
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)
        for i in range(len(back_kernels)):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(back_kernels)-i],
                                       channels[len(back_kernels)-i-1],
                                       kernel_size=back_kernels[i],
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(back_kernels)-i-1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)

    def forward(self, x):
        encoded = self.encoder(x)
        encoded = torch.flatten(encoded, 1)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded, mean, logvar

    def sample(self, device):
        z = torch.randn(1, self.latent_dim).to(device)
        x = self.decoder_projection(z)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = self.decoder(x)
        return decoded


In [101]:
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F

n_epochs = 50
kl_weight = 0.00025
lr = 0.005

def loss_fn(x, x_hat, mean, logvar):
    recons_loss = F.mse_loss(x_hat, x)
    kl_loss = torch.mean(
        -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
    loss = recons_loss + beta * kl_loss * kl_weight
    return loss

def loss_mse(x, x_hat, mean, logvar):
    loss = F.mse_loss(x_hat, x)
    return loss

def train(device, model, beta):
    print(beta)
    optimizer = torch.optim.Adam(model.parameters(), lr)
    begin_time = time()
    # train
    loss_list = []
    for i in range(n_epochs):
        loss_sum = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.to(device)
            x_hat, mean, logvar = model(x)
            loss = loss_fn(x, x_hat, mean, logvar)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(i, batch_idx * len(x), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item() / len(x)))
        print('====> Epoch: {} Average loss: {:.4f}'.format(i, loss_sum / len(train_loader.dataset)))
        loss_list.append(loss_sum / len(train_loader.dataset))
        training_time = time() - begin_time
        minute = int(training_time // 60)
        second = int(training_time % 60)
        print(f'time loss {minute}:{second}')
        #torch.save(model.state_dict(), '/root/BVAE/BVAE_model.pth')
    tot_training_time = time() - begin_time
    minute = int(tot_training_time // 60)
    second = int(tot_training_time % 60)
    print(f'total time loss {minute}:{second}')
    #plot epoch
    '''
    epochs = range(1, len(loss_list) + 1)
    plt.plot(epochs, loss_list, marker='o', linestyle='-')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss per Epoch')
    plt.grid(True)
    plt.savefig('/root/BVAE/pictures/training.png')
    plt.show()
    '''
    return loss_list[-1]
        
def generate(device, model):
    model.eval()
    num = 64
    z = torch.randn(num, 10).to(device)
    x = model.decoder_projection(z)
    x = torch.reshape(x, (-1, 256, 2, 2))
    sample = model.decoder(x).to(device)
    save_image(sample.view(num, 1, 28, 28), '/root/BVAE/pictures/sample.png')

from PIL import Image, ImageDraw, ImageFont

def reconstruct(device, model):
    model.eval()
    num = 8
    image_list = []
    for i in range(num):
        image_tensor, _ = train_dataset[i]
        image_list.append(image_tensor)
    batch = torch.stack(image_list)
    x = batch.to(device)
    output = model(x)[0]
    output = output.detach().cpu()
    input = batch.detach().cpu()
    concatenated_images = torch.cat((input, output), dim=2)
    save_image(concatenated_images, '/root/BVAE/pictures/compare.png')
    original_image = Image.open('/root/BVAE/pictures/compare.png')
    original_width, original_height = original_image.size
    new_width = original_width + 100
    new_height = original_height
    new_image = Image.new("RGB", (new_width, new_height), color=(255, 255, 255))
    new_image.paste(original_image, (100, 0))
    draw = ImageDraw.Draw(new_image)
    title_height = original_height // 2
    draw.text((10, 10), "Input", fill=(0, 0, 0))
    draw.text((10, title_height + 10), "BVAE", fill=(0, 0, 0))
    new_image.save('/root/BVAE/pictures/compare.png')
    
import matplotlib.pyplot as plt
import numpy as np

def latent_space(device, model):
    model.eval()
    num = 1000
    image_list = []
    label_list = []
    for i in range(num):
        image_tensor, label = train_dataset[i]
        image_list.append(image_tensor)
        label_list.append(label)
    batch = torch.stack(image_list)
    x = batch.to(device)
    _, mean, logvar = model(x)
    eps = torch.randn_like(logvar)
    std = torch.exp(logvar / 2)
    z = eps * std + mean
    z_cpu = z.cpu().detach().numpy()
    label_list = np.array(label_list)
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'brown']
    unique_labels = {}
    for i in range(len(label_list)):
        label = label_list[i]
        if label not in unique_labels:
            unique_labels[label] = colors[label]
        plt.scatter(z_cpu[i][0], z_cpu[i][1], color=colors[label], s=10)
    legend_handles = []
    for label, color in unique_labels.items():
        legend_handles.append(plt.Line2D([0], [0], marker='o', color='w', label=f'{label}', markerfacecolor=color, markersize=10))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Scatter Plot with Different Colors')
    plt.legend(handles=legend_handles)
    plt.savefig('/root/BVAE/pictures/latent_space.png')
    plt.show()

def check_bvae_and_vae():
    #check the difference between beta-vae and vae
    device = 'cuda:0'
    model = BVAE().to(device)
    model.load_state_dict(torch.load('/root/BVAE/check_VAE_model.pth', 'cuda:0'))
    #model.load_state_dict(torch.load('/root/BVAE/check_BVAE_model.pth', 'cuda:0'))
    model.eval()
    num = 8
    latent_dim = 10
    z = torch.randn(num, latent_dim)
    reduce_len = 0.5
    for dim in range(latent_dim):
        for i in range(latent_dim):
            if i != dim:
                z[:, i] = 0
        z[:, dim] = 2
        for i in range(z.shape[0]):
            z[i, 1] -= (i+1)*reduce_len
        z = z.to(device)
        x = model.decoder_projection(z)
        x = torch.reshape(x, (-1, 256, 2, 2))
        sample = model.decoder(x).to(device)
        img0 = sample.view(num, 1, 28, 28)
        if(dim != 0):
            img = torch.cat((img, img0), dim=2)
        else:
            img = img0
    save_image(img, '/root/BVAE/pictures/check_vae.png')
    #save_image(img, '/root/BVAE/pictures/check_beta.png')

def explore_beta():
    betas = [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75]
    #betas = [1.0, 1.25, 1.5]
    loss_list = []
    for i in range(len(betas)):
        beta = betas[i]
        device = 'cuda:0'
        model = BVAE(beta = beta).to(device)
        train(device, model, beta)
        #loss
        loss_sum = 0
        model.eval()
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.to(device)
            x_hat, mean, logvar = model(x)
            loss_sum += loss_mse(x, x_hat, mean, logvar).item()
        loss_list.append(loss_sum / len(train_loader.dataset))
    #plot epoch
    print(betas)
    print(loss_list)
    plt.plot(betas, loss_list, marker='o', linestyle='-')
    plt.xlabel('beta')
    plt.ylabel('MSE Loss')
    plt.title('Loss per beta')
    plt.grid(True)
    plt.xlim(0.75, max(betas)+0.25)
    plt.savefig('/root/BVAE/pictures/loss_vs_beta.png')
    plt.show()

In [None]:
def main():
    device = 'cuda:0'

    model = BVAE(beta = 1.5).to(device)
S
    # Load the model
    #model.load_state_dict(torch.load('/root/BVAE/BVAE_model.pth', 'cuda:0'))

    #Choose which to play
    #train(device, model, 1.5)
    #generate(device, model)
    #reconstruct(device, model)
    #latent_space(device, model)
    #check_bvae_and_vae()
    #explore_beta()
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))


if __name__ == '__main__':
    main()