# Introduction

The goal of this notebook is to examine RLCT estimation in 2D for three toy synthetic datasets. We will compare the performance of implicit variational inference versus expliciit variational inference. blah blah

First we set up the parameters which feed into our main function.

In [6]:
from __future__ import print_function
from main import *
from torch.distributions.uniform import Uniform
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal

class Args:
    
    syntheticsamplesize = 500
    batchsize = 100
    w_dim = 2
    dpower = None
    posterior_viz = True

    epochs = 200
    prior = 'gaussian'
    pretrainDepochs = 100
    trainDepochs = 50
    n_hidden_D = 128
    num_hidden_layers_D = 1
    n_hidden_G = 128
    num_hidden_layers_G = 1

    lr_primal = 1e-3
    lr_dual = 1e-3
    lr = 1e-2

    elasticnet_alpha = 0.5

    beta_auto_liberal = False
    beta_auto_conservative = False
    beta_auto_oracle = False
    betasbegin = 0.1
    betasend = 1.9
    betalogscale = True
    numbetas = 3

    R = 200

    cuda = False

    log_interval = 100
    
args = Args()


The main function draws a dataset, trains both explicit and implicit variational inference. Results are logged as posterior graphs and RLCT least squares plot. 

In [7]:
def main(args):

    # draw new training-testing split
    train_loader, valid_loader, test_loader = get_dataset_by_id(args, kwargs)

    # get a grid of inverse temperatures [beta_1/log n, \ldots, beta_k/log n]
    set_betas(args)

    mc = 1
    saveimgpath = None
    nll_betas_explicit = np.empty(0)
    nll_betas_implicit = np.empty(0)

    for beta_index in range(args.betas.shape[0]):

        # train explicit variational inference
        var_model = train_explicitVI(train_loader, valid_loader, args, mc, beta_index, True, saveimgpath)
        nllw_array_explicit = approxinf_nll_explicit(train_loader, var_model, args)
        # record E nL_n(w)
        nll_betas_explicit = np.append(nll_betas_explicit, nllw_array_explicit.mean())

        # visualize EVI
        args.VItype = 'explicit'
        sampled_weights = sample_EVI(var_model, args)
        posterior_viz(train_loader, sampled_weights, args, beta_index, saveimgpath)

        # train implicit variational inference
        args.epsilon_dim = args.w_dim
        args.epsilon_mc = args.batchsize
        args.VItype = 'implicit'
        G = train_implicitVI(train_loader, valid_loader, args, mc, beta_index, saveimgpath)
        nllw_array_implicit = approxinf_nll_implicit(train_loader, G, args)
        nll_betas_implicit = np.append(nll_betas_implicit, nllw_array_implicit.mean())

        # visualize IVI
        with torch.no_grad():
            eps = torch.randn(100, args.epsilon_dim)
            sampled_weights = G(eps)
            posterior_viz(train_loader, sampled_weights, args, beta_index, saveimgpath)


    # should observe a straight line below
    lsfit_lambda(nll_betas_explicit, args, saveimgpath)
    lsfit_lambda(nll_betas_implicit, args, saveimgpath)


## Logistic regression


The target is generated as $y = bernoulli(p)$, where $p = 1/(1+e^-(w^T x + b)$ with $x \sim n(0,1)$
We set the true parameters to w = [0.5, 1] and b = 0.0

In [8]:
args.dataset = 'logistic_synthetic'
args.network = 'logistic'
args.bias = False
args.input_dim = args.w_dim
args.output_dim = 1

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}


# Let's generate some data according to this model
args.w_0 = torch.Tensor([[0.5], [1]])
args.b = torch.tensor([0.0])
X = torch.randn(2 * args.syntheticsamplesize, args.input_dim)
affine = torch.mm(X, args.w_0) + args.b
m = torch.distributions.bernoulli.Bernoulli(torch.sigmoid(affine))
y = m.sample()

Let's first visualize the data

In [9]:
plt.plot(affine.squeeze(dim=1).detach().numpy(), y.detach().numpy(), '.g')
plt.plot(affine.squeeze(dim=1).detach().numpy(), torch.sigmoid(affine).detach().numpy(), '.r')
plt.title('synthetic logistic regression data: w^T x + b versus probabilities and Bernoulli(p)')
plt.show()