In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import Subset

device = 'cpu'

In [8]:
class VAE(nn.Module):

    def __init__(self, input_size, hidden_size, latent_dim) -> None:
        super().__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_dim)
        self.fc22 = nn.Linear(hidden_size, latent_dim)
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_size)
        self.fc4 = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        #encoded = self.encoder(x)
        encoded = torch.flatten(x, 1)
        encoded = torch.tanh(self.fc1(encoded))
        mean = self.fc21(encoded)
        logvar = self.fc22(encoded)
        eps = torch.randn_like(logvar)
        std = torch.exp(logvar / 2)
        z = eps * std + mean
        x = torch.tanh(self.fc3(z))
        x = torch.sigmoid(self.fc4(x))
        x = torch.reshape(x, (-1, 1, 28, 28))
        decoded = x
        return mean, logvar, decoded
    
model = torch.load('/root/paper/elbo/mnist_2_model.pth', map_location=device)

In [28]:
import numpy as np
from scipy.stats import norm

def produce_z_values(nrows, ncolumes, scenario=1):
    if scenario == 1:
        cdf_range1 = np.linspace(1e-5, 1 - 1e-5, ncolumes)
        cdf_range2 = np.linspace(1 - 1e-5, 1e-5, nrows)
        mat_z1, mat_z2 = np.meshgrid(norm.ppf(cdf_range1), norm.ppf(cdf_range2))
        z_values = np.concatenate((mat_z1.reshape(-1, 1), mat_z2.reshape(-1, 1)), axis=1)
        return z_values
    elif scenario == 2:
        z_range1 = np.linspace(-4.0, 4.0, ncolumes)
        z_range2 = np.linspace(4.0, -4.0, nrows)
        mat_z1, mat_z2 = np.meshgrid(z_range1, z_range2)
        z_values = np.concatenate((mat_z1.reshape(-1, 1), mat_z2.reshape(-1, 1)), axis=1)
        return z_values
    else:
        raise ValueError('The argument \"scenario\" must be an integer from the set {1, 2}.')

#manifold
z_values = produce_z_values(nrows=20, ncolumes=20)
z_values = torch.from_numpy(z_values).float()
with torch.no_grad():
    z_values = torch.tanh(model.fc3(z_values))
    z_values = torch.sigmoid(model.fc4(z_values))
    generated_imgs = z_values.view(-1, 1, 28, 28)
    
    save_image(generated_imgs, '/root/paper/plots/pictures/manifold.png',nrow=20)

In [30]:
model_2D = torch.load('/root/paper/elbo/mnist_2_model.pth', map_location=device)
model_5D = torch.load('/root/paper/elbo/mnist_5_model.pth', map_location=device)
model_10D = torch.load('/root/paper/elbo/mnist_10_model.pth', map_location=device)
model_20D = torch.load('/root/paper/elbo/mnist_20_model.pth', map_location=device)
    

In [None]:
with torch.no_grad():    
    # Generation
    '''
    noise2 = torch.randn(100, 2).to(device)
    noise2 = torch.tanh(model_2D.fc3(noise2))
    noise2 = torch.sigmoid(model_2D.fc4(noise2))
    generated_imgs = noise2.view(-1, 1, 28, 28)
    save_image(generated_imgs, '/root/paper/plots/pictures/genera_2.png', nrow=10)
    
    noise5 = torch.randn(100, 5).to(device)
    noise5 = torch.tanh(model_5D.fc3(noise5))
    noise5 = torch.sigmoid(model_5D.fc4(noise5))
    generated_imgs = noise5.view(-1, 1, 28, 28)
    save_image(generated_imgs, '/root/paper/plots/pictures/genera_5.png', nrow=10)
    
    noise10 = torch.randn(100, 10).to(device)
    noise10 = torch.tanh(model_10D.fc3(noise10))
    noise10 = torch.sigmoid(model_10D.fc4(noise10))
    generated_imgs = noise10.view(-1, 1, 28, 28)
    save_image(generated_imgs, '/root/paper/plots/pictures/genera_10.png', nrow=10)
    
    noise20 = torch.randn(100, 20).to(device)
    noise20 = torch.tanh(model_20D.fc3(noise20))
    noise20 = torch.sigmoid(model_20D.fc4(noise20))
    generated_imgs = noise20.view(-1, 1, 28, 28)
    save_image(generated_imgs, '/root/paper/plots/pictures/genera_20.png', nrow=10)
    '''
    


In [35]:

bs = 100
# MNIST Dataset
#train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=True)
train_dataset = datasets.MNIST(root='/root/data/', train=True, transform=transforms.ToTensor(), download=False)
#test_dataset = datasets.MNIST(root='/root/data/', train=False, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='/root/data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=True)

with torch.no_grad(): 
    # Reconstruction
    for batch_idx, (batch_x, _) in enumerate(test_loader):
        true_imgs = batch_x.view(-1, 1, 28, 28)
        save_image(true_imgs, '/root/paper/plots/pictures/oring.png', nrow=10)
        break
    
    x = true_imgs.to(device).view(-1, 784)
    
    reconst_x = model_2D(x)[-1]
    reconst_imgs = reconst_x.view(-1, 1, 28, 28)
    save_image(reconst_imgs, '/root/paper/plots/pictures/recons_2.png', nrow=10)
    
    reconst_x = model_5D(x)[-1]
    reconst_imgs = reconst_x.view(-1, 1, 28, 28)
    save_image(reconst_imgs, '/root/paper/plots/pictures/recons_5.png', nrow=10)
    
    reconst_x = model_10D(x)[-1]
    reconst_imgs = reconst_x.view(-1, 1, 28, 28)
    save_image(reconst_imgs, '/root/paper/plots/pictures/recons_10.png', nrow=10)
    
    reconst_x = model_20D(x)[-1]
    reconst_imgs = reconst_x.view(-1, 1, 28, 28)
    save_image(reconst_imgs, '/root/paper/plots/pictures/recons_20.png', nrow=10)
    