In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib widget

In [2]:
nll_ula_1 = [88.78, 86.90, 86.41, 86.23, 86.23, 86.13, 86.06, 86.19, 86.07, 85.91]
nll_ula_3 = [88.21, 86.57, 86.05, 85.86, 85.94, 85.80, 85.69, 85.79, 85.79, 85.61]
nll_ula_5 = [87.67, 85.89, 85.46, 85.30, 85.24, 85.27, 85.29, 85.05, 85.18, 85.02]
nll_ula_7 = [88.24, 86.44, 85.70, 85.63, 85.61, 85.50, 85.44, 85.46, 85.38, 85.48]
nll_ula_10 = [87.35, 85.76, 85.28, 85.14, 84.83, 85.04, 84.86, 84.86, 84.8, 84.85]

In [3]:
nll_iwae_10 = [88.18, 86.61, 85.94, 85.5, 85.24, 85.11, 85.12, 84.88, 85.23, 84.95] # 84.52
nll_vae = [89.01, 87.26, 86.49, 85.85, 85.58, 85.74, 85.33, 85.29, 85.1, 85.14]

In [10]:
plt.close()
plt.figure(figsize=(5, 3), dpi=200)
plt.plot([1, 3, 5, 10], -1 * np.array([nll_ula_1[-1], nll_ula_3[-1], nll_ula_5[-1], nll_ula_10[-1]]), '-o', label='ULA VAE')
# plt.hlines(-nll_vae[-1], -10, 10, linestyle='--', color='red', label='VAE')
plt.hlines(-nll_iwae_10[-1], -10, 10, linestyle='--', color='green', label='IWAE')
plt.xlabel('Transitions')
plt.ylabel('Log likelihood')
plt.xlim(1, 10)
plt.xticks([1, 3, 5, 7, 10])
plt.grid()
plt.legend()
plt.tight_layout();
plt.savefig('nll.pdf', format='pdf')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
plt.close()
plt.figure(figsize=(5, 3), dpi=200)
plt.plot(10 * np.arange(10), -np.array(nll_ula_10), '-o', label='ULA VAE (K=10)')
plt.plot(10 * np.arange(10), -np.array(nll_vae), '-o', label='VAE', color='red')
plt.plot(10 * np.arange(10), -np.array(nll_iwae_10), '-o', label='IWAE (K=10)', color='green')
plt.xlabel('Epoch')
plt.ylabel('Log likelihood')
plt.grid()
plt.legend()
plt.tight_layout();
plt.savefig('nll_training.pdf', format='pdf')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Experiments on 2d models.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from utils import make_dataloaders
from models.vaes import Base, VAE, IWAE, AIS_VAE, AIWAE, ULA_VAE, Stacked_VAE, VAE_with_flows
from models.samplers import HMC, MALA, ULA, run_chain
import yaml
import numpy as np
from scipy.stats import norm

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

%matplotlib widget

colors = {
    0: 'blue',
    1: 'red',
    2: 'green',
    3: 'yellow',
    4: 'black',
    5: 'orange',
}

In [2]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [3]:
train_loader, val_loader = make_dataloaders(dataset='mnist', batch_size=100, val_batch_size=100, binarize=True)



In [4]:
version = 292

In [5]:
# 268 - IWAE
# 269 - VAE
# 292/293 -- AIS_VAE without/with alpha annealing

In [6]:
def load_model(version):
    with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
        fruits_list = yaml.load(file, Loader=yaml.FullLoader)
        print(fruits_list)
        hparams = fruits_list
        
    path = f'lightning_logs/default/version_{version}/checkpoints/'
    file_name = os.listdir(path)[0]
    checkpoint = torch.load(f'{path}{file_name}')
    
    for current_model in [VAE, IWAE, ULA_VAE, AIS_VAE]:
        try:
            model = current_model(**hparams).to(device)
            model.load_state_dict(checkpoint['state_dict'])
        except:
            pass
        else:
            print(f'loaded {model.name}')
            return model

In [7]:
model = load_model(version=version)

{'K': 1, 'acceptance_rate_target': 0.8, 'act_func': <class 'torch.nn.modules.activation.GELU'>, 'beta': None, 'dataset': 'mnist', 'grad_clip_val': 0.0, 'grad_skip_val': 0.0, 'hidden_dim': 2, 'learnable_transitions': False, 'name': 'AIS_VAE', 'net_type': 'conv', 'num_samples': 5, 'shape': 28, 'step_size': 0.01, 'use_alpha_annealing': False, 'use_barker': False, 'use_cloned_decoder': False, 'variance_sensitive_step': True}
loaded AIS_VAE


In [8]:
random_vector = torch.randn((10, model.hidden_dim)).to(device)

def reconstruct_image(model, pics, num_samples=50):
    '''
    Function to reconstruct given images
    '''
    with torch.no_grad():
        z = get_posterior_samples(model, pics.to(device), n_samples=1)
        z = torch.tensor(z, device=device, dtype=torch.float32)
        pics_rec = torch.sigmoid(model.decode(z)).cpu()
    return pics_rec

## And let us write a function to generate images:
def generate_image(model, random_vector):
    with torch.no_grad():
        generated = torch.sigmoid(model.decode(random_vector)).cpu()
    return generated


def plot_digit_samples(original, reconstucted, generated):
    """
    Plot samples from the generative network in a grid
    """

    grid_h = 2
    grid_w = 5
    data_h = 28
    data_w = 28
    data_c = 1
    plt.close()
    fig, ax = plt.subplots(ncols=3, figsize=(5, 3), dpi=200)
    images_list = [original, reconstucted, generated]
    names = ['original', 'reconstructed', 'generated']
    for pos in range(3):
        # Turn the samples into one large image
        tiled_img = np.zeros((data_h * grid_h, data_w * grid_w))

        for idx, image in enumerate(images_list[pos]):
            i = idx % grid_w
            j = idx // grid_w

            top = j * data_h
            bottom = (j + 1) * data_h
            left = i * data_w
            right = (i + 1) * data_w
            tiled_img[top:bottom, left:right] = image

        # Plot the new image
        ax[pos].set_title(names[pos])
        ax[pos].axis('off')
        ax[pos].imshow(tiled_img, cmap='gray')
    plt.tight_layout()
    plt.show()


def get_posterior_samples(model, X, n_samples=1000):
    '''
    The function returns samples from posterior (from encoder for VAE and IWAE, transitions output for ULA/AIS VAEs)
    '''
    all_samples = np.array([])
    with torch.no_grad():
        for x in X:
            x = x[None].to(device)
            model_samples, mu, logvar = model.enc_rep(x=x, n_samples=n_samples)
            if model.name in ['ULA_VAE', 'AIS_VAE']:
                model_samples_init = model_samples

                model_samples = model.run_transitions(z=model_samples, x=x.repeat(n_samples, 1, 1, 1), mu=mu, logvar=logvar)[0]
            if all_samples.shape[0] == 0:
                all_samples = model_samples.cpu().numpy()[None]
            else:
                all_samples = np.concatenate([all_samples, model_samples.cpu().numpy()[None]])
    return all_samples

def plot_image(x):
    '''
    The function plots given image (tensor) x
    '''
    plt.close()
    plt.figure()
    if x.shape[1] == 1:
        plt.imshow(x[0].cpu())
    else:
        plt.imshow(x.permute((1, 2, 0)).cpu())
    plt.tight_layout()
    plt.show();
    
def plot_posterior_samples(samples, labels=None):
    '''
    The function takes given samples of shape [n_objects, n_samples, dims] and plots them, using different colors
    '''
    plt.close()
    plt.figure()
    for i, sampl in enumerate(samples):
        plt.scatter(sampl[:, 0], sampl[:, 1], c=colors[i], label=labels[i].item() if labels is not None else None)
    plt.axis('equal')
    plt.legend()
    plt.tight_layout()
    plt.show();
    
def form_objects(indices):
    '''
    The function forms pairs (object, labels) from given indices
    '''
    if not isinstance(indices, list):
        indices = [indices]
    formed_objects = torch.tensor([])
    formed_labels = torch.tensor([])
    for i in indices:
        formed_objects = torch.cat([formed_objects, batch[i][None]])
        formed_labels = torch.cat([formed_labels, labels[i][None]])
    return formed_objects, formed_labels

def plot_pics_manifold(model, n=15):
    '''
    The function plots manifold, given model
    '''
    image_size = 28
    figure = np.zeros((image_size * n, image_size * n))
    grid_x = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)
    grid_y = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)

    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z_sample = torch.tensor(np.array([[xi, yi]]), dtype=torch.float32, device=model.device)
            with torch.no_grad():
                x_decoded = torch.sigmoid(model.decode(z_sample)).cpu().numpy()
            image = x_decoded[0].reshape(image_size, image_size)
            figure[i * image_size: (i + 1) * image_size,
                   j * image_size: (j + 1) * image_size] = image

    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='Greys_r')
    plt.tight_layout()
    plt.show()

In [9]:
# x_cond = 
for full_batch in val_loader:
    batch, labels = full_batch
    break

In [17]:
plot_image(batch[7])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
formed_objects, formed_labels = form_objects([4, 7, 19])

In [19]:
posterior_samples = get_posterior_samples(model, formed_objects)

In [20]:
plot_posterior_samples(posterior_samples, labels=formed_labels)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
plot_pics_manifold(model, n=50)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
batch.shape

torch.Size([100, 1, 28, 28])

In [16]:
reconstructed = reconstruct_image(model, batch[:10], num_samples=1)
generated = generate_image(model, random_vector)
plot_digit_samples(original=batch[:10], reconstucted=reconstructed, generated=generated)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …