In [1]:
import os
import torch
import gc
torch.cuda.empty_cache()
gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


0

In [2]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= ""  # Set the GPU

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [228]:
# torch.cuda.set_device(device) # change allocation of current GPU
# torch.cuda.set_per_process_memory_fraction(0.4, device=device)

# print('Device:', device)
# print('Current cuda device:', torch.cuda.current_device())
# print('Available devices:', torch.cuda.device_count())

## VAE_test

In [229]:
import os
import torch


if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"# DEVICE {i}: {torch.cuda.get_device_name(i)}")
        print("- Memory Usage:")
        print(torch.cuda.list_gpu_processes(i))
        print("------------------------")
        print(f"  Allocated: {round(torch.cuda.memory_allocated(i)/1024**3,1)} GB")
        print(f"  Cached:    {round(torch.cuda.memory_reserved(i)/1024**3,1)} GB\n")
        
else:
    print("# GPU is not available")

# GPU is not available


In [230]:
from models import VanillaVAE
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from typing import List, Callable, Union, Any, TypeVar, Tuple
from torch import tensor as Tensor

Tensor = TypeVar('torch.tensor')

In [231]:
import numpy as np
import pandas as pd

# Set random seed for reproducibility
np.random.seed(10)
input_data = np.random.randint(0,6,(100,50))
input_data

array([[1, 5, 4, ..., 0, 3, 4],
       [2, 0, 1, ..., 1, 5, 0],
       [2, 3, 5, ..., 3, 1, 3],
       ...,
       [2, 3, 1, ..., 1, 4, 3],
       [5, 4, 1, ..., 1, 0, 3],
       [0, 5, 0, ..., 2, 1, 0]])

In [232]:
test_input = torch.tensor(input_data[81:]).float()
train_input = torch.tensor(input_data[0:80]).float()

In [233]:
data_dim = train_input.size(0)*train_input.size(1)
data_dim

4000

In [199]:
latent_dim = 4
hidden_dim = 8
batch_size = 10
x_dim = len(input_data[0])

In [197]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
kwargs = {'num_workers': 1, 'pin_memory': True} 

train_loader = DataLoader(dataset=train, batch_size = batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(dataset=test, batch_size = batch_size, shuffle=False, **kwargs)

In [200]:
for batch_idx, x in enumerate(train_loader):
    print(x)
    print(batch_idx)

tensor([[3., 5., 3., 2., 4., 1., 1., 0., 4., 1., 0., 4., 2., 0., 3., 3., 4., 1.,
         2., 5., 3., 2., 4., 3., 0., 5., 1., 5., 0., 5., 3., 0., 0., 1., 1., 3.,
         0., 2., 3., 2., 1., 4., 3., 2., 4., 2., 2., 4., 5., 2.],
        [0., 0., 5., 3., 2., 1., 4., 0., 4., 1., 2., 1., 3., 3., 1., 0., 1., 3.,
         4., 0., 1., 0., 3., 3., 3., 5., 1., 0., 0., 2., 0., 3., 4., 3., 3., 5.,
         2., 0., 5., 2., 5., 5., 4., 3., 1., 4., 4., 3., 3., 5.],
        [5., 1., 1., 3., 3., 1., 4., 0., 0., 3., 0., 2., 0., 1., 0., 4., 2., 2.,
         3., 0., 3., 3., 3., 1., 3., 1., 4., 0., 2., 1., 0., 0., 0., 1., 4., 3.,
         5., 0., 2., 3., 2., 1., 4., 1., 1., 5., 3., 0., 3., 0.],
        [0., 3., 4., 0., 5., 1., 0., 1., 4., 4., 2., 4., 1., 3., 5., 1., 5., 4.,
         5., 0., 0., 2., 5., 5., 1., 1., 4., 5., 1., 1., 5., 3., 2., 5., 1., 4.,
         5., 2., 0., 0., 2., 4., 2., 1., 3., 0., 1., 4., 4., 3.],
        [4., 2., 5., 1., 1., 1., 3., 5., 5., 1., 5., 2., 2., 4., 0., 1., 2., 2.,
       

In [162]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.FC_input = nn.Linear(input_dim, hidden_dim)
        # self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        self.FC_var   = nn.Linear(hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.FC_input(x))
        # h_       = self.LeakyReLU(self.FC_input2(h_))
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)                     # encoder produces mean and log of variance 
                                                       #             (i.e., parateters of simple tractable normal distribution "q"
        
        return mean, log_var

In [163]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        # self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h     = self.LeakyReLU(self.FC_hidden(x))
        # h     = self.LeakyReLU(self.FC_hidden2(h))
        
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat

In [164]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

In [165]:
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)

model = Model(Encoder=encoder, Decoder=decoder).to(device)

In [166]:
latent = model(train)

In [167]:
latent

(tensor([[0.6242, 0.4938, 0.5504,  ..., 0.6364, 0.7783, 0.5218],
         [0.5014, 0.6090, 0.4313,  ..., 0.4894, 0.4942, 0.4171],
         [0.5418, 0.5462, 0.4615,  ..., 0.4897, 0.3771, 0.3884],
         ...,
         [0.6277, 0.4487, 0.5076,  ..., 0.6175, 0.6082, 0.4912],
         [0.4431, 0.5234, 0.4125,  ..., 0.5887, 0.2842, 0.2681],
         [0.6313, 0.6774, 0.4565,  ..., 0.3652, 0.7372, 0.5550]],
        grad_fn=<SigmoidBackward0>),
 tensor([[ 0.9455, -1.6483, -1.1045, -0.0197],
         [ 0.6648, -1.0683, -1.2934,  0.1030],
         [ 0.5216, -1.4020, -0.8126, -0.1195],
         [ 0.3730, -1.0509, -0.2236, -0.2629],
         [ 0.8461, -1.6346, -0.5925,  0.0124],
         [ 0.6205, -1.3636, -0.6137, -0.1860],
         [ 0.7227, -1.2924, -2.0313,  0.1454],
         [ 0.3135, -0.5820,  0.0812, -0.3032],
         [ 0.6396, -1.8365, -1.3478,  0.0521],
         [ 0.7029, -1.4180, -0.7664, -0.2673],
         [ 0.8016, -1.5009, -0.2992, -0.3197],
         [ 0.8535, -1.6618, -0.8213, -0.0

In [168]:
from torch.optim import SGD

BCE_loss = nn.BCELoss()

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = F.cross_entropy(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD

# from torch.optim import Adam

optimizer = SGD(model.parameters(), lr=0.1)

In [169]:
epochs = 10

In [170]:
print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, x in enumerate(train_loader):
        # x = x.view(x.size(0), x_dim)
        x = x.to(device)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_function(x, x_hat, mean, log_var)
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / batch_idx)
    
print("Finish!!")

Start training VAE...
	Epoch 1 complete! 	Average Loss:  108.44813586504031
	Epoch 2 complete! 	Average Loss:  108.22938850598457
	Epoch 3 complete! 	Average Loss:  108.17118815886668
	Epoch 4 complete! 	Average Loss:  108.16479575328337
	Epoch 5 complete! 	Average Loss:  108.15182857024364
	Epoch 6 complete! 	Average Loss:  108.1175280839969
	Epoch 7 complete! 	Average Loss:  108.20703134781274
	Epoch 8 complete! 	Average Loss:  108.14195300371219
	Epoch 9 complete! 	Average Loss:  108.13398116674179
	Epoch 10 complete! 	Average Loss:  108.11548873705742
Finish!!


In [171]:
model.eval()

with torch.no_grad():
    for batch_idx, x in enumerate(test_loader):
        x = x.to(device)
        
        x_hat = model(x)


        break

In [172]:
x

tensor([3., 0., 0., 1., 1., 3., 0., 2., 3., 2., 1., 4., 3., 2., 4., 2.])

## VAE_bayes

In [119]:
import jax.numpy as np
import pandas as pd

# Set random seed for reproducibility
np.random.seed(10)
in_data = np.random.randint(0,6,(100,50))
in_data

AttributeError: module 'jax.numpy' has no attribute 'random'

In [105]:
test_input = np.array(in_data[81:])
train_input = np.array(in_data[0:80])
train_input

array([[1, 5, 4, ..., 0, 3, 4],
       [2, 0, 1, ..., 1, 5, 0],
       [2, 3, 5, ..., 3, 1, 3],
       ...,
       [3, 3, 2, ..., 5, 3, 3],
       [0, 3, 4, ..., 2, 5, 4],
       [2, 0, 2, ..., 4, 2, 1]])

In [81]:
# test_input = torch.tensor(input_data[81:]).float()
# train_input = torch.tensor(input_data[0:80]).float()

In [106]:
batch_size = 10

In [99]:
"""
Implements auto-encoding variational Bayes (variational autoencoder).
"""

from __future__ import absolute_import, division
from __future__ import print_function
import jax.random as random
import numpy as onp
from jax.scipy.stats import norm
from jax.nn import sigmoid
from jax import vmap, grad, value_and_grad, jit, tree_util
from jax.example_libraries.optimizers import adam

from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
from collections import defaultdict
import pickle

def diag_gaussian_log_density(x, mu, log_sigma):
  """
  Args:
    x: random variable
    mu: mean
    log_sigma: log standard deviation
  Return:
    log normal density.
  """
  assert x.ndim == 1
  return np.sum(norm.logpdf(x, mu, np.exp(log_sigma)), axis=-1)


def unpack_gaussian_params(params):
  """
  Args: 
    params of a diagonal Gaussian.
  Return:
    mean, log standard deviation
  """
  D = np.shape(params)[-1] // 2
  print("params shape", params.shape)
  mu, log_sigma = params[:D], params[D:]
  return mu, log_sigma


def sample_diag_gaussian(mu, log_std, subkey):
  """Reparameterization trick for getting z from x.
  """
  return random.normal(subkey, mu.shape) * np.exp(log_std) + mu


def gaussian_log_density(x, mu, logvar):
  """
  Args: 
    x: input value
    mu: mean of the Gaussian distribution
    logvar: log variance of the Gaussian distribution
  Return: 
    log N(x | mu, var)
  """
  c = -0.5 * math.log(2 * math.pi)
  return c - 0.5 * logvar - (x - mu) ** 2 / (2 * torch.exp(logvar))


def init_net_params(scale, layer_sizes, key):
  """
  Args:
    scale: scaling factor
    layer_sizes: List[number of neurons per layer]
  Return:
    Tuple[weights, biases] for all layers."""
  k1, k2 = random.split(key, 2)
  return [
      (
          scale * random.normal(k1, (m, n)),  # weight matrix
          scale * random.normal(k2, (n,)))  # bias vector
      for m, n in zip(layer_sizes[:-1], layer_sizes[1:])
  ]


def neural_net_predict(params, inputs):
  """
  Args:
    params: List[Tuple(weights, bias)]
    inputs: an (N x D) matrix, (batch 2D latent vector Dz x B), here Dz = 2
  Return:
    Applies batch normalization to every layer but the last."""
  for W, b in params:
    print(f"W {W.shape} b {b.shape}")
    if W.shape[0] == inputs.shape[0]:
      outputs = np.dot(inputs.T, W) + b
    else:
      outputs = np.dot(inputs, W) + b  # linear transformation
    inputs = np.tanh(outputs)  # nonlinear transformation
  return outputs


def nn_predict_gaussian(params, inputs):
  """
  Args:
    params: variational parameters
    inputs: batch of datas
  Return:
    means and diagonal variances
  """
  return unpack_gaussian_params(neural_net_predict(params, inputs))


def log_prior(z):
  """
  Args:
    z: latent variable
  Return:
    computes log of pior over digit's latent represenation z.
  """
  assert z.ndim == 1
  return diag_gaussian_log_density(z, 0, 0)


def decoder(z, params):
  """
  Args:
    z: latent representation
    params: decoder network parameters, theta
  Return: 
    784-D (28x28 pixels) mean vector of prod of Bern
  """
  logits = neural_net_predict(params, z)
  return logits


def log_likelihood(z, x, params):
  """
  Args:
    z: latent representation
    x: binarized digit
    params: logits from decoder/generator theta
  Return: 
    log likelihood log p(x|z), p data given latent
  """
  mu = decoder(z, params)  # logits
  likelihood = gaussian_log_density(x, mu)
  assert likelihood.ndim == 1
  return np.sum(likelihood) # sum over pixels


def generate_from_prior(gen_params,
                        num_samples,
                        noise_dim,
                        key=random.PRNGKey(2)):
  """
  Args:
    gen_params: decoder parameters
    num_samples: number of latent variable samples
  Return: 
    Fake data: Bernouilli means p(x|z)
  """
  latents = random.normal(key, (num_samples, noise_dim))
  return sigmoid(neural_net_predict(gen_params, latents))


def joint_log_density(x, z, params):
  """
  Args:
      z: latent representation
      x: binarized digit
    params: logits from decoder
  Return: 
    log p(z, x) for a single data
  """
  return log_prior(z) + log_likelihood(z, x, params)


def encoder(x, params):
  """
  Args:
    x: batch of datas
    params: variational parameters mu and sigma
    phi: recognition parameters
  Return: 
    mean and log std of factorized Gaussian with D = 2
  """
  mu, log_sigma = nn_predict_gaussian(params, x)
  return mu, log_sigma


def log_q(z, mu, log_sigma):
  """
  Args:
    z: latent representation
    mu, log_sigma: variational distribution parameters
  Return: 
    p(x|params) likelihood of x
  """
  return diag_gaussian_log_density(z, mu, log_sigma)


def elbo(x, params, subkey):
  """
  Args:
    x: batch of B datas, D_x x B
      only need to sample single z for each data in the batch
    params: {encoder (recognition network): encoder_params phi,
             decoder (likelihood): decoder_params theta}
    subkey: jax random key
  Return: 
    scalar, unbiased estimate of mean variaitonal elbo on datas
  """
  encoder_params, decoder_params = params['enc'], params['dec']
  # latent means and log stds
  mu_qz, log_sigma_qz = encoder(x, encoder_params) 
  # Monte Carlo est of KL divergence of q from prior p (both Gaussian)
  # KL(q(z | x) || p(z)),  q ~ N(z | mu(x), sigma(x)) and p ~ N(0, I_DzxDz)
  kl = -1 / 2 * np.sum(
      np.log(np.square(np.exp(log_sigma_qz))) + 1 -
      np.square(np.exp(log_sigma_qz)) - np.square(mu_qz))
  # latent variables
  z = sample_diag_gaussian(mu_qz, log_sigma_qz, subkey)
  # p(data x | latents z)
  ll = log_likelihood(z, x, decoder_params)

  return ll - kl


def loss(*args, **kwargs):
  # Note: negate ll for the elbo loss to minimize
  return -elbo(*args, **kwargs)


def batch_loss(*args, **kwargs):
  """Negative elbo estimate over batch of data."""
  loss_ = vmap(
      loss, in_axes=(0, None, 0))(*args,
                                  **kwargs)  # correspond each sample with input
  return np.mean(loss_)

In [110]:
# Model hyper-parameters
latent_dim = 4
# data_dim = train_input.size(0)*test_input.size(1)  # How many pixels in each data (28x28).
data_dim = 4000
gen_layer_sizes = [latent_dim, 20, data_dim]  # decoder has 200 hidden
rec_layer_sizes = [data_dim, 20, latent_dim * 2]  # encoder has 200 hidden

# Training parameters
param_scale = 0.01
# batch_size = batch_size
num_epochs = 10  # train for 100 epochs
learning_rate = 0.1

key = random.PRNGKey(seed)
key, enc_k, dec_k = random.split(key, 3)
init_gen_params = init_net_params(param_scale, gen_layer_sizes,
                                dec_k)  # encoder
init_rec_params = init_net_params(param_scale, rec_layer_sizes,
                                enc_k)  # decoder
combined_init_params = dict(dec=init_gen_params, enc=init_rec_params)

num_batches = int(np.ceil(len(train_input) / batch_size))

def batch_indices(iter):
    idx = iter % num_batches
    return slice(idx * batch_size, (idx + 1) * batch_size)

In [111]:
objective_grad = jit(value_and_grad(batch_loss,
                                      argnums=1))  # differentiate w.r.t params

opt_init, opt_update, opt_get_params = adam(step_size=learning_rate)
opt_state = opt_init(combined_init_params)

it = 0


In [118]:
for epoch in tqdm(range(num_epochs)):
    for batch in tqdm(range(num_batches)):
        batch_x = train_input[batch_indices(batch)]
        params = opt_get_params(opt_state)
        key, *subkeys = random.split(key, batch_size + 1)
        subkeys = np.stack(subkeys, axis=0)
        print("subkeys shape: ", subkeys.shape)
        loss_, grad_ = objective_grad(batch_x, params, subkeys)
        opt_state = opt_update(it, grad_, opt_state)
        
        if it % 100 == 0: # save samples during training
            gen_params, rec_params = params['dec'], params['enc']
            fake_data = generate_from_prior(gen_params, 20, latent_dim, key)
            save_data(fake_data, 'vae_samples.png', vmin=0, vmax=1)
        
        if it == 0 or (it + 1) % 100 == 0:
            test_size = test_input.shape[0]
            print("test size: ", test_input.shape, train_input.shape)
            key, *subkeys = random.split(key, test_size + 1)
            subkeys = np.stack(subkeys, axis=0)
            # print performance
            loss_t = batch_loss(test_input, params, subkeys)
            message = f"Epoch: {epoch} \t Batch: {batch} \t Loss: {loss_:.3f} \t Test Loss: {loss_t:.3f}"
            tqdm.write(message)
        it += 1

# pickle to save trained weights
params = opt_get_params(opt_state)
with open(param_dump, 'wb') as file:
    pickle.dump(   params, file, protocol=pickle.HIGHEST_PROTOCOL)

  0%|          | 0/1 [00:00<?, ?it/s]]
  0%|          | 0/10 [00:00<?, ?it/s]

subkeys shape:  (80, 2)
W (4000, 20) b (20,)





TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[50])>with<BatchTrace(level=4/0)> with
  val = Traced<ShapedArray(int32[80,50])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0
This BatchTracer with object id 140171901179264 was created on line:
  /tmp/ipykernel_128713/1397141302.py:233 (batch_loss)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [25]:
batch_size = 10

In [29]:
############## TRAINING VAE ##############

def train(train_data, test_data, param_dump='opt-params.pkl', seed=0):
  """
  Optimize gradients of weights over batches of data with elbo estimate.
  """
  # Model hyper-parameters
  latent_dim = 4
  data_dim = train_input.size(0)*test_input.size(1)  # How many pixels in each data (28x28).
  gen_layer_sizes = [latent_dim, 200, data_dim]  # decoder has 200 hidden
  rec_layer_sizes = [data_dim, 200, latent_dim * 2]  # encoder has 200 hidden

  # Training parameters
  param_scale = 0.01
  # batch_size = batch_size
  num_epochs = 10  # train for 100 epochs
  learning_rate = 0.1

  key = random.PRNGKey(seed)
  key, enc_k, dec_k = random.split(key, 3)
  init_gen_params = init_net_params(param_scale, gen_layer_sizes,
                                    dec_k)  # encoder
  init_rec_params = init_net_params(param_scale, rec_layer_sizes,
                                    enc_k)  # decoder
  combined_init_params = dict(dec=init_gen_params, enc=init_rec_params)

  num_batches = int(np.ceil(len(train_data) / batch_size))

  def batch_indices(iter):
    idx = iter % num_batches
    return slice(idx * batch_size, (idx + 1) * batch_size)

  objective_grad = jit(value_and_grad(batch_loss,
                                      argnums=1))  # differentiate w.r.t params

  opt_init, opt_update, opt_get_params = adam(step_size=learning_rate)
  opt_state = opt_init(combined_init_params)

  it = 0
  for epoch in tqdm(range(num_epochs)):
    for batch in tqdm(range(num_batches)):
      batch_x = train_data[batch_indices(batch)]
      params = opt_get_params(opt_state)
      key, *subkeys = random.split(key, batch_size + 1)
      subkeys = np.stack(subkeys, axis=0)
      print("subkeys shape: ", subkeys.shape)
      loss_, grad_ = objective_grad(batch_x, params, subkeys)
      opt_state = opt_update(it, grad_, opt_state)

      if it % 100 == 0: # save samples during training
        gen_params, rec_params = params['dec'], params['enc']
        fake_data = generate_from_prior(gen_params, 20, latent_dim, key)
        save_data(fake_data, 'vae_samples.png', vmin=0, vmax=1)

      if it == 0 or (it + 1) % 100 == 0:
        test_size = test_data.shape[0]
        print("test size: ", test_data.shape, train_data.shape)
        key, *subkeys = random.split(key, test_size + 1)
        subkeys = np.stack(subkeys, axis=0)
        # print performance
        loss_t = batch_loss(test_data, params, subkeys)
        message = f"Epoch: {epoch} \t Batch: {batch} \t Loss: {loss_:.3f} \t Test Loss: {loss_t:.3f}"
        tqdm.write(message)
      it += 1

  # pickle to save trained weights
  params = opt_get_params(opt_state)
  with open(param_dump, 'wb') as file:
    pickle.dump(params, file, protocol=pickle.HIGHEST_PROTOCOL)



In [30]:
train(train_input, test_input, param_dump='opt-params.pkl', seed=0)

  0%|          | 0/8 [00:00<?, ?it/s]]
  0%|          | 0/10 [00:00<?, ?it/s]

subkeys shape:  (10, 2)





TypeError: dtype torch.float32 not understood

In [27]:
############## VISUALIZE APPROXIMATE POSTERIOR #############

def load_params(file='params2.pkl'):
  with open(file, 'rb') as f:
    params = pickle.load(f)
  # JAX does not recognize pickled file, must re-format
  # params: List[[Tuple(weights), Tuple(bias)]]
  num_layers = 2
  for k in ['dec', 'enc']:
    params[k] = list(params[k])
    for l in range(num_layers):
      params[k][l] = tuple(params[k][l])

  print("after loaded params", type(params), type(params['enc']),
        type(params['dec'][0]), type(params['dec'][0][0]))

  return params


def sample_gen(params, num_samples=10, seed=0):
  """
  Args: 
    params: the variational parameters
    num_samples: number of times to sample from distributino
    seed: random seed
  Plot samples from trained generative model using ancestral sampling.
  """
  key = random.PRNGKey(seed)
  key, k1, k2 = random.split(key, 3)
  # sample z from prior num_samples times
  # use generative model to compute bernouilli means over pixels of x given z
  means = generate_from_prior(params['dec'], num_samples, 2, k1)
  # plot means as greyscale data
  mean_data = means.reshape([-1, 28, 28])
  # sample binary data x from product of Bern and plot as data
  sample_means = random.gaussian(k2, mean_data)
  # concatenate plots: row 1, bernouilli means, row 2 corresponding binary img sampled from 1
  plot_means = np.stack([mean_data, sample_means])
  data_ = onp.zeros([2 * 28, 10 * 28])
  num_rows = 2
  num_cols = 10
  for i in range(num_rows):
    for j in range(num_cols):
      data_[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = plot_means[i, j, ...]
  plt.imshow(data_, cmap=plt.cm.binary)
  plt.axis('off')
  plt.savefig('gen_samples.png', bbox_inches='tight')


def latent_means(params, train_data, train_labels):
  """
  Args:
    params: List[Tuple(W, b)] for each layer in NN
    train_data, train_labels = (10k, 784), (10k, 10)
  Latent space scatter plot, each point is a different data in training set.
  Visualizes which part of latent space corresponds to which kinds of data.
  """
  # encode each data in the train set
  mus, log_sigmas = vmap(
      encoder, in_axes=(0, None))(train_data, params['enc'])
  # one hot encode -> continuous
  labels = np.argmax(train_labels, axis=-1)
  num_labels = train_labels.shape[-1]

  # 2D mean vector of each encoding q_phi(z|x)
  # plot mean vectors in 2D latent space
  # color each point to class label (0, 9)
  def cmap_process(cmap, N):
    if type(cmap) == str:
      cmap = plt.get_cmap(cmap)
    col_idx = onp.concatenate((onp.linspace(0, 1., N), (0., 0., 0., 0.)))
    col_rgb = cmap(col_idx)
    idx = onp.linspace(0, 1., N + 1)
    cols = {}
    for k_i, k in enumerate(('red', 'green', 'blue')):
      cols[k] = [
          (idx[i], col_rgb[i - 1, k_i], col_rgb[i, k_i]) for i in range(N + 1)
      ]

    return mpl.colors.LinearSegmentedColormap(f"{cmap.name}-{N}", cols, 1024)

  def color_index(num_colors, cmap):
    cmap = cmap_process(cmap, num_colors)
    color_map = mpl.cm.ScalarMappable(cmap=cmap)
    color_map.set_array([])
    color_map.set_clim(-0.5, num_colors + 0.5)
    color_bar = plt.colorbar(color_map, fraction=0.045, pad=0.04)
    color_bar.set_ticks(onp.linspace(0, num_colors, num_colors))
    color_bar.set_ticklabels(range(num_colors))

    return color_bar

  fig, ax = plt.subplots()
  cmap = plt.cm.jet
  ax.scatter(mus[:, 0], mus[:, 1], c=labels, s=1, cmap=cmap)
  cb = color_index(num_labels, cmap)
  ratio = 1.0
  left, right = ax.get_xlim()
  low, hi = ax.get_ylim()
  ax.set_aspect(abs((right - left) / (low - hi)) * ratio)
  ax.set_xlabel(r'$\mu_z(x)_0$')
  ax.set_ylabel(r'$\mu_z(x)_1$')
  ax.set_title("Latent posterior mean given data")
  fig.set_size_inches([6, 6], forward=True)
  plt.savefig("latent_posterior.png", bbox_inches='tight')


# def lin_interpolate(params, train_data, train_labels, examples):
#   """
#   Args:
#     params: List[Tuple(W, b)] for each layer in NN
#     train_data, train_labels = (10k, 784), (10k, 10)
#     examples: List[Tuple[digit 1, digit 2]] samples to interpolate
#   Examining latent variable model with continuous latent variables by 
#   linearly interpolating between latent reps (mean vecs of encodings) of two points.
#   """

#   def interpolate(za, zb, alpha):
#     """Linear interpolation z_alpha = alpha * z_a + (1-a) * z_b
#     """
#     z_alpha = alpha * za + (1 - alpha) * zb
#     return z_alpha

#   # sample 3 pairs of datas, each having a different class
#   labels_to_data = defaultdict(list)
#   # encode data and get mean vectors
#   labels = np.argmax(train_labels, axis=-1)
#   # linearly interpolate between mean vectors
#   for im, lab in tqdm(zip(train_data, labels)):
#     labels_to_data[lab].append(im)
#   print("labels to datas", labels_to_data.keys())
#   # plot Bernoulli means p(x|z_\alpha) at 10 equally spaced points
#   data_ = onp.zeros([3 * 28, 10 * 28])
#   # plot generative distribution along linear interpolation
#   for row, pair in enumerate(examples):
#     datas = [labels_to_data[pair[0]][0], labels_to_data[pair[1]][0]]
#     datas = np.stack(datas)
#     mus, log_sigmas = vmap(encoder, in_axes=(0, None))(datas, params['enc'])
#     alphas = np.linspace(0, 1, 10)[::-1]
#     interpolated_means = [interpolate(mus[0], mus[1], a) for a in alphas]
#     interpolated_means = np.stack(interpolated_means)
#     bern_mus = sigmoid(
#         vmap(decoder, in_axes=(0, None))(interpolated_means, params['dec']))
#     bern_ims = bern_mus.reshape([-1, 28, 28])
#     print("bern ims", bern_ims.shape)
#     for col in range(10):
#       data_[row * 28:(row + 1) * 28, col * 28:(col + 1) *
#              28] = bern_ims[col, ...]

#   fig, ax = plt.subplots()
#   plt.imshow(data_, cmap=plt.cm.binary)
#   plt.axis('off')
#   plt.savefig('interpolated_means.png', bbox_inches='tight')

In [28]:
############ STOCHASTIC VARIATIONAL INFERENCE #############

def top_half(x):
  """
  Args:
    x: data
  Return: 
    top half of 28x28 data array.
  """
  assert x.shape == (28, 28)
  return x[:14, :]


def log_like_top_half(x, z, params):
  """
  Args:
    z: latent vector
    x: data
    params: decoder parameters
  Return: 
    log p(top half of data x | z) integrated out exactly for
      all unobserved dimensions of x are leaf nodes since ll factorizes
  """
  x = x.reshape([28, 28])
  mu_logits = decoder(z, params)  # unnormalized_logprob
  mu_data = mu_logits.reshape([28, 28])
  data_top_half = top_half(x)
  mu_top_half = top_half(mu_data)
  gaus_density = gaussian_log_density(data_top_half, mu_top_half)
  return np.sum(gaus_density)


def joint_ll_top_half(x, zs, params):
  """
  Args:
    x: data
    zs: array
    params; decoder parameters
  Return: 
    log joint density log p(z, top half data x) for each z
  """
  return log_prior(zs) + log_like_top_half(x, zs, params)


def init_var_params(subkey):
  """
  Args:
    subkey: jax key
  Return:
    Initialized variational parameters phi_mu and phi_logsigma for
    variational distribution q(z|top half of x).
  """
  return random.normal(subkey, (4,))


@jit
def elbo_half(*args, **kwargs):
  """
  ELBO estimate over K samples, batched for half of data x.
  """

  def elbo_k(x, qz_params, dec_params, subkey):
    """
    Estimate of ELBO over K samples z ~ q(z | top half of x).
    """
    mu_qz, log_sigma_qz = unpack_gaussian_params(qz_params)
    kl = -1 / 2 * np.sum(
        np.log(np.square(np.exp(log_sigma_qz))) + 1 -
        np.square(np.exp(log_sigma_qz)) - np.square(mu_qz))
    z = sample_diag_gaussian(mu_qz, log_sigma_qz, subkey)
    ll = log_like_top_half(x, z, dec_params)
    return ll - kl

  loss_ = vmap(elbo_k, in_axes=(None, None, None, 0))(*args, *kwargs)
  return np.mean(loss_)


def optimize_params(params, train_data, seed):
  """
  Args:
    params: variational and generator model parameters
    train_data: single digit from training datas.
  Return:
    Optimized phi_mu and phi_logsigma for one digit from set.
  """
  key = random.PRNGKey(seed)
  key, subkey = random.split(key)
  qz_params = init_var_params(subkey)
  grad_elbo = jit(grad(elbo_half, argnums=1))

  n = 2500
  K = 100
  lr = 0.001
  for it in tqdm(range(n)):
    key, *subs = random.split(key, K + 1)
    qz_params = qz_params + lr * grad_elbo(train_data, qz_params,
                                           params['dec'], np.stack(subs))
    if it == 0 or (it + 1) % 100 == 0:
      loss_ = elbo_half(train_data, qz_params, params['dec'], np.stack(subs))
      tqdm.write(f"Iteration {it} \t | \t ELBO {loss_:.3f}")

  return qz_params


def joint_isocountors(params, qz_params, train_data):
  """
  Args:
    params: variational (encoder) and generator network parameters
    qz_params: approximate posterior optimizer parameters

  Plot isocontours of joint distribution p(z, top half of data x) and 
  optimized approximate posterior q_phi (z | top half of data x).
  """

  def plt_isocontours(ax,
                      fn,
                      xlim=[-6, 6],
                      ylim=[-6, 6],
                      numticks=101,
                      colors=None,
                      levels=10):
    """Plot isocountours of distributions."""
    x = onp.linspace(*xlim, num=numticks)
    y = onp.linspace(*ylim, num=numticks)
    X, Y = onp.meshgrid(x, y)
    inputs = onp.concatenate(
        [onp.atleast_2d(X.ravel()),
         onp.atleast_2d(Y.ravel())])
    zs = onp.array(fn(inputs.T))
    Z = zs.reshape(X.shape)
    cs = plt.contour(X, Y, Z, colors=colors, levels=levels)
    plt.clabel(cs, inline=1, fontsize=10, fmt='%.2g')

  fig = plt.figure(figsize=(8, 8), facecolor='white')
  ax = fig.add_subplot(111, frameon=False)
  plt_isocontours(
      ax,
      lambda z: vmap(joint_ll_top_half, in_axes=(None, 0, None))
      (train_data, z, params['dec']),
      colors='g')
  plt_isocontours(
      ax,
      lambda z: vmap(diag_gaussian_log_density, in_axes=(0, None, None))
      (z, *unpack_gaussian_params(qz_params)),
      colors='b')
  plt.grid()
  plt.xlabel(r"$z_0$")
  plt.ylabel(r"$z_1$")
  lines = [
      mpl.lines.Line2D([0], [0], color='g'),
      mpl.lines.Line2D([0], [0], color='b')
  ]
  plt.title(r'Isocountours of $\log p$ and $\log q$ posteriors')
  ax.legend(lines, ['true log posterior p', 'variational log posterior q'])
  plt.tight_layout(rect=(0, 0, 1, 1))
  plt.savefig('isocountours.png')


# def infer_bottom_half(params, qz_params, train_data, seed=412):
#   """
#   Args:
#     params: decoder
#     qz_params: variational optimized posterior params
#     train_data: single digit trained on

#   Plots original whole data beside inferred greyscale.
#   """
#   key = random.PRNGKey(seed)
#   key, subkey = random.split(key)
#   # sample z ~ approximate posterior q, feed it to decoder to find
#   # Bernoulli means of p(bottom half of data | x).
#   z = sample_diag_gaussian(*unpack_gaussian_params(qz_params), subkey)
#   x = sigmoid(decoder(z, params['dec']))

#   data_ = onp.zeros((28, 28))
#   data_[:14, :] = train_data.reshape([28, 28])[:14, :]  # original top half
#   data_[14:, :] = x.reshape([28, 28])[14:, :]  # inferred bottom half

#   plt_im = onp.zeros((28, 28 * 2))
#   plt_im[:, :28] = data_
#   plt_im[:, 28:] = train_data.reshape([28, 28])

#   fig, ax = plt.subplots()
#   plt.imshow(plt_im, cmap=plt.cm.binary)
#   plt.axis('off')
#   plt.savefig('frankenstein_bottom_to_top.png', bbox_inches='tight')


if __name__ == '__main__':
  train_data, test_data = train_input, test_input

  # change the seed
  seed = 412
  num_samples = 10
  train(train_data, test_data, 'params.pkl', seed)

  # plot samples form generative model
  opt_params = load_params('params.pkl')
  sample_gen(opt_params, num_samples, seed)
  latent_means(opt_params, train_data, train_labels)
  interpolate_ex = [(1, 2), (3, 8), (4, 5)]
  lin_interpolate(opt_params, train_data, train_labels, interpolate_ex)

  # non-amortized inference (we are selecting one good sample)
  select_im_good = train_data[1]
  qz_params = optimize_params(opt_params, select_im_good, seed)
  joint_isocountors(opt_params, qz_params, select_im_good)
  infer_bottom_half(opt_params, qz_params, select_im_good)

  0%|          | 0/8 [00:00<?, ?it/s]]
  0%|          | 0/10 [00:00<?, ?it/s]

subkeys shape:  (10, 2)





TypeError: dtype torch.float32 not understood