In [1]:
import torch
import torch.nn as nn
import torchvision
from utils import make_dataloaders
import matplotlib.pyplot as plt
from models.vaes import VAE, IWAE, AIS_VAE, AIWAE
import yaml
import numpy as np
%matplotlib widget

In [2]:
from models.samplers import HMC, MALA, run_chain

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
n_leapfrogs = 10
step_size = 0.05

In [5]:
hmc = HMC(n_leapfrogs=n_leapfrogs, step_size=step_size, partial_ref=True, use_barker=True).to(device)
# hmc = MALA(step_size, use_barker=True, learnable=False)

In [6]:
target = torch.distributions.Normal(loc=torch.tensor([10., 10.], device=device), scale=torch.tensor([1., 1.], device=device))

In [7]:
target_density = lambda z, x: target.log_prob(z).sum(-1)

In [8]:
samples = run_chain(kernel=hmc, z_init=torch.randn((1, 2), device=device), target=target_density, return_trace=True, n_steps=1000)

In [9]:
samples.shape

torch.Size([1000, 2])

In [10]:
target_samples = target.sample((1000, ))

In [11]:
plt.close()
plt.figure()
plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), label='samples')
plt.scatter(target_samples[:, 0].cpu(), target_samples[:, 1].cpu(), label='target')
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
train_loader, val_loader = make_dataloaders(dataset='celeba', batch_size=10, val_batch_size=10, binarize=False)

In [19]:
version = 2
epoch = 29

In [20]:
with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    fruits_list = yaml.load(file, Loader=yaml.FullLoader)

    print(fruits_list)
    hparams = fruits_list

{'act_func': <class 'torch.nn.modules.activation.LeakyReLU'>, 'dataset': 'celeba', 'hidden_dim': 64, 'name': 'VAE', 'net_type': 'conv', 'num_samples': 1, 'shape': 64}


In [21]:
model = VAE(**hparams).to(device)
checkpoint = torch.load(f'lightning_logs/default/version_{version}/checkpoints/epoch={epoch}.ckpt')
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [22]:
for batch in train_loader:
    x, _ = batch
    break

In [23]:
plt.close()
plt.figure()
obj1 = x[0][None].to(device)
if obj1.shape[1] == 1:
    plt.imshow(obj1[0][0].cpu())
else:
    plt.imshow(obj1[0].permute((1, 2, 0)).cpu())
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
plt.close()
plt.figure()
obj2 = x[-1][None].to(device)
if obj2.shape[1] == 1:
    plt.imshow(obj2[0][0].cpu())
else:
    plt.imshow(obj2[0].permute((1, 2, 0)).cpu())
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [30]:
def interpolate_annealing(model, obj1, obj2, kernel, T=10, n_steps=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_init = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        for t in np.linspace(0., 1., T):
            target_density = lambda t: lambda z, x: (1 - t) * model.joint_density()(z=z, x=obj1) + t * model.joint_density()(z=z, x=obj2)
            current_z = run_chain(kernel=kernel, z_init=z_init, target=target_density(t=t), return_trace=False, n_steps=n_steps)
            all_z = torch.cat([all_z, z_init])
            z_init = current_z
    return all_z

def interpolate_mixture(model, obj1, obj2, kernel, T=10, n_steps=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_init = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        for t in np.linspace(0.01, 0.99, T):
            target_density = lambda t: lambda z, x: torch.logsumexp(torch.cat([np.log(1 - t) + model.joint_density()(z=z, x=obj1)[None],
                                                                               np.log(t) + model.joint_density()(z=z, x=obj2)[None]]), dim=0)
            current_z = run_chain(kernel=kernel, z_init=z_init, target=target_density(t=t), return_trace=False, n_steps=n_steps)
            all_z = torch.cat([all_z, z_init])
            z_init = current_z
    return all_z

def interpolate_linear(model, obj1, obj2, T=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_1 = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        z_2 = torch.mean(model.enc_rep(obj2)[0], 0)[None]
        for t in np.linspace(0., 1., T):
            current_z = (1 - t) * z_1 + t * z_2
            all_z = torch.cat([all_z, current_z])
    return all_z

def visualize(model, z, shape=(-1, 1, 28, 28)):
    with torch.no_grad():
        x_hat = torch.sigmoid(model(z)).view(shape).cpu()
        plt.close()
        plt.figure()
        if shape[1] == 1:
            grid = torchvision.utils.make_grid(x_hat, nrow=15).mean(0)
            plt.imshow(grid, 'gray')
        else:
            grid = torchvision.utils.make_grid(x_hat, nrow=15)
            plt.imshow(grid.permute((1, 2, 0)))
        plt.tight_layout()
        plt.show();

In [26]:
all_z = interpolate_annealing(model, obj1, obj2, hmc, T=10, n_steps=20)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [27]:
all_z = interpolate_linear(model, obj1, obj2, T=10)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [33]:
all_z = interpolate_mixture(model, obj1, obj2, hmc, T=10, n_steps=30)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …