## Question 2

In [2]:
import numpy as np
from scipy.stats import norm, t
import math

In [3]:
# naive implementation
def logsumexp_naive(nums):
    return np.log(np.sum(np.exp(nums)))

In [4]:
# stable implementation
def logsumexp_stable(values):
    max_val = np.max(values)
    return max_val + np.log(np.sum(np.exp(values - max_val)))

In [5]:
# Test with large values
test_arr = np.array([1000, 1010, 1002])

n = logsumexp_naive(test_arr)
print('naive logsumexp', n)
s = logsumexp_stable(test_arr)
print("stable logsumexp", s)

naive logsumexp inf
stable logsumexp 1010.000380790048


  return np.log(np.sum(np.exp(nums)))


In [6]:
# Define the Banana density function
def log_banana_density(x):
    return -x[0]**2 / 10 - x[1]**2 / 10 - 2 * (x[1] - x[0]**2)**2

# Define the proposal density q(x)
def log_proposal_density(x, s):
    return - (x[0]**2)/(2*s[0]**2) - (x[1] - (x[0]**2))**2/(2*s[1]**2) - np.log(2 * np.pi * s[0] * s[1])

### SNIS Log-Marginal Likelihood:

$
\log(p_{\text{SNIS}}) = -\log(N) + \log\left(\sum_i \left( \prod_j p(y_j | x_i) w_i \right)\right) - \log\left(\sum_i w_i\right)
$

Expanding:

$
= \log\left(\sum_i \exp\left(\sum_j \log p(y_j | x_i) + \log w_i\right)\right) - \log N - \log\left(\sum_i e^{\log w_i}\right)
$

Finally:

$
= \text{LSE}\left(\sum_j \log p(y_j | x_i) + \log w_i\right) - \text{LSE}(\log w_i) - \log N
$

Where:

$
\text{LSE}(a_1, a_2, \dots, a_n) = \log\left(\sum_{i=1}^n e^{a_i}\right).
$

In [7]:
# Load data
x_samples = np.load("samples.npy")
y_data = np.load("y.npy")

sigma_sq = 0.1
sigma = np.sqrt(sigma_sq)
proposal_s = [math.sqrt(5), 1/2]
nu=5

def gaussian_likelihood(y, x):
    return norm.pdf(y, loc=x[0], scale=sigma)

def student_t_likelihood(y, x, nu):
    return t.pdf(y, df=nu, loc=x[0], scale=sigma)

def snis_estimator(y_data, x_samples, model='gaussian', nu=None, epsilon=0):
    log_likelihoods = []
    weights = []
    
    for x in x_samples:
        match model:
            case 'gaussian':
                l = gaussian_likelihood(y_data, x)
            case 'student_t':
                l = student_t_likelihood(y_data, x, nu)
            case _:
                print('unknown model')
                return

        l = l[l > epsilon]
        log_l = np.sum(np.log(l))
        log_likelihoods.append(log_l)

        log_prior = log_banana_density(x)
        log_proposal = log_proposal_density(x, proposal_s)
        log_weight = log_prior - log_proposal

        log_l_w = log_l + log_weight

        log_likelihoods.append(log_l_w)
        weights.append(log_weight)
    
    lse = logsumexp_stable(log_likelihoods)
    lse_w = logsumexp_stable(weights)
    print(lse, lse_w, np.log(len(x_samples)))
    
    # Use logsumexp for numerical stability
    return lse - np.log(len(x_samples)) - lse_w

# Test runs
print("Gaussian:", snis_estimator(y_data, x_samples, model='gaussian'))
print(f"Student t, nu={nu}:", snis_estimator(y_data, x_samples, model='student_t', nu=nu))


-18547.555027364462 12.612291524246025 10.819778284410283
Gaussian: -18570.987097173118
-7561.639003403769 12.612291524246025 10.819778284410283
Student t, nu=5: -7585.071073212424


In [87]:
e  = snis_estimator(y_data, x_samples, model='gaussian')

-18547.555027364462 12.612291524246025 10.819778284410283


In [72]:
e + lsew + len(y_data)/2 * np.log(5/np.pi)

-16234.9815295856

In [88]:
lsew

[-29756.73470722124,
 -29755.088407392588,
 -648584.668798602,
 -648588.1325146836,
 -46725.5523662257,
 -46724.599729384085,
 -60365.834822548946,
 -60363.916002948456,
 -64212.47181882475,
 -64210.53740828356,
 -24825.286848608892,
 -24823.3389473836,
 -21601.434347734565,
 -21599.490684712826,
 -369619.5650889899,
 -369617.98447489244,
 -30143.594577190703,
 -30141.645552328682,
 -167274.90895196894,
 -167273.0951523343,
 -43655.59697142557,
 -43653.65301006907,
 -63978.132206725226,
 -63978.00755744448,
 -89434.44249827445,
 -89432.54431600514,
 -162375.3142102379,
 -162373.40660172133,
 -44999.27951610209,
 -44997.99014311999,
 -193069.18681964604,
 -193067.25878998908,
 -255501.23467281053,
 -255499.61227465826,
 -34238.30900556761,
 -34236.38874539135,
 -195903.36129231946,
 -195901.43393320788,
 -198850.24037527142,
 -198848.29773028972,
 -117300.04472298181,
 -117298.1053521541,
 -284560.2880325669,
 -284558.5695650549,
 -25604.919656424077,
 -25603.15966610186,
 -171208.72244