## Regular ACE
Really just copy/paste of AF's implementation in `pyro`. Since the current `pyro` release doesn't have some of AF's code from SGBOED, have to copy/paste some of his custom functions here.

Going to compare the lower/upper bounds of SGBOED on a simple linear regression task. 

In [2]:
import time
import numbers
import math

import torch
from torch import nn
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro import poutine
from pyro.contrib.util import rmv, lexpand

def is_bad(a):
    return torch_isnan(a) or torch_isinf(a)
    

def torch_isnan(x):
    """
    A convenient function to check if a Tensor contains any nan; also works with numbers
    """
    if isinstance(x, numbers.Number):
        return x != x
    return torch.isnan(x).any()


def torch_isinf(x):
    """
    A convenient function to check if a Tensor contains any +inf; also works with numbers
    """
    if isinstance(x, numbers.Number):
        return x == float('inf') or x == -float('inf')
    return (x == float('inf')).any() or (x == -float('inf')).any()


def _safe_mean_terms(terms):
    mask = torch.isnan(terms) | (terms == float('-inf')) | (terms == float('inf'))
    if terms.dtype is torch.float32:
        nonnan = (~mask).sum(0).float()
    elif terms.dtype is torch.float64:
        nonnan = (~mask).sum(0).double()
    terms[mask] = 0.
    loss = terms.sum(0) / nonnan
    agg_loss = loss.sum()
    return agg_loss, loss


def make_regression_model(w_loc, w_scale, sigma_scale, xi_init, observation_label="y"):
    def regression_model(design_prototype):
        design = pyro.param("xi", xi_init)
        design = (design / design.norm(dim=-1, p=1, keepdim=True)).expand(design_prototype.shape)
        if is_bad(design):
            raise ArithmeticError("bad design, contains nan or inf")
        batch_shape = design.shape[:-2]
        with pyro.plate_stack("plate_stack", batch_shape):
            # `w` is shape p, the prior on each component is independent
            w = pyro.sample("w", dist.Laplace(w_loc, w_scale).to_event(1))
            # `sigma` is scalar
            sigma = 1e-6 + pyro.sample("sigma", dist.Exponential(sigma_scale)).unsqueeze(-1)
            mean = rmv(design, w)
            sd = sigma
            y = pyro.sample(observation_label, dist.Normal(mean, sd).to_event(1))
            return y

    return regression_model


class TensorLinear(nn.Module):
    __constants__ = ['bias']

    def __init__(self, *shape, bias=True):
        super(TensorLinear, self).__init__()
        self.in_features = shape[-2]
        self.out_features = shape[-1]
        self.batch_dims = shape[:-2]
        self.weight = nn.Parameter(torch.Tensor(*self.batch_dims, self.out_features, self.in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(*self.batch_dims, self.out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return rmv(self.weight, input) + self.bias


class PosteriorGuide(nn.Module):
    def __init__(self, y_dim, w_dim, batching):
        super(PosteriorGuide, self).__init__()
        n_hidden = 64
        self.linear1 = TensorLinear(*batching, y_dim, n_hidden)
        self.linear2 = TensorLinear(*batching, n_hidden, n_hidden)
        self.output_layer = TensorLinear(*batching, n_hidden, w_dim + 3)
        self.covariance_shape = batching + (w_dim, w_dim)
        self.softplus = nn.Softplus()
        self.relu = nn.ReLU()

    def forward(self, y_dict, design_prototype, observation_labels, target_labels):
        y = y_dict["y"] - .5
        x = self.relu(self.linear1(y))
        x = self.relu(self.linear2(x))
        final = self.output_layer(x)

        posterior_mean = final[..., :-3]
        gamma_concentration = 1e-6 + self.softplus(final[..., -3])
        gamma_rate = 1. + self.softplus(final[..., -2])
        scale_tril_multiplier = 1e-6 + self.softplus(final[..., -1])

        pyro.module("posterior_guide", self)

        posterior_scale_tril = pyro.param(
            "posterior_scale_tril",
            torch.eye(posterior_mean.shape[-1], device=posterior_mean.device).expand(self.covariance_shape),
            constraint=constraints.lower_cholesky
        )
        posterior_scale_tril = posterior_scale_tril * scale_tril_multiplier.unsqueeze(-1).unsqueeze(-1)

        batch_shape = design_prototype.shape[:-2]
        with pyro.plate_stack("guide_plate_stack", batch_shape):
            pyro.sample("sigma", dist.Gamma(gamma_concentration, gamma_rate))
            pyro.sample("w", dist.MultivariateNormal(posterior_mean, scale_tril=posterior_scale_tril))


def _ace_eig_loss(model, guide, M, observation_labels, target_labels):
    def loss_fn(design, num_particles, evaluation=False, **kwargs):
        N = num_particles
        expanded_design = lexpand(design, N)
        
        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict_exp = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels}
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
        theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

        trace.compute_log_prob()
        marginal_terms_cross = sum(trace.nodes[l]["log_prob"] for l in target_labels)
        marginal_terms_cross += sum(trace.nodes[l]["log_prob"] for l in observation_labels)

        reguide_trace = poutine.trace(
            pyro.condition(guide, data=theta_dict)).get_trace(
            y_dict, expanded_design, observation_labels, target_labels)
        # Here's a spot where you could update each model's parameters based on log_prob
        reguide_trace.compute_log_prob()
        marginal_terms_cross -= sum(reguide_trace.nodes[l]["log_prob"] for l in target_labels)

        # Sample M times from q(theta | y, d) for each y
        reexpanded_design = lexpand(expanded_design, M)
        guide_trace = poutine.trace(guide).get_trace(
            y_dict, reexpanded_design, observation_labels, target_labels
        )
        theta_y_dict = {l: guide_trace.nodes[l]["value"] for l in target_labels}
        theta_y_dict.update(y_dict_exp)
        guide_trace.compute_log_prob()

        # Re-run that through the model to compute the joint
        model_trace = poutine.trace(
            pyro.condition(model, data=theta_y_dict)).get_trace(reexpanded_design)
        model_trace.compute_log_prob()

        marginal_terms_proposal = -sum(guide_trace.nodes[l]["log_prob"] for l in target_labels)
        marginal_terms_proposal += sum(model_trace.nodes[l]["log_prob"] for l in target_labels)
        marginal_terms_proposal += sum(model_trace.nodes[l]["log_prob"] for l in observation_labels)

        marginal_terms = torch.cat([lexpand(marginal_terms_cross, 1), marginal_terms_proposal])
        terms = -marginal_terms.logsumexp(0) + math.log(M + 1)

        # At eval time, add p(y | theta, d) terms
        if evaluation:
            terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels)
        return _safe_mean_terms(terms)

    return loss_fn


def _vnmc_eig_loss(model, guide, observation_labels, target_labels):
    """VNMC loss: to evaluate directly use `vnmc_eig` setting `num_steps=0`."""

    def loss_fn(design, num_particles, evaluation=False, **kwargs):
        N, M = num_particles
        expanded_design = lexpand(design, N)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict_unexp = {l: trace.nodes[l]["value"] for l in observation_labels}
        y_dict = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels}

        # Sample M times from q(theta | y, d) for each y
        reexpanded_design = lexpand(expanded_design, M)
        # conditional_guide = pyro.condition(guide, data=y_dict)
        guide_trace = poutine.trace(guide).get_trace(
            y_dict_unexp, reexpanded_design, observation_labels, target_labels)
        theta_y_dict = {l: guide_trace.nodes[l]["value"] for l in target_labels}
        theta_y_dict.update(y_dict)
        guide_trace.compute_log_prob()

        # Re-run that through the model to compute the joint
        modelp = pyro.condition(model, data=theta_y_dict)
        model_trace = poutine.trace(modelp).get_trace(reexpanded_design)
        model_trace.compute_log_prob()

        terms = -sum(guide_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"] for l in observation_labels)
        terms = -terms.logsumexp(0) + math.log(M)

        # At eval time, add p(y | theta, d) terms
        if evaluation:
            trace.compute_log_prob()
            terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels)

        return _safe_mean_terms(terms)

    return loss_fn


def neg_loss(loss):
    def new_loss(*args, **kwargs):
        return (-a for a in loss(*args, **kwargs))
    return new_loss


def opt_eig_loss_w_history(design, loss_fn, num_samples, num_steps, optim, time_budget):
    params = None
    est_loss_history = []
    xi_history = []
    baseline = 0.
    t = time.time()
    wall_times = []
    for step in range(num_steps):
        if params is not None:
            pyro.infer.util.zero_grads(params)
        with poutine.trace(param_only=True) as param_capture:
            agg_loss, loss = loss_fn(design, num_samples, evaluation=True, control_variate=baseline)
        baseline = -loss.detach()
        params = set(site["value"].unconstrained()
                     for site in param_capture.trace.nodes.values())
        if torch.isnan(agg_loss):
            raise ArithmeticError("Encountered NaN loss in opt_eig_ape_loss")
        agg_loss.backward(retain_graph=True)
        est_loss_history.append(loss.detach())
        wall_times.append(time.time() - t)
        optim(params)
        optim.step()
        print(pyro.param("xi")[0, 0, ...])
        print(step)
        print('eig', baseline.squeeze())
        if time_budget and time.time() - t > time_budget:
            break

    xi_history.append(pyro.param('xi').detach().clone())

    est_loss_history = torch.stack(est_loss_history)
    xi_history = torch.stack(xi_history)
    wall_times = torch.tensor(wall_times)

    return xi_history, est_loss_history, wall_times


# --------------------
# Start Pyro code implementation
num_steps = 1000
num_samples = 10
time_budget = 1000
seed = 420
num_parallel = 1
start_lr = 0.001
end_lr = 0.001
device = 'cpu'
n = 1
p = 1
scale = 1.

pyro.clear_param_store()
if seed >= 0:
    pyro.set_rng_seed(seed)
else:
    seed = int(torch.rand(tuple()) * 2 ** 30)
    pyro.set_rng_seed(seed)

xi_init = torch.randn((num_parallel, n, p), device=device)
# Change the prior distribution here
# prior params
w_prior_loc = torch.zeros(p, device=device)
w_prior_scale = scale * torch.ones(p, device=device)
sigma_prior_scale = scale * torch.tensor(1., device=device)

model_learn_xi = make_regression_model(
    w_prior_loc, w_prior_scale, sigma_prior_scale, xi_init)

contrastive_samples = num_samples

# Fix correct loss
targets = ["w", "sigma"]

guide = PosteriorGuide(n, p, (num_parallel,)).to(device)
eig_loss = _ace_eig_loss(model_learn_xi, guide, contrastive_samples, ["y"], targets)
loss = neg_loss(eig_loss)

# Train guide
print("Training")
# Look into switching this out with `opt_eig_loss_w_history`
# opt_eig_ape_loss(design, loss, num_samples=10, num_steps=20000, optim=optimizer)
gamma = (end_lr / start_lr) ** (1 / num_steps)
scheduler = pyro.optim.ExponentialLR({'optimizer': torch.optim.Adam, 'optim_args': {'lr': start_lr},
                                        'gamma': gamma})

design_prototype = torch.zeros(num_parallel, n, p, device=device)  # this is annoying, code needs refactor

xi_history, est_loss_history, wall_times = opt_eig_loss_w_history(
    design_prototype, loss, num_samples=num_samples, num_steps=num_steps, optim=scheduler,
    time_budget=time_budget)

est_eig_history = -est_loss_history

# Evaluate
print("Evaluation")
num_inner_samples = 2500
num_outer_samples = 100000
lower_loss = _ace_eig_loss(model_learn_xi, guide, num_inner_samples, "y", targets)  # isn't that an annoying API difference?
upper_loss = _vnmc_eig_loss(model_learn_xi, guide, "y", targets)
lower, upper = 0., 0.
max_samples = 10000
n_per_batch = max_samples // num_inner_samples
n_batches = num_inner_samples * num_outer_samples // max_samples
for i in range(n_batches):
    print(i)
    # lower += lower_loss(design, n_per_batch, evaluation=True)[1].detach().cpu()
    # upper += upper_loss(design, (n_per_batch, num_inner_samples), evaluation=True)[1].detach().cpu()
    lower += lower_loss(xi_history[-1], n_per_batch, evaluation=True)[1].detach().cpu()
    upper += upper_loss(xi_history[-1], (n_per_batch, num_inner_samples), evaluation=True)[1].detach().cpu()

final_upper_bound= upper.cpu() / n_batches
final_lower_bound = lower.cpu() / n_batches

print(final_lower_bound, final_upper_bound)

# Maybe add final/lower bounds?
results = {'seed': seed, 'xi_history': xi_history.cpu(), 'est_eig_history': est_eig_history.cpu(),
            'wall_times': wall_times.cpu()}

tensor([-1.6977], grad_fn=<SelectBackward0>)
0
eig tensor(-0.2472)
tensor([-1.6979], grad_fn=<SelectBackward0>)
1
eig tensor(0.5802)
tensor([-1.6981], grad_fn=<SelectBackward0>)
2
eig tensor(0.1910)
tensor([-1.6984], grad_fn=<SelectBackward0>)
3
eig tensor(0.5558)
tensor([-1.6987], grad_fn=<SelectBackward0>)
4
eig tensor(-3.8025)
tensor([-1.6989], grad_fn=<SelectBackward0>)
5
eig tensor(0.4813)
tensor([-1.6992], grad_fn=<SelectBackward0>)
6
eig tensor(0.6277)
tensor([-1.6994], grad_fn=<SelectBackward0>)
7
eig tensor(-0.5476)
tensor([-1.6996], grad_fn=<SelectBackward0>)
8
eig tensor(-0.4766)
tensor([-1.6997], grad_fn=<SelectBackward0>)
9
eig tensor(-1.8593)
tensor([-1.6999], grad_fn=<SelectBackward0>)
10
eig tensor(-0.7433)
tensor([-1.7000], grad_fn=<SelectBackward0>)
11
eig tensor(0.6771)
tensor([-1.7001], grad_fn=<SelectBackward0>)
12
eig tensor(0.1910)
tensor([-1.6998], grad_fn=<SelectBackward0>)
13
eig tensor(-0.1653)
tensor([-1.6995], grad_fn=<SelectBackward0>)
14
eig tensor(1.2694