In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from models.vaes import VAE, IWAE, AIS_VAE, AIWAE
import yaml
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

%matplotlib widget

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def generate_dataset(N, eps0, d=2, sigma=1.):
    z = np.random.randn(N, d)
    x = eps0 * z**2 + np.random.randn(N, d) * sigma
    return x

In [None]:
class VAE_Toy(VAE):
    def loss_function(self, recon_x, x, mu, logvar):
        batch_size = mu.shape[0] // self.num_samples
        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
                                                 reduction='none').view(
            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
        KLD = -0.5 * torch.mean((1 + logvar - mu.pow(2) - logvar.exp()).view(
            (self.num_samples, -1, self.hidden_dim)).mean(0).sum(-1))
        loss = BCE + KLD
        return loss
    
    def joint_logdensity(self, ):
        def density(z, x):
            z = z.clone()
            x_reconst = self(z)
            log_Pr = torch.distributions.Normal(loc=torch.tensor(0., device=x.device, dtype=torch.float32),
                                                scale=torch.tensor(1., device=x.device, dtype=torch.float32)).log_prob(
                z).sum(-1)
            return -F.mse_loss(x_reconst.view(x_reconst.shape[0], -1),
                                                       x.view(x_reconst.shape[0], -1), reduction='none').sum(
                -1) + log_Pr

        return density

In [None]:
class ToyDataset(Dataset):
    def __init__(self, data):
        super(ToyDataset, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        sample = torch.tensor(self.data[item], dtype=torch.float32, device=device)
        return sample, -1.

In [None]:
class ToyEncoder(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, d),
            nn.LeakyReLU(),
            nn.Linear(d, 2*d),
        )

    def forward(self, x):
        return self.net(x)
    
class ToyDecoder(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, d),
            nn.LeakyReLU(),
            nn.Linear(d, d),
        )

    def forward(self, x):
        return self.net(x) 

In [None]:
N = 10000
d = 2
sigma = 1
eps = np.random.randn(1, d)
X_train = generate_dataset(N=N, eps0=eps, d=d, sigma=sigma)
X_val = generate_dataset(N=N // 100, eps0=eps, d=d, sigma=sigma)

In [8]:
plt.close()
plt.figure()
plt.title('True data')
plt.hist(x=X_train, bins=100)
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
train_dataset = ToyDataset(data=X_train)
val_dataset = ToyDataset(data=X_val)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [11]:
model = VAE_Toy(shape=28, act_func=nn.LeakyReLU,
            num_samples=1, hidden_dim=d,
            net_type='conv', dataset='toy')
model.encoder_net = ToyEncoder(d=d)
model.decoder_net = ToyDecoder(d=d)
model = model.to(device)

In [12]:
tb_logger = pl_loggers.TensorBoardLogger('lightning_logs/')
trainer = pl.Trainer(logger=tb_logger, fast_dev_run=False, max_epochs=51)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name            | Type       | Params
-----------------------------------------------
0 | encoder_net     | ToyEncoder | 18    
1 | decoder_net     | ToyDecoder | 12    
2 | transitions_nll | ModuleList | 8     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

> [0;32m<ipython-input-4-ad6d06d0739d>[0m(6)[0;36mloss_function[0;34m()[0m
[0;32m      4 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(7)[0;36mloss_function[0;34m()[0m
[0;32m      5 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m----> 7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m[0;32m      9 [0;31m        KLD = -0.5 * torch.mean((1 + logvar - mu.pow(2) - logvar.exp()).view(
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(6)[0;36mloss_function[0;34m()[0m
[0;32m      4 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(8)[0;36mloss_function[0;34m()[0m
[0;32m      6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m[0;32m      9 [0;31m        KLD = -0.5 * torch.mean((1 + logvar - mu.pow(2) - logvar.exp()).view(
[0m[0;32m     10 [0;31m            (self.num_samples, -1, self.hidden_dim)).mean(0).sum(-1))
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(6)[0;36mloss_function[0;34m()[0m
[0;32m      4 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(8)[0;36mloss_function[0;34m()[0m
[0;32m      6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m[0;32m      9 [0;31m        KLD = -0.5 * torch.mean((1 + logvar - mu.pow(2) - logvar.exp()).view(
[0m[0;32m     10 [0;31m            (self.num_samples, -1, self.hidden_dim)).mean(0).sum(-1))
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(6)[0;36mloss_function[0;34m()[0m
[0;32m      4 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(8)[0;36mloss_function[0;34m()[0m
[0;32m      6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m[0;32m      9 [0;31m        KLD = -0.5 * torch.mean((1 + logvar - mu.pow(2) - logvar.exp()).view(
[0m[0;32m     10 [0;31m            (self.num_samples, -1, self.hidden_dim)).mean(0).sum(-1))
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(6)[0;36mloss_function[0;34m()[0m
[0;32m      4 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m        MSE = F.mse_loss(recon_x.view(mu.shape[0], -1), x.view(mu.shape[0], -1),
[0m[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m


ipdb>  n


> [0;32m<ipython-input-4-ad6d06d0739d>[0m(9)[0;36mloss_function[0;34m()[0m
[0;32m      7 [0;31m                                                 [0mreduction[0m[0;34m=[0m[0;34m'none'[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m            (self.num_samples, batch_size, -1)).mean(0).sum(-1).mean()
[0m[0;32m----> 9 [0;31m        KLD = -0.5 * torch.mean((1 + logvar - mu.pow(2) - logvar.exp()).view(
[0m[0;32m     10 [0;31m            (self.num_samples, -1, self.hidden_dim)).mean(0).sum(-1))
[0m[0;32m     11 [0;31m        [0mloss[0m [0;34m=[0m [0mBCE[0m [0;34m+[0m [0mKLD[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  MSE


tensor(8.5624, device='cuda:0')


ipdb>  q


BdbQuit: 

In [26]:
z = torch.randn(10000, d, dtype=torch.float32, device=device)
with torch.no_grad():
    generated_samples = model(z).cpu().numpy()

In [29]:
plt.close()
plt.figure()
plt.title('Generated data')
plt.hist(x=generated_samples, bins=100)
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [2]:
from models.samplers import HMC, MALA, run_chain

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
n_leapfrogs = 10
step_size = 0.05

In [5]:
hmc = HMC(n_leapfrogs=n_leapfrogs, step_size=step_size, partial_ref=True, use_barker=True).to(device)
# hmc = MALA(step_size, use_barker=True, learnable=False)

In [6]:
target = torch.distributions.Normal(loc=torch.tensor([10., 10.], device=device), scale=torch.tensor([1., 1.], device=device))

In [7]:
target_density = lambda z, x: target.log_prob(z).sum(-1)

In [8]:
samples = run_chain(kernel=hmc, z_init=torch.randn((1, 2), device=device), target=target_density, return_trace=True, n_steps=1000)

In [9]:
samples.shape

torch.Size([1000, 2])

In [10]:
target_samples = target.sample((1000, ))

In [11]:
plt.close()
plt.figure()
plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), label='samples')
plt.scatter(target_samples[:, 0].cpu(), target_samples[:, 1].cpu(), label='target')
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
train_loader, val_loader = make_dataloaders(dataset='celeba', batch_size=10, val_batch_size=10, binarize=False)

In [19]:
version = 2
epoch = 29

In [20]:
with open(f'lightning_logs/default/version_{version}/hparams.yaml') as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    fruits_list = yaml.load(file, Loader=yaml.FullLoader)

    print(fruits_list)
    hparams = fruits_list

{'act_func': <class 'torch.nn.modules.activation.LeakyReLU'>, 'dataset': 'celeba', 'hidden_dim': 64, 'name': 'VAE', 'net_type': 'conv', 'num_samples': 1, 'shape': 64}


In [21]:
model = VAE(**hparams).to(device)
checkpoint = torch.load(f'lightning_logs/default/version_{version}/checkpoints/epoch={epoch}.ckpt')
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [22]:
for batch in train_loader:
    x, _ = batch
    break

In [23]:
plt.close()
plt.figure()
obj1 = x[0][None].to(device)
if obj1.shape[1] == 1:
    plt.imshow(obj1[0][0].cpu())
else:
    plt.imshow(obj1[0].permute((1, 2, 0)).cpu())
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
plt.close()
plt.figure()
obj2 = x[-1][None].to(device)
if obj2.shape[1] == 1:
    plt.imshow(obj2[0][0].cpu())
else:
    plt.imshow(obj2[0].permute((1, 2, 0)).cpu())
plt.tight_layout()
plt.show();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Interpolation functions

In [30]:
def interpolate_annealing(model, obj1, obj2, kernel, T=10, n_steps=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_init = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        for t in np.linspace(0., 1., T):
            target_density = lambda t: lambda z, x: (1 - t) * model.joint_density()(z=z, x=obj1) + t * model.joint_density()(z=z, x=obj2)
            current_z = run_chain(kernel=kernel, z_init=z_init, target=target_density(t=t), return_trace=False, n_steps=n_steps)
            all_z = torch.cat([all_z, z_init])
            z_init = current_z
    return all_z

def interpolate_mixture(model, obj1, obj2, kernel, T=10, n_steps=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_init = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        for t in np.linspace(0.01, 0.99, T):
            target_density = lambda t: lambda z, x: torch.logsumexp(torch.cat([np.log(1 - t) + model.joint_density()(z=z, x=obj1)[None],
                                                                               np.log(t) + model.joint_density()(z=z, x=obj2)[None]]), dim=0)
            current_z = run_chain(kernel=kernel, z_init=z_init, target=target_density(t=t), return_trace=False, n_steps=n_steps)
            all_z = torch.cat([all_z, z_init])
            z_init = current_z
    return all_z

def interpolate_linear(model, obj1, obj2, T=10):
    with torch.no_grad():
        all_z = torch.tensor([], dtype=torch.float32, device=obj1.device)
        z_1 = torch.mean(model.enc_rep(obj1)[0], 0)[None]
        z_2 = torch.mean(model.enc_rep(obj2)[0], 0)[None]
        for t in np.linspace(0., 1., T):
            current_z = (1 - t) * z_1 + t * z_2
            all_z = torch.cat([all_z, current_z])
    return all_z

def visualize(model, z, shape=(-1, 1, 28, 28)):
    with torch.no_grad():
        x_hat = torch.sigmoid(model(z)).view(shape).cpu()
        plt.close()
        plt.figure()
        if shape[1] == 1:
            grid = torchvision.utils.make_grid(x_hat, nrow=15).mean(0)
            plt.imshow(grid, 'gray')
        else:
            grid = torchvision.utils.make_grid(x_hat, nrow=15)
            plt.imshow(grid.permute((1, 2, 0)))
        plt.tight_layout()
        plt.show();

In [26]:
all_z = interpolate_annealing(model, obj1, obj2, hmc, T=10, n_steps=20)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [27]:
all_z = interpolate_linear(model, obj1, obj2, T=10)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [33]:
all_z = interpolate_mixture(model, obj1, obj2, hmc, T=10, n_steps=30)
visualize(model, all_z, (-1, 3, 64, 64))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …