# Calculation of Marginal Likelihood

## Likelihood from NLE

In [None]:
# Essentials
from scipy.special import logsumexp
import numpy as np
import pandas as pd
import pickle
import sbi.utils as utils
import torch

lik = pickle.load(open(f'posteriors/posterior_chuong_snle_10.pkl', 'rb')) # Amortized Likelihood estimator

# Model Prior
prior_min = np.log10(np.array([1e-2,1e-7,1e-8]))
prior_max = np.log10(np.array([1,1e-2,1e-2]))
prior = utils.BoxUniform(low=torch.tensor(prior_min), 
                         high=torch.tensor(prior_max))

## $P(X) = \int P(X|\zeta)P(\zeta)d\zeta$

In [4]:
# Function for P(X)
def get_PX(lik, prior, x, n):
    # Explored space
    s = np.linspace(prior.base_dist.low[0], prior.base_dist.high[0], n)
    m = np.linspace(prior.base_dist.low[1], prior.base_dist.high[1], n)
    p = np.linspace(prior.base_dist.low[2], prior.base_dist.high[2], n)

    # Create empty grid
    grd = torch.tensor([[[[s_,m_,p_,0] for s_ in s] for m_ in m] for p_ in p], dtype=torch.float32).reshape(n**3,4)
    # Fill it with likelihood (potential = likelihood)
    grd[:,3] = lik.potential(x=x,theta=grd[:,0:3]) # vectorized
    lens = np.array([float(prior.base_dist.high[i])-float(prior.base_dist.low[i]) for i in range(len(prior.base_dist.high))]) # Prior dimensions
    A = np.prod(lens) # Prior volume -> P(θ) = 1/A
    dt = A / (n**3) # Granularity
    y = grd[:,3] + np.log(1/A) + np.log(dt)
    return float(logsumexp(y))# Riemann sum ~ integral -> marginal likelihood

lines = ['wt','ltr','ars','all']
cont_df = pd.DataFrame(index = lines, columns = [f'rep {i+1}' for i in range(8)])
for l in range(len(lines)):
    line = lines[l]
    X = pd.read_csv(f'empirical_data/{line}.csv', index_col=0) # unimputed data
    conts = [round(get_PX(lik, prior, X.iloc[i,:], 100)) for i in range(len(X))]
    cont_df.iloc[l,:len(conts)] = conts
cont_df.replace(np.nan, '') # aesthetics

  x = atleast_2d(torch.as_tensor(x, dtype=float32))
  x = atleast_2d(torch.as_tensor(x, dtype=float32))
  x = atleast_2d(torch.as_tensor(x, dtype=float32))
  x = atleast_2d(torch.as_tensor(x, dtype=float32))
  cont_df.replace(np.nan, '') # aesthetics


Unnamed: 0,rep 1,rep 2,rep 3,rep 4,rep 5,rep 6,rep 7,rep 8
wt,-1963,-4188,-1094,-22541,-1558,,,
ltr,-2480,-1511,-12135,-953,-271,-153.0,-301.0,
ars,-5363,-3624,-3460,-366,-1774,-4685.0,-4489.0,
all,-911,-1065,-5163,-1853,-3174,-2125.0,-906.0,-1648.0


wt #4 and ltr #3 are possibly unrepresentative