In [1]:
import torch
import torch.nn
import math

def logdet(L):
    """
    Calculate the log determinant of L L^T.
    """
    return (L.diag() ** 2).log().sum()

## Create dataset

In [2]:
# Create a very simple dataset
torch.manual_seed(0)
d = 5
Phi = torch.randn(d, N)
coeff = torch.ones(d)
ts = torch.where((coeff @ Phi > 0.), torch.tensor(1.), torch.tensor(0.))
print(ts)

# For test
Ntest = 50
Phi_test = torch.randn(d, Ntest)
ts_test = torch.where((coeff @ Phi_test > 0.), torch.tensor(1.), torch.tensor(0.))
print(ts_test)

tensor([0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0.,
        0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1.,
        1., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1.,
        1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1.,
        0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 1., 1.,
        0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0.,
        1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1.,
        1., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1.,
        1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0.,
        1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 1.,
        1., 1.])
tensor([0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1.,
        0., 1., 0., 0.,

## Describe the model and calculate the ELBO

Model Description: $p(w)=\mathcal{N}(0, \frac{1}{\alpha} I), y_i = \sigma(w^T \phi(x_i))$ and observation $t_i\sim \text{Bernoulli}(y_i)=y_i^{t_i} (1-y_i)^{1-t_i}$.

Using $q(w)=\mathcal{N}(m, LL^T)$ as our variational distribution, where $L$ is a lower triangular matrix, we can calculate the ELBO as below:
\begin{align*}
\text{ELBO}&=E_{q(w)}[\ln p(w) - \ln q(w) + \ln p(t|w)] \\
&=E_{q(w)}[(\frac{d}{2} \ln \alpha - \frac{\alpha}{2} w^T w) + (\frac{1}{2} \ln \det(L L^T) + \frac{1}{2} (w-m)^T L^{-T} L^{-1} (w-m)) + (\sum_i t_i \ln y_i + (1-t_i)\ln(1-y_i))].
\end{align*}

We use the function below to evaluate the quality of $q(w)$. It returns $\log E_{q(w)}[p(t|w)]$ and the misclassification rate.

In [3]:
def evaluate(m, L, Phi_test, ts_test):
    mu = m @ Phi_test
    var = torch.sum((L@Phi_test)**2, 0)
    kappa = (1 + math.pi * var / 8) ** (-0.5)
    
    loss_func = torch.nn.BCEWithLogitsLoss()
    log_loss = loss_func(kappa * mu, ts_test)
    
    prob = torch.sigmoid(kappa * mu)
    prediction = (prob > 0.5).float()
    misclassification = torch.sum((prediction != ts_test).int()).item() / ts_test.numel()
    return log_loss, misclassification

## Reparametrization trick

Let $w=m+L \epsilon, \epsilon \sim \mathcal{N}(0, I)$. "triangular_solve" efficiently computes $L^{-1} A$ where $L$ is a lower triangular matrix and $A$ is any other matrix. A similar function called "cholesky_solve" computes $L^{-T} L^{-1} A$.

In [4]:
def ELBO_reparametrization(alpha, m, L, Phi, ts, num_samples=1):
    """
    Use reparametrization trick to compute ELBO.
    :param alpha: prior parameter.
    :param m: variational mean.
    :param L: cholesky factor of variational covariance.
    :param Phi: Phi matrix of xs. Each colomn represent a phi(x).
    :param ts: observations.
    """
    d = Phi.size(0)
    
    L = L.tril(0)
    
    eps = torch.randn(d, num_samples)
    ws = m[:, None] + (L @ eps)  # ws is now of size (d, num_samples)
    
    lnp = d * torch.log(alpha) /2 - alpha * (ws **2).sum(0) / 2
    
    L_inv_w_m = torch.triangular_solve(ws - m[:, None], L, upper=False)[0]
    lnq = logdet(L) / 2 + (L_inv_w_m ** 2).sum(0) / 2
    
    ys = torch.sigmoid(ws.t() @ Phi)
    logl = (ts * torch.log(ys) + (1 - ts) * torch.log(1-ys)).sum(1)
    
    return (lnp + lnq + logl).sum() / num_samples

Show the result on the simple dataset.

In [5]:
num_samples = 10
alpha = torch.tensor(1.)
dim = 5

m = torch.zeros(dim, requires_grad=True)
L = torch.eye(dim).clone().detach().requires_grad_(True)

optimizer = torch.optim.Adam([m, L], lr=0.05)
for i in range(100):
    optimizer.zero_grad()
    loss = -ELBO_reparametrization(alpha, m, L, Phi, ts, num_samples=num_samples)
    loss.backward()
    optimizer.step()
    # evaluate
    with torch.no_grad():
        logloss, mis = evaluate(m, L, Phi_test, ts_test)
    if i % 10 == 9:
        print('Iteration: {0:04d} Loss: {1: .6f} Test Performance: {2: .6f} {3: .1f}'.format(i, loss.item(), logloss.item(), mis))

Iteration: 0009 Loss:  129.097626 Test Performance:  0.390509  0.0
Iteration: 0019 Loss:  64.714828 Test Performance:  0.215732  0.0
Iteration: 0029 Loss:  56.918682 Test Performance:  0.158729  0.0
Iteration: 0039 Loss:  50.749977 Test Performance:  0.138594  0.0
Iteration: 0049 Loss:  47.075630 Test Performance:  0.126056  0.0
Iteration: 0059 Loss:  45.723824 Test Performance:  0.116105  0.0
Iteration: 0069 Loss:  45.048424 Test Performance:  0.108743  0.0
Iteration: 0079 Loss:  44.142651 Test Performance:  0.103115  0.0
Iteration: 0089 Loss:  43.614021 Test Performance:  0.099320  0.0
Iteration: 0099 Loss:  43.450909 Test Performance:  0.096284  0.0


torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1678402353079/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2197.)
  L_inv_w_m = torch.triangular_solve(ws - m[:, None], L, upper=False)[0]


## Black box variational inference

Black Box variational inference uses $f_{\lambda}(z)=\nabla_{\lambda} \log q(z|\lambda) (\log p(x|z) + \log p(z) - \log q(z))$ to compute the gradients of ELBO. An easier way to implement is use the autograd framework to compute the $\nabla_\lambda \log q(z|\lambda)$.

In [6]:
def BBVIAuto(m, L, alpha, Phi, ts, num_samples=5):
    L = L.tril(0)
    q_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=m, scale_tril=L)
    p_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=torch.zeros_like(m), \
                                                                        covariance_matrix=1/alpha * torch.eye(m.numel()))
    samples = q_dist.sample((num_samples,)) # of size (num_samples, d)
    
    lnp = p_dist.log_prob(samples)
    lnq = q_dist.log_prob(samples)
        
    # Calculate log likelihood. Use BCEWithLogitsLoss for numerical stability.
    loss_func = torch.nn.BCEWithLogitsLoss(reduction='none')
    logl = -loss_func(samples @ Phi, ts.repeat(num_samples, 1)).sum(1)
    
    # We do not want to take the gradients of log q(z) inside the parentheses, so we detach it.
    value = lnp + logl - lnq.detach().clone()
    return (value * lnq).sum() / num_samples

In [7]:
num_samples = 50
alpha = torch.tensor(1.)
dim = 5

m = torch.zeros(dim, requires_grad=True)
L = torch.eye(dim).clone().detach().requires_grad_(True)

optimizer = torch.optim.Adam([m, L], lr=0.01)
for i in range(500):
    optimizer.zero_grad()
    loss = -BBVIAuto(m, L, alpha, Phi, ts, num_samples)
    loss.backward()
    optimizer.step()
    # evaluate
    with torch.no_grad():
        logloss, mis = evaluate(m, L, Phi_test, ts_test)
    if i % 20 == 19:
        print('Iteration: {0:04d} Loss: {1: .6f} Test Performance: {2: .6f} {3: .1f}'.format(i, loss.item(), logloss.item(), mis))

Iteration: 0019 Loss: -1270.076172 Test Performance:  0.589947  0.0
Iteration: 0039 Loss: -947.644287 Test Performance:  0.515785  0.0
Iteration: 0059 Loss: -799.911499 Test Performance:  0.455023  0.0
Iteration: 0079 Loss: -609.904663 Test Performance:  0.395307  0.0
Iteration: 0099 Loss: -486.943512 Test Performance:  0.349302  0.0
Iteration: 0119 Loss: -448.278015 Test Performance:  0.312564  0.0
Iteration: 0139 Loss: -255.392273 Test Performance:  0.281208  0.0
Iteration: 0159 Loss: -232.433105 Test Performance:  0.258804  0.0
Iteration: 0179 Loss: -201.783279 Test Performance:  0.238110  0.0
Iteration: 0199 Loss: -175.608307 Test Performance:  0.224246  0.0
Iteration: 0219 Loss: -123.408623 Test Performance:  0.212608  0.0
Iteration: 0239 Loss: -151.047577 Test Performance:  0.200374  0.0
Iteration: 0259 Loss: -105.120094 Test Performance:  0.189645  0.0
Iteration: 0279 Loss: -83.250877 Test Performance:  0.178871  0.0
Iteration: 0299 Loss: -83.944405 Test Performance:  0.169706  

However, using the autograd framework in pytorch gives us no access to the gradients and thus it is hard to add the control variates. Instead, we can define our own loss function along with the backward method. The value returned by forward function is not important but we have to calculate the gradients inside the forward function and save them in "context". Then in backward function, we extract gradients from context and return them.

Notice that $f_{\lambda}(z)=\nabla_{\lambda} \log q(z|\lambda) (\log p(x|z) + \log p(z) - \log q(z))$ is the gradient and $h(z)=\nabla_\lambda \log q(z|\lambda)$ is the control variate. We have to manually calculate $\nabla_m \log q(z|m, L) = L^{-T} L^{-1} (z - m) $ and $\nabla_L \log q(z|m, L) = -L^{-T} + L^{-T} L^{-1} (z-m)(z-m)^T L^{-T}$.

In [8]:
class BBVILoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, m, L, alpha, Phi, ts, num_samples=5):
        L = L.tril(0)
        q_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=m, scale_tril=L)
        p_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=torch.zeros_like(m), \
                                                                            covariance_matrix=1/alpha * torch.eye(m.numel()))
        samples = q_dist.sample((num_samples,)) # of size (num_samples, d)
        
        lnp = p_dist.log_prob(samples)
        lnq = q_dist.log_prob(samples)
        
        # Calculate log likelihood. Use BCEWithLogitsLoss for numerical stability.
        loss_func = torch.nn.BCEWithLogitsLoss(reduction='none')
        logl = -loss_func(samples @ Phi, ts.repeat(num_samples, 1)).sum(1)
        value = lnp + logl - lnq
        
        dlnq_dm = torch.zeros(num_samples, d)
        dlnq_dL = torch.zeros(num_samples, d, d)
        L_inv_T = (torch.triangular_solve(torch.eye(L.size(0)), L)[0]).t()
        for i in range(num_samples):
            temp = (samples[i] - m).unsqueeze(-1)
            dlnq_dm[i] = (torch.cholesky_solve(temp, L)).squeeze()
            dlnq_dL[i] = -L_inv_T + torch.cholesky_solve(temp @ temp.t() @ L_inv_T, L)
        
        # calculate 'a' for m
        fm = value[:, None] * dlnq_dm
        hm = dlnq_dm
        fm_mean = fm.sum(0) / num_samples
        hm_mean = hm.sum(0) / num_samples
        cov_fm_hm = ((fm-fm_mean)*hm).sum(0) / (num_samples - 1)
        var_hm = (hm**2).sum(0) / (num_samples - 1)
        am = cov_fm_hm / var_hm
        
        # calculate 'a' for L
        fL = value[:, None, None] * dlnq_dL
        hL = dlnq_dL
        fL_mean = fL.sum(0) / num_samples
        hL_mean = hL.sum(0) / num_samples
        cov_fL_hL = ((fL-fL_mean)*(hL-hL_mean)).sum(0) / (num_samples - 1)
        var_hL = ((hL-hL_mean)**2).sum(0) / (num_samples - 1)
        aL = cov_fL_hL / var_hL
        
        delbo_dm = (value[:, None] * dlnq_dm - am[None, :] * dlnq_dm).sum(0).div(num_samples)
        delbo_dL = (value[:, None, None]* dlnq_dL - aL[None, :, :] * dlnq_dL).sum(0).div(num_samples)
        
        ctx.save_for_backward(delbo_dm, delbo_dL)
        return (lnp + logl - lnq).sum().div(num_samples)
    
    @staticmethod
    def backward(ctx, grad_output):
        grad_m = grad_L = None
        delbo_dm, delbo_dL = ctx.saved_tensors
        if ctx.needs_input_grad[0]:
            grad_m = grad_output * delbo_dm
        if ctx.needs_input_grad[1]:
            grad_L = grad_output * delbo_dL
        return grad_m, grad_L, None, None, None, None

In [9]:
num_samples = 50
alpha = torch.tensor(1.)
dim = 5

m = torch.zeros(dim, requires_grad=True)
L = torch.eye(dim).clone().detach().requires_grad_(True)

optimizer = torch.optim.Adam([m, L], lr=0.01)
for i in range(500):
    optimizer.zero_grad()
    loss = -BBVILoss.apply(m, L, alpha, Phi, ts, num_samples)
    loss.backward()
    optimizer.step()
    # evaluate
    with torch.no_grad():
        logloss, mis = evaluate(m, L, Phi_test, ts_test)
    if i % 20 == 19:
        print('Iteration: {0:04d} Loss: {1: .6f} Test Performance: {2: .6f} {3: .1f}'.format(i, loss.item(), logloss.item(), mis))

Iteration: 0019 Loss:  170.104980 Test Performance:  0.584444  0.0
Iteration: 0039 Loss:  134.841248 Test Performance:  0.476614  0.0
Iteration: 0059 Loss:  104.949593 Test Performance:  0.390326  0.0
Iteration: 0079 Loss:  90.097000 Test Performance:  0.322188  0.0
Iteration: 0099 Loss:  75.086990 Test Performance:  0.269097  0.0
Iteration: 0119 Loss:  67.526443 Test Performance:  0.236580  0.0
Iteration: 0139 Loss:  63.855225 Test Performance:  0.216595  0.0
Iteration: 0159 Loss:  60.096058 Test Performance:  0.200095  0.0
Iteration: 0179 Loss:  57.848900 Test Performance:  0.186758  0.0
Iteration: 0199 Loss:  55.048630 Test Performance:  0.173956  0.0
Iteration: 0219 Loss:  53.447769 Test Performance:  0.165810  0.0
Iteration: 0239 Loss:  53.225605 Test Performance:  0.159034  0.0
Iteration: 0259 Loss:  51.165985 Test Performance:  0.152136  0.0
Iteration: 0279 Loss:  50.513100 Test Performance:  0.146309  0.0
Iteration: 0299 Loss:  50.353901 Test Performance:  0.140852  0.0
Iterati