# Upper and lower bound estimators for NInGa

This notebook provides example code for how to compute the upper and lower bound estimators for the performance measure `NInGa`. The distributions we use in this notebook are the `Zero Inflated Gamma` and the `Zero Inflated LogNormal` distribution. Other distributions can be implemented similarly. For derivations, see the paper:

[Bayesian Oracle for bounding information gain in neural encoding models](https://openreview.net/forum?id=iYC5hOMqUg)

In [None]:
import torch
import numpy as np

from neuralmetrics.datasets import simulate_neuron_data, simulate_neuron_data_advanced
from neuralmetrics.models.utils import get_zig_params_from_moments, get_zil_params_from_moments
from neuralmetrics.utils import bits_per_image
from neuralmetrics.models.gs_models import Gaussian_GS, Gamma_GS
from neuralmetrics.models.gs_zero_inflation import Zero_Inflation_Base
from neuralmetrics.models.priors import get_prior_for_gaussian, get_prior_for_q, get_prior_for_gamma, train_prior_for_gaussian
from neuralmetrics.models.flows.transforms import Log, Identity
from neuralmetrics.models.score_functions import compute_gs_loss_over_target_repeats, compute_null_loss
from neuralpredictors.measures.zero_inflated_losses import ZIGLoss, ZILLoss

from scipy.stats import beta as beta_distribution

from neuralpredictors.measures import corr
from neuralpredictors.measures.zero_inflated_losses import ZILLoss


random_seed = 27121992
device = 'cuda'

## Simulated Data

Here we simulate neural data to show the functionality of the code on. Plug in your own code for loading your data instead. In the end, the data needs to be a numpy array of the shape `(n_repeats, n_images, n_neurons)`. If trials are missing in your data, replace them by `np.nan`. Note that for a missing trial, all entries in the neuron dimension need to be missing (see also the end of the following cell for an example). 

In [None]:
np.random.seed(random_seed)

exp_data = True
n_images = 360
n_repeats = 10
n_neurons = 100

mean = .5
variance = .01
A = (mean * (1 - mean) / variance - 1)
alpha = A * mean
beta = A * (1 - mean)
zero_inflation_level = beta_distribution(21, 117).rvs(n_neurons)
loc = np.exp(-10)

resps, gt_means, gt_variances, zil_params = simulate_neuron_data_advanced(n_images=n_images,
                                                      n_repeats=n_repeats,
                                                      n_neurons=n_neurons,
                                                      zero_inflation_level=zero_inflation_level,
                                                      loc=loc,
                                                      random_state=random_seed)

# If single trials are missing due to experimental errors, replace them by np.nan, for example:
resps[0, 0, :] = np.nan
n_trials = (n_repeats*n_images*n_neurons - np.isnan(resps).sum())

---

# Optimize prior params

The upper bound estimator (GS model) is best, if the prior hyperparamters are optimized. This takes a long time but only needs to be done once per dataset. The following cells show how to do this for the examples of the `Zero Inflated Gamma` and the `Zero Inflated LogNormal` distribution. Note that we choose reasonable initialization values for the prior hyperparameters by fitting to the raw data using the functions `get_prior_for_q`, `get_prior_for_gaussian` and `get_prior_for_gamma`. This is not necessary but speeds up the optimization.

In [None]:
distribution = "Zero Inflated Gamma" #"Zero Inflated LogNormal"

#### Initialize GS model

In [None]:
loc = np.exp(-10)
slab_mask = np.ones_like(resps)
slab_mask[resps <= loc] = np.nan
print("Getting good init values for q prior parameters...")
q_prior_params = get_prior_for_q(torch.from_numpy(resps), loc)

# Initialize GS model
if distribution == "Zero Inflated LogNormal":
    transform = Log()
    resps_transformed, _ = transform(torch.from_numpy(resps) - loc)
    print("Getting good init values for slab prior parameters...")
    slab_prior_params = get_prior_for_gaussian(resps_transformed.numpy(),
                                                   per_neuron=False,
                                                   mask=slab_mask,
                                                   lr_decay_steps=1)
    dist_slab = Gaussian_GS(*slab_prior_params, train_prior_hyperparams=True, alpha_greater_one=True)
    
elif distribution == "Zero Inflated Gamma":
    transform = Identity()
    resps_transformed, _ = transform(torch.from_numpy(resps) - loc)
    print("Getting good init values for slab prior parameters...")
    slab_prior_params = get_prior_for_gamma(resps_transformed.numpy(),
                                                   per_neuron=False,
                                                   mask=slab_mask)
    dist_slab = Gamma_GS(*slab_prior_params, train_prior_hyperparams=True)
else:
    raise NotImplementedError()

possible_number_of_loo_repeats = np.unique([dist_slab.get_number_of_repeats(torch.from_numpy(resps[:, i, :])) - 1 for i in range(resps.shape[1])])
gs_model = Zero_Inflation_Base(
    loc,
    dist_slab,
    *q_prior_params,
    possible_number_of_loo_repeats=possible_number_of_loo_repeats,
    transform=transform,
).to(device)
gs_model.integrals_over_q_dict = gs_model.get_integrals_over_q()

#### Optimize prior params

This cell takes a long time. The warning about not forgetting to recompute the integral over q can be ignored because it is taken care of. Consider saving the optimized prior hyperparameters after the optimization is finished.

In [None]:
print("Optimizing prior parameters...")
gs_model, loss = train_prior_for_gaussian(resps, gs_model, max_iter=200, logger=False, use_map=False)

# Optionally save optimized prior params
prior_params = {k: v for k, v in gs_model.named_parameters()}
# torch.save(prior_params, "optimized_prior_params" + ".tar")

# Obtain upper and lower bounds

In [None]:
if distribution == "Zero Inflated LogNormal":
    params_from_moments_function = get_zil_params_from_moments
    loss_function = ZILLoss(per_neuron=True)

elif distribution == "Zero Inflated Gamma":
    params_from_moments_function = get_zig_params_from_moments
    loss_function = ZIGLoss(per_neuron=True)
else:
    raise NotImplementedError()
    
# Get upper bound log-likelihood per repeat, image and neuron
loss_gs = compute_gs_loss_over_target_repeats(resps, gs_model, False).item()
upper_bound = -loss_gs / n_trials

# Get lower bound log-likelihood per repeat, image and neuron
loss_null = compute_null_loss(resps, params_from_moments_function, loss_function, torch.Tensor([loc]).to(device), device).sum()
lower_bound = -loss_null / n_trials

In [None]:
print(f"upper_bound: {upper_bound}")
print(f"lower_bound: {lower_bound}")