In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST(root='/root/data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [2]:
import torch
import torch.nn as nn

class VAE(nn.Module):

    def __init__(self, latent_dim, hiddens) -> None:
        super().__init__()

        # encoder
        pre_channel = 1
        img_length = 28
        #mean
        modules = []
        pre_hidden = pre_channel * img_length * img_length
        for i in range(len(hiddens)):
            modules.append(
                nn.Sequential(
                    nn.Linear(pre_hidden, hiddens[i]),
                    nn.ReLU(),
                )
            )
            pre_hidden = hiddens[i]
        modules.append(nn.Linear(pre_hidden, latent_dim))
        self.mean_encoder = nn.Sequential(*modules)
        #logvar
        modules = []
        pre_hidden = pre_channel * img_length * img_length
        for i in range(len(hiddens)):
            modules.append(
                nn.Sequential(
                    nn.Linear(pre_hidden, hiddens[i]),
                    nn.ReLU(),
                )
            )
            pre_hidden = hiddens[i]
        modules.append(nn.Linear(pre_hidden, latent_dim))
        self.logvar_encoder = nn.Sequential(*modules)
        self.latent_dim = latent_dim
        
        # decoder
        #mean
        modules = []
        pre_hidden = latent_dim
        for i in range(len(hiddens)):
            modules.append(
                nn.Sequential(
                    nn.Linear(pre_hidden, hiddens[len(hiddens)-i-1]),
                    nn.ReLU(),
                )
            )
            pre_hidden = hiddens[len(hiddens)-i-1]
        modules.append(nn.Linear(pre_hidden, pre_channel * img_length * img_length))
        self.decoder_linear = nn.Sequential(*modules)
        self.decoder_input_chw = (pre_channel, img_length, img_length)
    
    def encoder(self, x):
        x = torch.flatten(x, 1)
        mean = self.mean_encoder(x)
        logvar = self.logvar_encoder(x)
        return mean, logvar
        
    def decoder(self, x):
        x = self.decoder_linear(x)
        x = torch.reshape(x, (-1, *self.decoder_input_chw))
        decoded = torch.sigmoid(x)
        return decoded

    def forward(self, x):
        mean, logvar = self.encoder(x)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        decoded = self.decoder(z)
        return decoded, mean, logvar

    def sample(self, device):
        z = torch.randn(1, self.latent_dim).to(device)
        decoded = self.decoder(z)
        return decoded


In [3]:
from time import time

n_epochs = 2000
learning_rate = 1e-2
beta = 5

pi = torch.tensor(torch.pi)

def elbo_loss(x, reconst_x, mean, log_var):
    reconst_error = -torch.nn.functional.binary_cross_entropy(reconst_x, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    elbo = (reconst_error - beta * kl_divergence) / len(x)
    return -elbo

def reconst_loss(x, reconst_x):
    return torch.nn.functional.binary_cross_entropy(reconst_x, x, reduction='sum') / len(x)

def kl_loss(mean, log_var):
    return -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) / len(mean)

def train(device, model):
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    begin_time = time()
    # train
    with open('/root/new/BVAE/findone/loss_2.txt', 'w') as file:
        for i in range(n_epochs):
            for batch_idx, (x, _) in enumerate(train_loader):
                x = x.to(device)
                x_hat, mean, logvar = model(x)
                loss = elbo_loss(x, x_hat, mean, logvar)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            # estimate loss
            model.eval()
            with torch.no_grad():
                each_epoch = 10
                n_samples = 100
                #train
                indices = torch.randperm(len(train_dataset))[:n_samples]
                x = torch.stack([train_dataset[i][0] for i in indices]).to(device)
                x_hat, mean, logvar = model(x)
                loss = elbo_loss(x, x_hat, mean, logvar)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.4f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.4f}'.format(i, loss_recons))
                    print('====> Epoch: {} reconst loss: {:.4f}'.format(i, loss_kl))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + ' ')
                #test
                indices = torch.randperm(len(test_dataset))[:n_samples]
                x = torch.stack([test_dataset[i][0] for i in indices]).to(device)
                x_hat, mean, logvar = model(x)
                loss = elbo_loss(x, x_hat, mean, logvar)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.4f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.4f}'.format(i, loss_recons))
                    print('====> Epoch: {} reconst loss: {:.4f}'.format(i, loss_kl))
                file.write(str(loss.item()) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + '\n')
            
            #time
            if(i % each_epoch == 0):
                training_time = time() - begin_time
                minute = int(training_time // 60)
                second = int(training_time % 60)
                print(f'time loss {minute}:{second}')
        
        torch.save(model, '/root/new/BVAE/findone/model_2.pth')
        
    tot_training_time = time() - begin_time
    minute = int(tot_training_time // 60)
    second = int(tot_training_time % 60)
    print(f'total time loss {minute}:{second}')

def initialize_parameters(model):
    for param in model.parameters():
        param.data.normal_(mean=0, std=0.01)

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
        
import numpy as np
import pandas as pd

In [None]:
def main():
    device = 'cuda:0'

    # Load the model
    #model = torch.load('/root/new/BVAE/bvae_1.pth', map_location=device)

    #Choose which to play
    
    model = VAE(latent_dim = 3, hiddens=[500, 500]).to(device)
    initialize_parameters(model)
    train(device, model)
    
    #generate(device, model)
    #reconstruct(device, model)
    #latent_space(device, model)
    #explore_latent()
    #print(sum(p.numel() for p in model.parameters() if p.requires_grad))

if __name__ == '__main__':
    main()

In [5]:
device = 'cpu'
model = torch.load('/root/new/BVAE/findone/model_2.pth', map_location=device)

In [6]:
# Generation
with torch.no_grad():    
    
    noise = torch.randn(100, 3).to(device)
    generated_imgs = model.decoder(noise)
    save_image(generated_imgs, '/root/new/BVAE/findone/pictures/genera_2.png', nrow=10)

In [7]:
# Reconstruction
with torch.no_grad(): 
    
    for batch_idx, (batch_x, _) in enumerate(test_loader):
        true_imgs = batch_x.view(-1, 1, 28, 28)
        save_image(true_imgs, '/root/new/BVAE/findone/pictures/oring_2.png', nrow=10)
        break
    
    x = true_imgs.to(device).view(-1, 784)
    
    reconst_x = model(x)[0]
    reconst_imgs = reconst_x.view(-1, 1, 28, 28)
    save_image(reconst_imgs, '/root/new/BVAE/findone/pictures/recons_2.png', nrow=10)

In [None]:
#drawing

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
        
import numpy as np
import pandas as pd

def moving_average(data, window_size):
    return pd.Series(data).rolling(window=window_size, min_periods=1).mean().values

def drawing():
    plt.clf()
    fig = plt.figure(figsize=(8, 8))
    trains = []
    trains_recon = []
    trains_kl = []
    tests = []
    tests_recon = []
    tests_kl = []
    with open('/root/new/BVAE/findone/loss_2.txt', 'r') as file:
        for line in file:
            parts = line.split()
            trains.append(float(parts[0]))
            trains_recon.append(float(parts[1]))
            trains_kl.append(float(parts[2]))
            tests.append(float(parts[3]))
            tests_recon.append(float(parts[4]))
            tests_kl.append(float(parts[5]))
            
    epochs = [len(train_dataset) * i for i in range(1, len(trains) + 1)]
    
    size = 100
    print("Reconstruction Loss: ", sum(trains_recon[-size:])/size)
    print("KL Loss: ", sum(trains_kl[-size:])/size)
    print("Reconstruction Loss: ", sum(tests_recon[-size:])/size)
    print("KL Loss: ", sum(tests_kl[-size:])/size)
    
    window_size = 50
    trains = moving_average(trains, window_size)
    tests = moving_average(tests, window_size)
    trains_recon = moving_average(trains_recon, window_size)
    tests_recon = moving_average(tests_recon, window_size)
    trains_kl = moving_average(trains_kl, window_size)
    tests_kl = moving_average(tests_kl, window_size)
    
    plt.plot(epochs, trains, label='Train ELBO Loss')
    plt.plot(epochs, tests, label='Test ELBO Loss')
    plt.plot(epochs, trains_recon, label='Train reconstruction Loss')
    plt.plot(epochs, tests_recon, label='Test reconstruction Loss')
    plt.plot(epochs, trains_kl, label='Train kl Loss')
    plt.plot(epochs, tests_kl, label='Test kl Loss')
    
    plt.xlabel('Samples')
    plt.ylabel('Loss')
    plt.title('MNIST')
    plt.legend()
    #plt.ylim(-150, -100)
    #plt.xlim(1e5, 1e8)
    plt.xscale('log')
    plt.savefig('/root/new/BVAE/findone/pictures/train_2.png',bbox_inches='tight')
    plt.show()

drawing()