In [1]:
import torch
import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST(root='/root/data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)


In [2]:
class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim, dim, 1)
        self.batchnorm2d = nn.BatchNorm2d(dim)

    def forward(self, x):
        tmp = self.conv1(x)
        tmp = self.batchnorm2d(tmp)
        tmp = self.relu(tmp)
        tmp = self.conv2(tmp)
        tmp = self.batchnorm2d(tmp)
        tmp = x + tmp
        tmp = self.relu(tmp)
        return tmp

class Encoder(nn.Module):
    def __init__(self, channels, latent_dim) -> None:
        super().__init__()
        
        pre_channel = 1
        modules = []
        img_length = 28
        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(pre_channel,
                              channels[i+1],
                              kernel_size=3,
                              stride=2,
                              padding=1),
                    nn.BatchNorm2d(channels[i+1]),
                    nn.ReLU()
                )
            )
            pre_channel = channels[i+1]
            img_length = (img_length-1)//2+1
        
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(pre_channel * img_length * img_length, latent_dim)
        self.var_linear = nn.Linear(pre_channel * img_length * img_length, latent_dim)
        self.latent_dim = latent_dim
        self.last_pre_channel = pre_channel
        self.last_img_length = img_length
        
    def forward(self, x):
        encoded = self.encoder(x)
        encoded = torch.flatten(encoded, 1)
        mean = self.mean_linear(encoded)
        logvar = self.var_linear(encoded)
        std = torch.exp(logvar / 2)
        return mean, std

class Decoder(nn.Module):
    def __init__(self, channels, latent_dim, pre_channel, img_length) -> None:
        super().__init__()
        
        modules = []
        self.decoder_projection = nn.Linear(
            latent_dim, pre_channel * img_length * img_length)
        self.decoder_input_chw = (pre_channel, img_length, img_length)

        for i in range(len(channels)-1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(channels[len(channels)-i-1],
                                       channels[len(channels)-i-2],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(channels[len(channels)-i-2]),
                    nn.ReLU()
                )
            )
        self.decoder_layers = nn.Sequential(*modules)
        
    def forward(self, z):
        z = self.decoder_projection(z)
        z = torch.reshape(z, (-1, *self.decoder_input_chw))
        decoded = self.decoder_layers(z)
        return decoded
    
class VAE(nn.Module):
    
    def __init__(self, channels, latent_dim):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(channels, latent_dim)
        pre_channel = self.encoder.last_pre_channel
        img_length = self.encoder.last_img_length
        self.decoder = Decoder(channels, latent_dim, pre_channel, img_length)
        self.z_dim = latent_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            mean = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            std = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(mean, std).to_event(1))
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            # score against actual images
            pyro.sample("obs", dist.Normal(loc_img.reshape(-1, 1*28*28), 0.3).to_event(1), obs=x.reshape(-1, 1*28*28))

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            mean, std = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(mean, std).to_event(1))

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        mean, std = self.encoder(x)
        # sample in latent space
        z = dist.Normal(mean, std).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z)
        logvar = 2*torch.log(std)
        return loc_img, mean, logvar

def initialize_parameters(model):
    for param in model.parameters():
        param.data.normal_(mean=0, std=0.01)

In [3]:
from time import time

LEARNING_RATE = 1.0e-3
n_epochs = 500
adam_args = {"lr": LEARNING_RATE}

def reconst_loss(x, x_hat):
    return F.mse_loss(x_hat, x) / len(x)

def kl_loss(mean, logvar):
    return -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar)) / len(mean)

def train(device, svi, model):
    
    begin_time = time()
    # train
    with open('/root/pyro/MNIST/VAE/loss.txt', 'w') as file:
        for i in range(n_epochs):
            for batch_idx, (x, label) in enumerate(train_loader):
                x = x.to(device)
                label = label.to(device)
                label = label.unsqueeze(1)
                svi.step(x)
                
            # estimate loss
            model.eval()
            with torch.no_grad():
                each_epoch = 10
                n_samples = 100
                #train
                indices = torch.randperm(len(train_dataset))[:n_samples]
                x = torch.stack([train_dataset[i][0] for i in indices]).to(device)
                label = torch.tensor([train_dataset[i][1] for i in indices]).to(device).unsqueeze(1)
                loss = svi.evaluate_loss(x)/len(x)
                x_hat, mean, logvar = model.reconstruct_img(x)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} kl loss: {:.7f}'.format(i, loss_kl))
                file.write(str(loss) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + ' ')
                #test
                indices = torch.randperm(len(test_dataset))[:n_samples]
                x = torch.stack([test_dataset[i][0] for i in indices]).to(device)
                label = torch.tensor([test_dataset[i][1] for i in indices]).to(device).unsqueeze(1)
                loss = svi.evaluate_loss(x)/len(x)
                x_hat, mean, logvar = model.reconstruct_img(x)
                loss_recons = reconst_loss(x, x_hat)
                loss_kl = kl_loss(mean, logvar)
                if(i % each_epoch == 0):
                    print('====> Epoch: {} elbo loss: {:.7f}'.format(i, loss))
                    print('====> Epoch: {} reconst loss: {:.7f}'.format(i, loss_recons))
                    print('====> Epoch: {} kl loss: {:.7f}'.format(i, loss_kl))
                file.write(str(loss) + ' ' + str(loss_recons.item()) + ' ' + str(loss_kl.item()) + '\n')
                
                #reconstruct and generation
                if(i % each_epoch == 0):
                    for batch_idx, (batch_x, label) in enumerate(train_loader):
                        true_imgs = batch_x[0:49].view(-1, 1, 28, 28)
                        resized_image = torchvision.transforms.Resize((50, 50))(true_imgs)
                        save_image(resized_image, '/root/pyro/MNIST/VAE/pictures/oring.png', nrow=7)
                        break

                    x = true_imgs.to(device)

                    reconst_x = model.reconstruct_img(x)[0]
                    reconst_imgs = reconst_x.view(-1, 1, 28, 28)
                    resized_image = torchvision.transforms.Resize((50, 50))(reconst_imgs)
                    save_image(resized_image, '/root/pyro/MNIST/VAE/pictures/recons.png', nrow=7)

                    print(reconst_loss(x, reconst_x))
                    
                    noise = torch.randn(49, 4).to(device)
                    generated_imgs = model.decoder(noise)
                    resized_image = torchvision.transforms.Resize((50, 50))(generated_imgs)
                    save_image(resized_image, '/root/pyro/MNIST/VAE/pictures/genera.png', nrow=7)
            
            #time
            if(i % each_epoch == 0):
                training_time = time() - begin_time
                minute = int(training_time // 60)
                second = int(training_time % 60)
                print(f'time loss {minute}:{second}')
        
        #torch.save(model.state_dict(), '/root/pyro/MNIST/VAE/model.pth')

In [None]:
def main():
    pyro.clear_param_store()
    device = 'cuda:0'
    vae = VAE(channels = [1, 500, 500], latent_dim = 4).to(device)
    optimizer = Adam(adam_args)
    initialize_parameters(vae)
    svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
    train(device, svi, vae)

if __name__ == '__main__':
    main()