<a href="https://colab.research.google.com/github/probabll/mixed-rv-vae/blob/master/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#%load_ext autoreload
#%autoreload 2

On Colab, you will need to clone and install [probabll/dists.pt](https://github.com/probabll/dists.pt.git)

In [None]:
#!git clone https://github.com/probabll/dists.pt.git

In [None]:
#cd dists.pt

In [None]:
#!git pull

In [None]:
#pip install .

In [None]:
#cd ..

# Mixed RVs via Mixture of Dirichlets

In this notebook I develop a VAE whose latent code is a mixed rv on $\Delta_{K-1}$, that is, a sparse probability vector, via the so called *stratified representation*, that is, as a finite mixture of distributions on the faces of $K-1$-dimensional simplex.





## Stratified Representation


Let $Z$ be an rv taking on values in the simplex $\Delta_{K-1}$. The density assigned to an outcome $y \in \Delta_{K-1}$ is given by

\begin{align}
  p_{Y}(y|\alpha, \omega) &= \sum_{f} p_F(f|\omega) p_{Y|F}(y|f, \alpha) \\
  &= \sum_f \mathrm{Cat}(f|\omega) \mathrm{Dir}(y|\alpha_f)
\end{align}

where $f$ is a face of dimensionality $\mathrm{dim}(f)$ and we choose a $\mathrm{dim}(f)$-dimensional Dirichlet distribution for $Y|F=f$. There are $2^K-1$ non-empty faces, thus $\omega \in \Delta_{2^K-2}$. An efficient parameterisation of the Categorical distribution is possible (see notebook on using FSA to represent discrete distributions), but in this implementation we will assume the parameter $\omega$ can be stored/predicted explicitly. Similarly, the parameters $\alpha_f \in \mathbb R_{>0}^{\mathrm{dim}(f)}$ can be stored or predicted for all faces.

**Gradient estimation**

We assume samples from $Y|F=f, \alpha$ can be reparameterised in terms of samples from a fixed distribution $\Phi$ independent of $\alpha$, i.e. $\epsilon = \mathcal T^{-1}(y; \alpha_f) \sim \Phi$, 
and use score function estimation for gradient estimation with respect to $\omega$:

\begin{align}
    \nabla_{\alpha,\omega} \mathbb E_{Y|\alpha, \omega}\left[ \psi(y) \right] &= \nabla_{\omega} \mathbb E_{F|\omega}\left[ \nabla_{\alpha} \mathbb E_{Y|F,\alpha} \left[ \psi(y) \right] \right] \\
    &= \mathbb E_{F|\omega}\left[ \left( \nabla_{\alpha} \mathbb E_{Y|F,\alpha} \left[ \psi(y) \right] \right)\nabla_{\omega} \log p_F(f) \right] \\
    &= \mathbb E_{F|\omega}\left[  \mathbb E_{} \left[ \nabla_{y} \psi(y) \nabla_{\alpha_f} \mathcal T(\epsilon; \alpha_f) \right] \nabla_{\omega} \log p_F(f) \right]
\end{align}

For variance reduction we can use baselines based on summary statistics (e.g., mean, standard deviation), on trained MLPs, or based on $\nabla_{\alpha} \mathbb E_{Y|F,\alpha}[ \psi(y')]$ assessed at an additional sample $y'$ (this is known as  a self-critic). More sophisticated techniques are possible (but hopefully won't be needed).

In [None]:
import torch
torch.__version__

In [None]:
import numpy as np
import torch
import torch.distributions as td
import probabll.distributions as pd

# VAE

A variational auto-encoder approximates the distribution of $X$ on sample space $\mathcal X$ as a marginal of a joint distribution $X,Y|\theta$ whose pdf factorises $p_{XY}(x, y|\theta) = p_{Y}(y|\theta)p_{X|Y}(x|y, \theta)$. Inference employs an independently parameterised  approximation $Y|X=x, \lambda$ to the true posterior distribution $Y|X=x, \theta$. The parameters $\theta$ of the generative model and $\lambda$ of the inference model are estimated jointly to maximise a lowerbound on the log-evidence (the ELBO):

\begin{align}
\theta^\star, \theta^\star &= \arg\,\max_{\theta, \lambda} ~ \sum_{x \sim \mathcal D}\mathrm{ELBO}_x(\theta, \lambda) \\
&= \arg\,\max_{\theta, \lambda} ~ \sum_{x \sim \mathcal D} \mathbb E_{Y|X,\lambda}\left[ \log \frac{p_{XY}(x, y|\theta)}{q_{Y|X}(y|x, \lambda)} \right]
\end{align}

For us, $Y$ is a mixed rv, and we will use the stratified representation explained above. And, in this notebook, an observed data point $x$ is an MNIST digit, i.e., $x\in \{0, 1\}^{H\times W}$.

In [None]:
import torch.nn as nn


def assert_shape(t, shape, message):
    assert t.shape == shape, f"{message} has the wrong shape: got {t.shape}, expected {shape}"        


## Generative model

\begin{align}
p_{XYF}(x, y, f|\theta) &= p_{F}(f)p_{Y|F}(y|f, \theta)p_{X|Y}(x|y, \theta) \\
&= \mathrm{Gibbs}(f|\mathbf 0_K) \times \mathrm{Dir}(y|\mathbf 1_{\mathrm{dim}(f)}) \times \prod_{d=1}^D \mathrm{Bern}(x_d|b_d(y; \theta))
\end{align}

where the Gibbs distribution is constrained to supporting all but the empty face of $\Delta_{K-1}$, and $\mathbf b(x, y;\theta) \in (0,1)^D$ is the output of an NN with parameters $\theta$.


In [None]:
class GenerativeModel(nn.Module):

    def __init__(self, latent_dim, data_dim, hidden_dec_size, p_drop=0.0):
        super().__init__()
        self._latent_dim = latent_dim    
        self._data_dim = data_dim   
        self._y_to_logits = nn.Sequential(
            nn.Dropout(p_drop),
            nn.Linear(latent_dim, hidden_dec_size),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dec_size, hidden_dec_size),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dec_size, data_dim),
        )
        self.register_buffer("_prior_scores", torch.zeros(latent_dim, requires_grad=False))

    @property
    def data_dim(self):
        return self._data_dim
    
    @property
    def latent_dim(self):
        return self._latent_dim

    def F(self, predictors=None):
        """
        :param predictors: input predictors, this is reserved for future use
        """
        return pd.NonEmptyBitVector(scores=self._prior_scores)

    def Y(self, f, predictors=None):
        """
        :param f: face-encoding [batch_size, K]
        :param predictors: input predictors, this is reserved for future use
        """
        return pd.MaskedDirichlet(f.bool(), torch.ones_like(f))

    def X(self, y, predictors=None):
        logits = self._y_to_logits(y)
        return td.Independent(td.Bernoulli(logits=logits), 1)  
    
    def sample(self, sample_shape=torch.Size([])):
        # [sample_shape, K]
        f = self.F().sample(sample_shape)
        # [sample_shape, K]
        y = self.Y(f).sample()
        # [sample_shape, D]
        x = self.X(y).sample()
        return f, y, x
    
    def log_prob(self, f, y, x, per_bit=False):
        if per_bit:
            return self.F().log_prob(f), self.Y(f).log_prob(y), self.X(y).base_dist.log_prob(x)
        else:
            return self.F().log_prob(f), self.Y(f).log_prob(y), self.X(y).log_prob(x)
    

## Inference model

\begin{align}
q_{YF|X}(y,f|x,\lambda) &= q_{F|X}(f|x, \lambda) q_{Y|FX}(y|f,x, \lambda) \\
&= \mathrm{Gibbs}(f|\mathbf s(x; \lambda)) \times \mathrm{Dir}(y|\mathbf a(x, f;\lambda))
\end{align}

where $\mathbf s(x; \lambda) \in \mathbb R^K$ is a vector of scores predicted by an NN, and $\mathbf a(x, f; \lambda) \in \mathbb R_{>0}^{\mathrm{dim}(f)}$ is a vector of concentrations predicted by an NN; $\lambda$ denotes the NN parameters.

In [None]:
class InferenceModel(nn.Module):

    def __init__(self, latent_dim, data_dim, hidden_enc_size, cond='fx', p_drop=0.0):
        super().__init__()

        assert cond in {'f', 'x', 'fx'}, f"The concentration net can take 'f', 'x', or 'fx', got {cond}"
        self._cond = cond

        self._scores_net = nn.Sequential(
            nn.Dropout(p_drop),
            nn.Linear(data_dim, hidden_enc_size),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_enc_size, hidden_enc_size),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_enc_size, latent_dim)
        )
        if cond == 'f':
            self._concentrations_net = nn.Sequential(
                nn.Dropout(p_drop),
                nn.Linear(latent_dim, latent_dim),
                nn.ReLU(),
                nn.Dropout(p_drop),
                nn.Linear(latent_dim // 2, latent_dim // 2),
                nn.ReLU(),
                nn.Dropout(p_drop),
                nn.Linear(latent_dim // 2, latent_dim),
                nn.Softplus()
            )
        else: # x or fx
            self._concentrations_net = nn.Sequential(
                nn.Dropout(p_drop),
                nn.Linear(data_dim if cond == 'f' else latent_dim + data_dim, hidden_enc_size),
                nn.ReLU(),
                nn.Dropout(p_drop),
                nn.Linear(hidden_enc_size, hidden_enc_size),
                nn.ReLU(),
                nn.Dropout(p_drop),
                nn.Linear(hidden_enc_size, latent_dim),
                nn.Softplus()
            )
            
    def F(self, x, predictors=None):
        # [B, K]
        scores = self._scores_net(x) 
        # constrain scores?
        # e.g., by clipping?
        # 2.5 + tanh(NN(f,x)) * 2.5 + eps
        return pd.NonEmptyBitVector(scores)

    def Y(self, f, x, predictors=None):
        if self._cond == 'f':
            inputs = f  # [...,K]
        else:
            if len(f.shape) < len(x.shape):
                raise ValueError(f"f is missing dimensions: f has shape {f.shape} and x has shape {x.shape}")
            elif len(f.shape) > len(x.shape): 
                # deal with f having a larger sample_shape than x
                sample_dims = len(f.shape) - len(x.shape)
                sample_shape = f.shape[:sample_dims] 
                x = x.view((1,) * sample_dims + x.shape).expand(sample_shape + (-1,) * len(x.shape))
            if self._cond == 'x':
                inputs = x  # [...,D]
            else:
                assert f.shape[:-1] == x.shape[:-1], "f and x have different sample/batch shapes"
                inputs = torch.cat([f, x], -1)  # [...,K+D]
        # [...,K]
        concentration = self._concentrations_net(inputs) 
        # constrain concentration?
        # e.g., by clipping?
        # 2.5 + tanh(NN(f,x)) * 2.5 + eps
        return pd.MaskedDirichlet(f.bool(), concentration)
    
    def sample(self, x, sample_shape=torch.Size([])):
        """No gradients through this"""
        with torch.no_grad():
            # [sample_shape, B, K]
            f = self.F(x).sample(sample_shape)
            # [sample_shape, B, K]
            y = self.Y(f, x).sample()
            return f, y
    
    def log_prob(self, f, y, x):
        return self.F(x).log_prob(f), self.Y(f, x).log_prob(y)

## ELBO

For a single observation $x$, our ELBO corresponds to:

\begin{align}
\mathrm{ELBO}_x(\theta, \lambda) 
&=  \mathbb E_{F|X=x,\lambda} \left[ \mathbb E_{Y|X=x,\lambda}\left[ \log p_{X|Y}(x| y,\theta) \right] - \mathrm{KL}(Y|X=x,F=f,\lambda || Y|F=f,\theta)\right]\\
&- \mathrm{KL}(F|X=x,\lambda \mid\mid F|\theta )
\end{align}

A gradient estimator for $\theta$ can be obtained via
\begin{align}
\nabla_\theta \mathrm{ELBO}_x(\theta, \lambda)&\approx \nabla_\theta \log p_{X|Y}(x|y, \theta) \\
&- \nabla_\theta \mathrm{KL}(Y|X=x,F=f,\lambda || Y|F=f,\theta) \\
&- \nabla_\theta \mathrm{KL}(F|X=x,\lambda \mid\mid F|\theta )
\end{align}
with $y,f \sim Y,F|X=x, \lambda$.

A gradient estimator for $\lambda$ can be obtained via
\begin{align}
\nabla_\lambda  \mathrm{ELBO}_x(\theta, \lambda) &\approx \nabla_{\lambda} \log p_{X|Y}(x|y=\mathcal T(\epsilon; \lambda), \theta) \\
&-\nabla_\lambda \mathrm{KL}(Y|X=x,F=f,\lambda || Y|F=f,\theta) \\
&-\nabla_\lambda \mathrm{KL}(F|X=x,\lambda \mid\mid F|\theta ) \\
&+\left( \log p_{X|Y}(x|y, \theta) -  \mathrm{KL}(Y|X=x,F=f,\lambda || Y|F=f,\theta) \right) \nabla_\lambda \log q_{F|X}(f|x,\lambda)
\end{align}
with $f \sim F|X=x,\lambda$ and $y = \mathcal T(\epsilon; \lambda)$ for $\epsilon \sim \Phi$.

A surrogate in `pytorch` is given by 
\begin{align}
&\log p_{X|Y}(x|y, \theta) \\
&- \mathrm{KL}(Y|X=x,F=f,\lambda || Y|F=f,\theta) \\
&- \mathrm{KL}(F|X=x,\lambda \mid\mid F|\theta )\\
&+ \mathrm{detach}\left( \log p_{X|Y}(x|y, \theta) -  \mathrm{KL}(Y|X=x,F=f,\lambda || Y|F=f,\theta) \right) \log q_{F|X}(f|x,\lambda)
\end{align}
again with $f \sim F|X=x,\lambda$ and $y = \mathcal T(\epsilon; \lambda)$ for $\epsilon \sim \Phi$.

In [None]:
from collections import OrderedDict, deque


class VAE:

    def __init__(self, p, q, use_self_critic=False, use_reward_standardisation=True):
        self.p = p
        self.q = q
        self.use_self_critic = use_self_critic
        self.use_reward_standardisation = use_reward_standardisation
        self._rewards = deque([])

    def train(self):
        self.p.train()
        self.q.train()

    def eval(self):
        self.p.eval()
        self.q.eval()

    def gen_parameters(self):
        return self.p.parameters()

    def inf_parameters(self):
        return self.q.parameters()   
    
    def critic(self, x_obs, q_F=None):
        """This estimates reward (w.r.t sampling of F) on a single sample for variance reduction"""
        B, K, D = x_obs.shape[0], p._latent_dim, p._data_dim
        with torch.no_grad():
            if q_F is None:
                q_F = self.q.F(x_obs)
            # [B, K]
            f = q_F.sample() 
            assert_shape(f, (B, K), "f ~ F|X=x, \lambda")
            q_Y = self.q.Y(f, x_obs)
            # [B, K]
            y = q_Y.sample() 
            assert_shape(y, (B, K), "y ~ Y|F=f, \lambda")

            p_F = self.p.F()
            if p_F.batch_shape != x_obs.shape[:1]:
                p_F = p_F.expand(x_obs.shape[:1] + p_F.batch_shape)

            p_Y = self.p.Y(f)  # we condition on f ~ q_F         
            p_X = self.p.X(y)  # we condition on y ~ q_Y

            # [B]
            ll = p_X.log_prob(x_obs)
            # [B]
            kl_Y_given_f = td.kl_divergence(q_Y, p_Y)
            # [B]
            return ll - kl_Y_given_f
        
    def update_reward_stats(self, reward):
        """Return the current statistics and update the vector"""
        if len(self._rewards) > 1:
            avg = np.mean(self._rewards)
            std = np.std(self._rewards)
        else:
            avg = 0.0
            std = 1.0
        if len(self._rewards) == 100:
            self._rewards.popleft()
        self._rewards.append(reward.mean(0).item())
        return avg, std

    def loss(self, x_obs):
        """
        :param x_obs: [B, D]
        """
        B, K, D = x_obs.shape[0], p._latent_dim, p._data_dim

        # [B, K]
        q_F = self.q.F(x_obs)
        f = q_F.sample() # not rsample
        assert_shape(f, (B, K), "f ~ F|X=x, \lambda")
        q_Y = self.q.Y(f, x_obs)
        y = q_Y.rsample()  # with reparameterisation! (important)
        assert_shape(y, (B, K), "y ~ Y|F=f, \lambda")
        
        p_F = self.p.F()
        if p_F.batch_shape != x_obs.shape[:1]:
            p_F = p_F.expand(x_obs.shape[:1] + p_F.batch_shape)
        
        p_Y = self.p.Y(f)  # we condition on f ~ q_F         
        p_X = self.p.X(y)  # we condition on y ~ q_Y

        # ELBO: the first term is an MC estimate (we sampled (f,y))
        # the second term is exact 
        # the third tuse_self_criticis an MC estimate (we sampled f)
        ll = p_X.log_prob(x_obs)
        kl_Y_given_f = td.kl_divergence(q_Y, p_Y)
        kl_F = td.kl_divergence(q_F, p_F)
        
        # E_FY[ log p(x|y,f)] - KL(F) - E_F[ KL(Y) ]
        # E_F[ E_Y[  log p(x|y,f) ] - KL(Y) ] - KL(F)
        # E_F[ r(F) ] for r(f) = log p(x|y,f)
        # r(f).detach() * log q(f)
        reward = (ll - kl_Y_given_f).detach()
        
        # Variance reduction tricks
        if self.use_self_critic:
            criticised_reward = reward - self.critic(x_obs, q_F).detach()
        else:
            criticised_reward = reward
        
        if self.use_reward_standardisation:
            reward_avg, reward_std = self.update_reward_stats(criticised_reward)
            standardised_reward = (criticised_reward - reward_avg) / np.minimum(reward_std, 1.0)
        else:
            standardised_reward = criticised_reward
        
        # Gradient surrogates and loss
        sfe_surrogate = standardised_reward * q_F.log_prob(f)
        grep_surrogate = ll - kl_F - kl_Y_given_f
        loss = -(grep_surrogate + sfe_surrogate).mean(0)
        ret = OrderedDict(
            loss=loss.item(),
            LL=ll.mean(0).item(),
            KL_F=kl_F.mean(0).item(),
            KL_Y=kl_Y_given_f.mean(0).item(),
            SFE_reward=reward.mean(0).item()
        )
        if self.use_self_critic:
            ret['SFE_criticised_reward'] = criticised_reward.mean(0).item()
        if self.use_reward_standardisation:
            ret['SFE_standardised_reward'] = standardised_reward.mean(0).item()

        return loss, ret

    def estimate_ll(self, x_obs, num_samples):     
        with torch.no_grad():
            self.eval()
            # log 1/N \sum_{n} p(x, z_n)/q(z_n|x)
            # [N, B, K], [N, B, K]
            f, y = self.q.sample(x_obs, (num_samples,))
            # Here I compute: log p(f) + log p(y|f) + log p(x|f,y)
            # stack([N, B], [N, B], [N,B]) -> [N, B, 3], then reduce to [N, B]
            log_p = torch.stack(self.p.log_prob(f, y, x_obs), -1).sum(-1)
            # Here I compute: log q(f) + log q(y|f)
            # stack([N, B], [N, B]) -> [N, B, 2], then reduce to [N, B]
            log_q = torch.stack(self.q.log_prob(f, y, x_obs), -1).sum(-1)
            # [B]
            ll = torch.logsumexp(log_p - log_q, 0) - np.log(num_samples)                    
        return ll

    def estimate_ll_per_bit(self, x_obs, num_samples):             
        with torch.no_grad():
            # log 1/N \sum_{n} p(x, z_n)/q(z_n|x)
            # [N, B, K]
            f, y = self.q.sample(x_obs, (num_samples,))        
            # [N, B], [N, B], [N, B, D]
            log_pf, log_py, log_px = self.p.log_prob(f, y, x_obs, per_bit=True)
            # [N, B, D]
            log_p = log_pf.unsqueeze(-1) + log_py.unsqueeze(-1) + log_px
            # [N, B]
            log_q = torch.stack(self.q.log_prob(f, y, x_obs), -1).sum(-1)
            # [B, D]
            ll = torch.logsumexp(log_p - log_q.unsqueeze(-1), 0) - np.log(num_samples)                    
        return ll

# MNIST 

* Download
* Preprocess
* Batcher

## Data and Batcher

In [None]:
import pathlib
import numpy as np
import torch
import torch.utils.data
from torchvision import datasets, transforms


def boolean_argument(string):
    return str(string).lower() in {"true", "yes", "1"}


def list_argument(dtype, separator=","):
    def constructor(string):
        return [dtype(x) for x in string.split(separator)]
    return constructor

def print_digit(matrix):
    rows = []
    for i in range(matrix.size(0)):
        row = ""
        for j in range(matrix.size(1)):
            row += "x" if matrix[i,j] >= 0.5 else " "
        rows.append(row)
    return "\n".join(rows)


def load_mnist(batch_size, save_to, height=28, width=28):
    """
    :param batch_size: the dataloader will create batches of this size
    :param save_to: a folder where we download the data into    
    :param height: using something other than 28 implies a Resize transformation
    :param width: using something other than 28 implies a Resize transformation
    :return: 3 data loaders
        training, validation, test
    """
    # create directory
    pathlib.Path(save_to).mkdir(parents=True, exist_ok=True)
    
    if height == width == 28:
        transform = transforms.ToTensor()    
    else:        
        transform = transforms.Compose([
            transforms.Resize((height, width)), 
            transforms.ToTensor()]
        )

    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(
            datasets.MNIST(
                save_to,
                train=True, 
                download=True, 
                transform=transform),
            indices=range(55000)), 
        batch_size=batch_size,
        shuffle=True
    )
    valid_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(
            datasets.MNIST(
                save_to,
                train=True, 
                download=True, 
                transform=transform),
            indices=range(55000, 60000)), 
        batch_size=batch_size,
        shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            save_to,
            train=False, 
            download=True, 
            transform=transform),
        batch_size=batch_size
    )
    return train_loader, valid_loader, test_loader


class Batcher:
    """
    Deals with options such as
        * dynamic binarization
        * change to device
        * shape
        * one-hot encoding of digits
    """
    
    def __init__(self, data_loader, height, width, device, binarize=False, onehot=False, num_classes=10): 
        self.data_loader = data_loader
        self.height = height
        self.width = width
        self.device = device
        self.binarize = binarize
        self.num_batches = len(data_loader)
        self.onehot = onehot
        self.num_classes = num_classes
            
    def __len__(self):
        return self.num_batches
    
    def __iter__(self):
        """        
        Yields
            x: [B, H, W], y: [B]
        or
            x: [B, H, W], y: [B, 10]
        """
        for x, y in self.data_loader: 
            # x: [B, C=1, H, W], y: [B]
            # [B, H, W]
            x = x.reshape(x.size(0), self.height, self.width).to(self.device)
            if self.binarize:
                x = (x > torch.rand_like(x)).float()
            # [B]
            y = y.to(self.device)
            if self.onehot:
                # [B, 10]
                y = torch.nn.functional.one_hot(y, num_classes=self.num_classes)
            yield x, y
                


## Hyperparameters

In [None]:
from collections import namedtuple
cfg = dict(
    # Data
    batch_size=200,
    data_dir='tmp',
    height=28,
    width=28, 
    # CUDA
    device='cuda:0',
    # Architecture
    hidden_enc_size=500,
    hidden_dec_size=500,
    cond='fx',
    latent_dim=10,    
    # Training
    epochs=100,    
    # Evaluation
    num_samples=100,    
    # Optimisation & regularisation
    gen_lr=1e-4,
    inf_lr=1e-4,  
    gen_l2=0.0,
    inf_l2=0.0,  
    gen_p_drop=0.3,  
    inf_p_drop=0.0,  # dropout for inference model is not well understood    
    grad_clip=5.0,
    # Variance reduction
    use_self_critic=True,
    use_reward_standardisation=False,
)
args = namedtuple('Config', cfg.keys())(*cfg.values())

In [None]:
# You can skip this on Colab

# Download MNIST
#!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
#!tar -zxvf MNIST.tar.gz

In [None]:
train_loader, valid_loader, test_loader = load_mnist(
    args.batch_size, 
    save_to='./', 
    height=args.height, 
    width=args.width
)

In [None]:
def get_batcher(data_loader, args):
    batcher = Batcher(
        data_loader, 
        height=args.height, 
        width=args.width, 
        device=torch.device(args.device), 
        binarize=True, 
        num_classes=10,
        onehot=True
    )
    return batcher

## Model

In [None]:
p = GenerativeModel(
    latent_dim=args.latent_dim, 
    data_dim=args.height * args.width, 
    hidden_dec_size=args.hidden_dec_size,
    p_drop=args.gen_p_drop,
).to(torch.device(args.device))

In [None]:
q = InferenceModel(
    latent_dim=10, 
    data_dim=args.height * args.width, 
    hidden_enc_size=args.hidden_enc_size,
    cond=args.cond,
    p_drop=args.inf_p_drop,
).to(torch.device(args.device))

In [None]:
q.sample(p.sample()[-1])

In [None]:
vae = VAE(
    p, 
    q, 
    use_self_critic=args.use_self_critic, 
    use_reward_standardisation=args.use_reward_standardisation
)

## Visualisations

* Compare statistics of samples from $p$ to those of samples from $q$
* Visualise a few generations

In [None]:
from matplotlib import pyplot as plt
from collections import defaultdict

In [None]:
def bitvec2str(f, as_set=False):
    return ''.join('1' if b else '0' for b in f) if not as_set else '{' + ','.join(f'{i:1d}' for i, b in enumerate(f, 1) if b) + '}'

In [None]:
def compare_p_q(vae, batcher: Batcher, args): 
    with torch.no_grad():
        vae.eval()        
    
        prior = defaultdict(list)
        posterior = defaultdict(list)
        other = defaultdict(list)
        num_obs = 0

        # Some visualisations
        for x_obs, y_obs in batcher:
            
            # [B, H*W]
            x_obs = x_obs.reshape(-1, args.height * args.width)
            num_obs += x_obs.shape[0]

            # [B, 10]
            context = None
            
            B, K, D = x_obs.shape[0], vae.p._latent_dim, vae.p._data_dim            
            
            # [B, K]
            f = vae.p.F().expand((B,)).sample()
            y = vae.p.Y(f).sample()
            #x = vae.p.X(y).sample()
            # [B, K]
            prior['f'].append(f.cpu().numpy())
            # [B]
            prior['dim'].append(f.sum(-1).cpu().numpy())
            # [B, K]
            prior['y'].append(y.cpu().numpy())
            
            # [B, K]
            f, y = vae.q.sample(x_obs)
            # [B, K]
            posterior['f'].append(f.cpu().numpy())
            # [B]
            posterior['dim'].append(f.sum(-1).cpu().numpy())
            # [B, K]
            posterior['y'].append(y.cpu().numpy())
            #print(posterior['f'][-1].shape, posterior['dim'][-1].shape, posterior['y'][-1].shape)
            
            # [B]
            other['KL_F'].append(td.kl_divergence(q.F(x_obs), p.F().expand((B,))).cpu().numpy())
            other['KL_Y_f'].append(td.kl_divergence(q.Y(f, x_obs), p.Y(f)).cpu().numpy())
            
        # KLs
        print("For a trained VAE: ")
        print(" 1. We want to see that KL(F|x || F) and KL(Y|f,x || Y|f) is generally > 0 for any x ~ D.")
        _ = plt.hist(np.concatenate(other['KL_F'], 0), bins=20)
        _ = plt.xlabel(r'$KL( F|x,\lambda || F| \theta )$')
        plt.show()
        
        _ = plt.hist(np.concatenate(other['KL_Y_f'], 0), bins=20)
        _ = plt.xlabel(r'$KL( Y|f,x,\lambda || Y|f, \theta )$')
        plt.show()
            
        
        print(" 2. But, marginally, we expect E[F|X] ~ F and E[Y|F,X] ~ E[Y|F].")
        # Pr(F_k = 1) compared to E_X[ I[F_k = 1] ]
        _ = plt.imshow(
            np.stack([np.concatenate(prior['f'], 0).mean(0), np.concatenate(posterior['f'], 0).mean(0)]), 
            interpolation='nearest',
        )
        _ = plt.colorbar()
        _ = plt.xlabel('k')
        _ = plt.yticks([0,1], ['F', 'E[F|X]'])
        _ = plt.title(r'Marginal probability that $F_k = 1$')
        plt.show()
        
        # Y_k compared to E_X[Y_k]
        
        _ = plt.imshow(
            np.stack([np.concatenate(prior['y'], 0).mean(0), np.concatenate(posterior['y'], 0).mean(0)]), 
            interpolation='nearest'
        )
        _ = plt.colorbar()
        _ = plt.xlabel('k')
        _ = plt.yticks([0, 1], ['E[Y|F]', 'E[Y|F,X]'])
        _ = plt.title(r'Average $Y_k$')
        plt.show()
        
        _ = plt.hist(
            np.concatenate(prior['dim'], 0), 
            alpha=0.3, label='dim(F)', bins=np.arange(0, 11))
        _ = plt.hist(
            np.concatenate(posterior['dim'], 0), 
            alpha=0.3, label='E[dim(F)|X]', bins=np.arange(0, 11))
        _ = plt.ylabel(f'Count')
        _ = plt.xlabel('dim(f)')
        _ = plt.xticks(np.arange(1,11), np.arange(1,11))
        _ = plt.title(f'Distribution of dim(f)')
        _ = plt.legend()
        plt.show()


In [None]:
compare_p_q(vae, get_batcher(valid_loader, args), args)

In [None]:
def visualize(vae, batcher: Batcher, args, N=4, num_figs=1): 

    assert N <= args.batch_size, "N should be no bigger than a batch"
    with torch.no_grad():
        vae.p.eval()        
        vae.q.eval()
            
        # Some visualisations
        for r, (x_obs, y_obs) in enumerate(batcher, 1):

            plt.figure(figsize=(2*N, 2*N))
            plt.subplots_adjust(wspace=0.5, hspace=0.5)        
        
            
            # [B, H*W]
            x_obs = x_obs.reshape(-1, args.height * args.width)
            x_obs = x_obs[:N]
            # [B, 10]
            context = None
            
            B, K, D = x_obs.shape[0], vae.p._latent_dim, vae.p._data_dim            
            # marginal probability
            prob = vae.estimate_ll_per_bit(x_obs, args.num_samples).exp()            
            # posterior samples
            f, y = vae.q.sample(x_obs)
            x = vae.p.X(y).sample()
            # prior samples
            f_, y_, x_ = vae.p.sample((N,))

            for i in range(N):
                plt.subplot(4, N, 0*N + i + 1)
                plt.imshow(x_obs[i].reshape(args.height, args.width).cpu(), cmap='Greys')
                plt.title("$x^{(%d)}$" % (i+1))

                plt.subplot(4, N, 1*N + i + 1)
                plt.imshow(x[i].reshape(args.height, args.width).cpu(), cmap='Greys')
                plt.title("$p(x^{(%d)})$" % (i+1))
                
                plt.subplot(4, N, 2*N + i + 1)                
                #plt.axhline(y=args.height//2, c='red', linewidth=1, ls='--')
                plt.imshow(x[i].reshape(args.height, args.width).cpu(), cmap='Greys')
                plt.title("X,Y,F|$x^{(%d)}$" % (i+1))
                plt.xlabel(f'f={bitvec2str(f[i])}')
                
                plt.subplot(4, N, 3*N + i + 1)
                plt.imshow(x_[i].reshape(args.height, args.width).cpu(), cmap='Greys')
                plt.title("X,Y,F")
                plt.xlabel(f'f={bitvec2str(f[i])}')
                
            plt.show()

            if r == num_figs:
                break

In [None]:
visualize(vae, get_batcher(valid_loader, args), args, N=5)

## Training

In [None]:
def validate(vae, batcher, num_samples):
    """
    Return average NLL and the average number of bits per dimension.
    """
    with torch.no_grad():
        vae.eval()
        
        nb_obs = 0
        nb_bits = 0.
        ll = 0.
        for x_obs, y_obs in batcher:
            # [B, H*W]
            x_obs = x_obs.reshape(-1, vae.p.data_dim)     
            # [B]
            ll = ll + vae.estimate_ll(x_obs, num_samples).sum(0)
            nb_bits += np.prod(x_obs.shape)
            nb_obs += x_obs.shape[0]

    nll = - (ll / nb_obs).cpu()
    return nll, nll / np.log(2) / vae.p.latent_dim

In [None]:
from tqdm.auto import tqdm
from itertools import chain
from collections import defaultdict

### Optimiser

In [None]:
p_opt = torch.optim.Adam(p.parameters(), lr=args.gen_lr, weight_decay=args.gen_l2)
q_opt = torch.optim.Adam(q.parameters(), lr=args.inf_lr, weight_decay=args.inf_l2)
stats_tr = defaultdict(list)
stats_val = defaultdict(list)

### Steps

In [None]:
val_metrics = validate(vae, get_batcher(valid_loader, args), args.num_samples)
print(f'Validation {0:3d}: nll={val_metrics[0]:.2f} bpd={val_metrics[1]:.2f}')

In [None]:
for epoch in range(args.epochs):

    iterator = tqdm(get_batcher(train_loader, args))

    for x_obs, y_obs in iterator:        
        # [B, H*W]
        x_obs = x_obs.reshape(-1, args.height * args.width)
        # [B, 10]
        context = None   
        
        vae.train()      
        loss, ret = vae.loss(x_obs)

        for k, v in ret.items():
            stats_tr[k].append(v)
                
        p_opt.zero_grad()
        q_opt.zero_grad()        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(
            chain(vae.gen_parameters(), vae.inf_parameters()), 
            args.grad_clip
        )        
        p_opt.step()
        q_opt.step()
    
        iterator.set_description(f'Epoch {epoch+1:3d}')
        iterator.set_postfix(ret)
    
    val_metrics = validate(vae, get_batcher(valid_loader, args), args.num_samples)
    stats_val['val_nll'].append(val_metrics[0])
    stats_val['val_bpd'].append(val_metrics[1])
    print(f'Validation {epoch+1:3d}: nll={val_metrics[0]:.2f} bpd={val_metrics[1]:.2f}')

### Training Curves

In [None]:
np_stats_tr = {k: np.array(v) for k, v in stats_tr.items()}
np_stats_val = {k: np.array(v) for k, v in stats_val.items()}

In [None]:
def smooth(v, kernel_size=100):
    if kernel_size is None:
        return v
    return np.convolve(v, np.ones(kernel_size)/kernel_size, 'valid')

In [None]:
for k, v in np_stats_tr.items():
    v = smooth(v)
    plt.plot(np.arange(1, v.size + 1), v, '.')    
    plt.ylabel(f'Training {k}')
    plt.xlabel('iteration')
    plt.show()

### Validation Curves

In [None]:
for k, v in np_stats_val.items():    
    plt.plot(np.arange(1, v.size + 1), v, 'o')
    plt.ylabel(f'Validation {k}')
    plt.xlabel('iteration')
    plt.show()

### Analysis

In [None]:
compare_p_q(vae, get_batcher(valid_loader, args), args)

In [None]:
visualize(vae, get_batcher(valid_loader, args), args, N=5, num_figs=2)