Implementation of REBAR (https://arxiv.org/abs/1703.07370), a low-variance, unbiased gradient estimator for discrete latent variable models. This notebook is focused on the generative modeling experiments on the MNIST and Omniglot datasets from Section 5.2.1.

The problem being solved is $\text{max} \hspace{5px} \mathbb{E} [f(b, \theta) | p(b) ]$, $b$ ~ Bernoulli($\theta$).

For generative modeling, the objective is to maximize a single-sample variational lower bound on the log-likelihood. There are two networks, one to model $q(b|x,\theta)$ and one to model $p(x,b|\theta)$. The former is the variational distribution and the latter is the joint probability distribution over the data and latent stochastic variables $b$.

The **ELBO**, or evidence lower bound which we seek to maximize, is: 

$$
\log p(x \vert \theta) \geq \mathbb{E}_{q(b \vert x,\theta)} [ \log p(x,b\vert\theta) - \log q(b \vert x,\theta)]
$$

In practice, the Q-network has its own set of parameters $\phi$ and the generator network $P$ has its own parameters $\theta$.

I'll refer to the learning signal $\log p(x,b\vert\theta) - \log q(b \vert x,\theta)$ as $l(x,b)$ for shorthand.

The following is an implementation of a Sigmoid Belief Network (SBN) with REBAR gradient updates. I tried to follow the [author's TensorFlow implementation](https://github.com/tensorflow/models/blob/master/research/rebar/rebar.py) closely for correctness; there are a lot of computational statistics stuff going on that need to be implemented carefully.

For an in-depth treatment on SBNs, see [this paper](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.63.1777&rep=rep1&type=pdf) by R. Neal.

We're just going to focus on the nonlinear SBN REBAR model.
The model is pretty complex, so I'll implement them as separate modules and try to explain them
one by one.

In [16]:
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.autograd import grad
import matplotlib.pyplot as plt
import numpy as np

import rebar.datasets

import rebar.util as U

%matplotlib inline

In [17]:
# Some global parameters we'll need later
hparams = {
    'model': 'SBNGumbel',
    'learning_rate':3e-4,
    'n_hidden':200,
    'n_input':784,
    'temperature':0.5,
    'batch_size':24,
    'task':'sbn',
    'n_layers': 2
}

We'll define samplers for producing the "hard" and "soft" reparameterized samples needed for computing the REBAR gradient.

In [18]:
def random_sample(log_alpha, u, layer, uniform_samples_v, _):
    """Returns sampled random variables parameterized by log_alpha."""
    # Generate tied randomness for later
    if layer not in uniform_samples_v:
        uniform_samples_v[layer] = u_to_v(log_alpha, u)
        
    # Sample random variable underlying softmax/argmax
    x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u)
    samples = ((x > 0).float()).detach()

    return {
        'preactivation': x,
        'activation': samples,
        'log_param': log_alpha,
    }, uniform_samples_v

def random_sample_soft(log_alpha, u, layer, uniform_samples_v, temperature=None):
    """Returns sampled random variables parameterized by log_alpha."""

    # Sample random variable underlying softmax/argmax
    x = log_alpha + U.safe_log_prob(u) - U.safe_log_prob(1 - u)
    x /= temperature.view(-1)
    y = F.sigmoid(x)

    return {
        'preactivation': x,
        'activation': y,
        'log_param': log_alpha
    }, uniform_samples_v

def random_sample_soft_v(log_alpha, _, layer, uniform_samples_v, temperature=None):
    """Returns sampled random variables parameterized by log_alpha."""
    v = uniform_samples[layer]
    return random_sample_soft(log_alpha, v, layer, uniform_samples_v, temperature)

This next bit, for producing common random numbers, is for variance reduction. [The general idea is easy](https://en.wikipedia.org/wiki/Variance_reduction), but what the authors were doing here is a bit more subtle. According to Appendix G.2, they're correlating u and v to reduce the variance of the gradient by first sampling u and then using that to determine v.

In [12]:
 # Random samplers TODO
def u_to_v(log_alpha, u, eps = 1e-8):
    """Convert u to tied randomness in v."""
    u_prime = F.sigmoid(-log_alpha)  # g(u') = 0

    v_1 = (u - u_prime) / torch.clamp(1 - u_prime, eps, 1)
    v_1 = torch.clamp(v_1, 0, 1).detach()
    v_1 = v_1*(1 - u_prime) + u_prime
    v_0 = u / torch.clamp(u_prime, eps, 1)
    v_0 = torch.clamp(v_0, 0, 1).detach()
    v_0 = v_0 * u_prime
    v = v_1 if(u > u_prime) else v_0
    # TODO: add pytorch check
    #v = tf.check_numerics(v, 'v sampling is not numerically stable.')
    v = v + (-v + u).detach()  # v and u are the same up to numerical errors
    return v

This is the deterministic mapping we'll use to construct the stochastic layers of the Q- and P-networks.

In [13]:
class Transformation(nn.Module):
    """
    Deterministic transformation between stochastic layers
    
        x -> FC -> Tanh -> FC -> Tanh() -> FC -> logQ
    
    """
    def __init__(self, n_input, n_hidden, n_output):
        super(Transformation, self).__init__()
        self.h = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_output))
    
    def forward(self, x):
        return self.h(x)

The RecognitionNet is the variational distribution (Q-network) and the GeneratorNet is the joint distribution of the data and latent variables (P-network). It looks like this for an unrolled 2-layer SBN, where Sample is the stochastic layer of Bernoulli units:

// Replace with figure?

x -> Transformation(x) -> Sample(x) -> Transformation(x) -> Sample(x)

In [7]:
class RecognitionNet(nn.Module):
    """
    given x values, samples from Q and returns log Q(h|x)
    """
    def __init__(self, mean_xs, uniform_samples, sampler):
        super(RecognitionNet, self).__init__()
        self.mean_xs = mean_xs
        self.uniform_samples = uniform_samples
        self.sampler = sampler
        self.transforms = nn.ModuleList([Transformation(hparams.n_input,
                                        hparams.n_hidden, hparams.n_hidden)
                                         for _ in range(hparams.n_layers)])
        self.uniform_samples = dict()
        self.uniform_samples_v = dict()
        # generate randomness
        for i in range(self.hparams.n_layers):
            self.uniform_samples[i] = Variable(
                torch.FloatTensor([self.batch_size, self.hparams.n_hidden]).uniform_(0,1), requires_grad=False)
            
    def forward(self, x, sampler_=None):
        if sampler_ is not None:
            sampler = sampler_
        else:
            sampler = self.sampler
        samples = {}
        samples[-1] = {'activation': x}
        samples[-1] -= mean_xs
        samples[-1] = (samples[-1] + 1)/2.
        logQ = []
        for i,t in enumerate(self.transforms):
            input = 2 * samples[i-1]['activation'] - 1.0
            logits = t(input)
            # expect sampler to return a dictionary with key 'activation'
            samples[i], self.uniform_samples_v = sampler(logits, self.uniform_samples[i],
                                                         i, self.uniform_samples_v)
            logQ.append(U.binary_log_likelihood(samples[-1], logits))  
        # logQHard, samples
        return logQ, samples

class GeneratorNet(nn.Module):
    """
    Returns learning signal and function. Reconstructs the input.
    """
    def __init__(self, mean_xs):
        self.transforms = nn.ModuleList()
        for i in range(hparams.n_layers):
            if i == 0:
                n_output = hparams.n_input
            else:
                n_output = hparams.n_hidden
            self.transforms.append(Transformation(hparams.n_input,
                                                 hparams.n_hidden, n_output))
        self.prior = Variable(torch.zeros(hparams.n_hidden), requires_grad=False)
        self.train_bias= -np.log(1./np.clip(mean_xs, 0.001, 0.999)-1.).astype(np.float32)
        
    def forward(self, x, samples, logQ):
        """
        Args:
            samples: dictionary of sampled latent variables
            logQ: list of log q(h_i) terms
        """
        logPPrior = U.binary_log_likelihood(samples[hparams.n_layers-1], self.prior)
        for i in reversed(range(hparams.n_layers)):
            # Set up the input to the layer
            input = 2 * samples[i]['activation'] - 1.0
            h = self.transforms[i](input)
            if i == 0:
                logP = U.binary_log_likelihood(x, h + self.train_bias)
            else:
                logPPrior += U.binary_log_likelihood(samples[i-1], h)
        # Note that logP(x,b) = logP(b|x) + logP(x)
        # reinforce_learning_signal (l(x,b)), reinforce_model_grad
        return logP + logPPrior - torch.sum(logQ), logP + logPPrior         

Now we can put these modules together inside the SBNRebar module

In [20]:
class SBNRebar(nn.Module):
    def __init__(self, mean_xs):
        super(SBN, self).__init__()
        self.mean_xs     
        # For centering the learning signal, from the NVIL paper (2.3.1) https://arxiv.org/pdf/1402.0030.pdf
        # Input dependent baseline that is trained to minimize the MSE with the learning signal
        self._baseline = nn.Sequential(
           nn.linear(hparams.n_inputs, 100),
           nn.Tanh(),
           nn.linear(100, 1))
        self._temperature = Variable(torch.FloatTensor(hparams.temperature), requires_grad=False)
        self._recognition_network = RecognitionNet(mean_xs, random_sample)
        self._generator_network = GeneratorNet(mean_xs) 
                  
    def forward(self, x):
        """
        All of the passes through the Q- and P-networks are here
        """
        # hardELBO is the non-differentiable learning signal, l(x,b)
        #
        # reinforce_model_grad is the joint distribution of interest p(x,b,\theta), 
        #   and the gradient of l(x,b) wrt the P-network params is grad E[logP + logPPrior]  
        #   = grad E[reinforce_model_grad]
        logQHard, hardSamples = self._recognition_network(x)
        hardELBO, reinforce_model_grad = self._generator_network(x, hardSamples, logQHard)
        logQHard = torch.sum(logQHard)
        baseline = self._baseline(x)
        
        # compute Gumbel control variate
        logQ, softSamples = self._recognition_network(sampler=functools.partial(
            random_sample_soft, temperature=self._temperature))
        softELBO, _ = self._generator_network(x, softSamples, logQ)
        logQ = torch.sum(logQ)
        
        # compute softELBO_v (same value as softELBO, different grads) :- zsquiggle = g(v, b, \theta)
        # NOTE: !!!Super Tricky!!! Because of the common random numbers (u_to_v), z is distributed as z|b. 
        # So the reparameterization for p(z|b) is just g(v,b,\theta) == g(v,\theta) == log(\theta/1-\theta) + log(v/1-v)
        # This is why random_sample_soft_v() just calls random_sample_soft()        
        logQ_v, softSamples_v = self._recognition_network(sampler=functools.partial(
            random_sample_soft_v, temperature=self._temperature))
        softELBO_v, _ = self._generator_network(x, softSamples_v, logQ_v)
        logQ_v = torch.sum(logQ_v)
        
        gumbel_cv_learning_signal = softELBO_v.detach()
        
         # Gumbel CV
        gumbel_cv = gumbel_cv_learning_signal * logQHard - softELBO + softELBO_v
        
        return {
            'logQHard': logQHard,
            'hardELBO': hardELBO,
            'reinforce_model_grad': reinforce_model_grad,
            'baseline': baseline,
            'gumbel_cv': gumbel_cv
        }       

Next, we will compute the REBAR gradients using PyTorch's autograd functionality


In [None]:
def get_rebar_gradient(sbn_outs):
    nvil_gradient = (sbn_outs['hardELBO'].detach() - sbn_outs['baseline']) * sbn_outs['logQHard'] 
        + sbn_outs['reinforce_model_grad']

One of the main concepts to understand is the fact that the Concrete relaxation is applied to the discrete RV b ~ Bernoulli($\theta$), s.t. b = H(z) where H is the heaviside function and z ~ Gumbel.

Now, we can initialize a few things.

In [None]:
random_seed = 1337
torch.manual_seed(random_seed)
# hyperparams
rebar_eta_z = 0.1
rebar_eta_zb = 0.1
rebar_lamda=0.5
concrete_lamda = 0.5
batch_size = 128
train_steps = 8000

# Initialize three models to compare the REINFORCE, Concrete(0.5), and REBAR estimators
reinforce = SimpleBernoulli()
concrete = SimpleBernoulli()
rebar = SimpleBernoulli()
reinforce_opt = optim.Adam(reinforce.parameters(), lr=1e-3)
concrete_opt = optim.Adam(concrete.parameters(), lr=1e-3)
rebar_opt = optim.Adam(rebar.parameters(), lr=1e-3)
mse = nn.MSELoss()
# labels
targets = Variable(torch.FloatTensor([0.45]).repeat(batch_size), requires_grad=False)

Now for the main training loop, where most of the REBAR magic happens:

In [None]:
reinforce_loss = []
concrete_loss = []
rebar_loss = []
for i in range(train_steps):
    # For each iteration of the loop, we will compute a
    # single-sample MC estimate of the gradient
    # Get the latest estimate of $\theta$ copy it to form a minibatch
    reinforce_theta = reinforce.forward().repeat(batch_size)
    concrete_theta = concrete.forward().repeat(batch_size)
    rebar_theta = rebar.forward().repeat(batch_size)
    
    # sample batch_size pairs of Unif(0,1). You're supposed to couple u,v
    # to do the reparameterizations, but we omit that for this toy problem
    uv = Variable(torch.FloatTensor(2, batch_size).uniform_(0, 1), requires_grad=False)
    u = uv[0] + 1e-9 # for numerical stability
    v = uv[1] + 1e-9 # for numerical stability
    
    ########## First, we'll compute the REINFORCE estimator ##########
    
    # Lets record where the loss is at currently
    discrete_reinforce_preds = torch.bernoulli(reinforce_theta.detach())
    reinforce_loss.append(mse(discrete_reinforce_preds, targets).data.numpy())
    
    # Now, the REINFORCE estimator (Eq. 2 of the paper, beg. of Section 3)
    reinforce_z = reparam_pz(u, reinforce_theta)
    reinforce_Hz = H(reinforce_z) # this is the non-differentiable reparameterization
    # evaluate f
    reinforce_f_Hz = (reinforce_Hz - targets) ** 2
    # This is  d_log_P(b) / d_$\theta$
    grad_logP = grad(binary_log_likelihood(reinforce_Hz, \
                     torch.log(reinforce_theta)).split(1), reinforce_theta)[0]
    
    # Apply the Monte-carlo REINFORCE gradient estimator
    reinforce_grad_est = (reinforce_f_Hz * grad_logP).mean()
    reinforce_opt.zero_grad()
    reinforce.theta.grad = reinforce_grad_est
    reinforce_opt.step()
    
    ########## Next up, the Concrete(0.5) estimator ##########
    
    discrete_concrete_preds = torch.bernoulli(concrete_theta.detach())
    concrete_loss.append(mse(discrete_concrete_preds, targets).data.numpy())

    # Now, the Concrete(0.5) estimator. We compute the continuous relaxation of
    # the reparameterization and use that.. (end of Section 2 of the paper)
    concrete_z = reparam_pz(u, concrete_theta)
    soft_concrete_z = F.sigmoid(concrete_z / concrete_lamda) + 1e-9
    # evaluate f
    f_soft_concrete_z = (soft_concrete_z - targets) ** 2
    grad_f = grad(f_soft_concrete_z.split(1), concrete_theta)[0]

    # Apply the Monte-carlo Concrete gradient estimator
    concrete_grad_est = grad_f.mean()
    concrete_opt.zero_grad()
    concrete.theta.grad = concrete_grad_est
    concrete_opt.step()
    
    ########## Finally, we tie it all together with REBAR ##########
    
    discrete_rebar_preds = torch.bernoulli(rebar_theta.detach())
    rebar_loss.append(mse(discrete_rebar_preds, targets).data.numpy())

    # We compute the continuous relaxation of the reparameterization 
    # as well as the REINFORCE estimator and combine them.

    rebar_z = reparam_pz(u, rebar_theta)
    # "hard" bc this is non-differentiable
    hard_concrete_rebar_z = H(rebar_z)
    # We also need to compute the reparam for p(z|b) - see the paper
    # for explanation of this conditional marginalization as control variate
    rebar_zb = reparam_pz_b(v, hard_concrete_rebar_z, rebar_theta)
    # "soft" relaxations
    soft_concrete_rebar_z = F.sigmoid(rebar_z / rebar_lamda) + 1e-9 
    soft_concrete_rebar_zb = F.sigmoid(rebar_zb / rebar_lamda) + 1e-9
    # evaluate f
    f_hard_concrete_rebar_z = (hard_concrete_rebar_z - targets) ** 2
    f_soft_concrete_rebar_z = (soft_concrete_rebar_z - targets) ** 2
    f_soft_concrete_rebar_zb = (soft_concrete_rebar_zb - targets) ** 2
    # compute the necessary derivatives
    grad_logP = grad(binary_log_likelihood(hard_concrete_rebar_z, \
                     torch.log(rebar_theta)).split(1), rebar_theta, retain_graph=True)[0]
    grad_sc_z = grad(f_soft_concrete_rebar_z.split(1), rebar_theta, retain_graph=True)[0]
    grad_sc_zb = grad(f_soft_concrete_rebar_zb.split(1), rebar_theta)[0]
    
    # Notice how we combine the REINFORCE and concrete estimators
    rebar_grad_est = (((f_hard_concrete_rebar_z - rebar_eta_zb * f_soft_concrete_rebar_zb) \
                       * grad_logP) + rebar_eta_zb * grad_sc_z - rebar_eta_zb * grad_sc_zb).mean()
    
    # Apply the Monte-carlo REBAR gradient estimator
    rebar_opt.zero_grad()
    rebar.theta.grad = rebar_grad_est
    rebar_opt.step()
    
    if (i+1) % 1000 == 0:
        print("step: {}".format(i+1))
        print("reinforce_loss {}".format(reinforce_loss[-1]))
        print("concrete(0.5)_loss {}".format(concrete_loss[-1]))
        print("rebar_loss {}\n".format(rebar_loss[-1]))

We can plot the loss per train step to see if we can replicate the results from the paper

In [None]:
# @hidden_cell
fig = plt.figure(figsize=(12, 9))
plt.plot(reinforce_loss, 'm', label="REINFORCE", alpha=0.7)
plt.plot(concrete_loss, 'r', label="Concrete(0.5)", alpha=0.7)
plt.plot(rebar_loss, 'b', label="REBAR", alpha=0.7)
plt.title("Optimal loss is 0.2025")
plt.xlabel("train_steps")
plt.ylabel("loss")
plt.ylim(0.2, 0.32)
plt.grid(True)
plt.legend()
plt.show()

## Some final thoughts

The variance of the loss in the above plot appears to be significantly greater than in the plots from the paper. Unfortunately, the hyperparameters used for the toy problem were not revealed.. My plot was generated with a batch size of 128. The variance increases a lot with smaller batch sizes. 

It was mentioned in the paper that the scaling factor, $\eta$, can be computed by

$$\frac{\text{Cov}(f,g)}{\text{Var}(g)}.$$

I tried this, but a value of $0.1$ performed better. I may have not been computing $\eta$ correctly though.