### Install and import packages

In [1]:
# Import packages
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
sns.set_palette("tab10")

import random as rnd
import numpy as np
rnd.seed(0)
np.random.seed(0)

import scipy.stats as stats
import bayes_logistic

from utils import *

### Set parameters

In [2]:
# Set parameters
# data_x parameters
num_data = 2
num_data_half = num_data // 2
num_feats = 2

# num_samples 
num_samples = 10000

# weights prior distribution parameters
weights_prior_params = [
    [0.0, 0.0], [[1.0, 0.0], [0.0, 1.0]]]

# init sigma used in numerical optimization for laplace approximation
laplace_num_iters = 1000

### Generate data_x

In [3]:
# Generate data_x
num_data = 50
# data_x = stats.uniform.rvs(0, 1, size=(num_data, num_feats), random_state=12345)
data_x = np.array([[0, 1], [1, 0]])
print(data_x.shape)

(2, 2)


### Sample weights, generate sample y from sample weights and visualize data

### Generate prior and posterior samples

Generate two set of prior samples A and B

$$
\begin{align*}
    & A = \{ \theta_i \}_{i=1}^N, \; \theta_i \sim p(\theta) \\
    & B = \{ \theta_i \}_{i=1}^N, \; \theta_i \sim p(\theta)
\end{align*}
$$

Generate a set of posterior from prior samples A

Notes: 
- $x$ is fixed and generated from the above procedure. Only $y_i$ is generated from $\theta_i$.

$$
\begin{align*}
    & C' = \{ \theta_i' \}_{i=1}^N, \\
    & \theta_i' \sim p(\theta|x, y_i), \\
    & y_i \sim p(y_i|x, \theta_i), \; \theta_i \in A
\end{align*}
$$

In [4]:
### Generate prior and posterior samples

# generate samples A of weights prior
samples_a_weights_prior = stats.multivariate_normal.rvs(
    weights_prior_params[0], weights_prior_params[1], 
    size=(num_samples), random_state=1)

# generate samples B of weights prior
samples_b_weights_prior = stats.multivariate_normal.rvs(
    weights_prior_params[0], weights_prior_params[1], 
    size=(num_samples), random_state=11)

# generate samples C of weights posterior using samples A of weights prior
samples_a_weights_posterior = []
samples_a_weights_posterior_params = []
for sidx in range(num_samples):
    # for each sample w_i in A
    sample_a_weights_prior = samples_a_weights_prior[sidx].reshape(1, num_feats)
    
    # generate sample y_i from Ber(x, w_i)
    sample_a_logit = 1.0 / (1 + np.exp(-np.matmul(data_x, sample_a_weights_prior.T)))
    sample_a_y = stats.bernoulli.rvs(sample_a_logit)
    
    # fit laplace approximation for pair (x, y_i)
    w_map, h_map = bayes_logistic.fit_bayes_logistic(
        y = sample_a_y.squeeze(-1),
        X = data_x, 
        wprior = np.array(weights_prior_params[0]), # initialize wprior same as prior params
        H = np.linalg.inv(np.array(weights_prior_params[1])),
        weights = None,
        solver = "Newton-CG",
        bounds = None,
        maxiter = laplace_num_iters
    )
    cov_map = np.linalg.inv(h_map)
    
    # sample weights' posterior p(w|x,y_i)
    sample_a_weights_posterior = stats.multivariate_normal.rvs(w_map, cov_map)
    samples_a_weights_posterior.append(sample_a_weights_posterior)  
    samples_a_weights_posterior_pdf.append(
        stats.multivariate_normal.pdf(
            sample_a_weights_posterior, w_map, cov_map))
    # samples_a_weights_posterior_params.append([w_map, cov_map])

samples_a_weights_prior = np.vstack(samples_a_weights_prior)
samples_b_weights_prior = np.vstack(samples_b_weights_prior)
samples_a_weights_posterior = np.vstack(samples_a_weights_posterior)
samples_a_weights_posterior_pdf = np.vstack(samples_a_weights_posterior)
print(samples_a_weights_prior.shape)
# print(samples_b_weights_prior)
# print(samples_a_weights_posterior)

NameError: name 'samples_a_weights_posterior_pdf' is not defined

### Visualize the generated prior and posterior samples

In [None]:
# # Visualize the generated prior and posterior samples, individual features
# nrows = 2
# fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True, figsize=(8,8))
# axes = axes.flatten()

# for i in range(nrows):
#     sns.kdeplot(samples_a_weights_prior[:,i], fill=False, color="red", label="samples_a_prior", ax=axes[i])
#     sns.kdeplot(samples_b_weights_prior[:,i], fill=False, color="green", label="samples_b_prior", ax=axes[i])
#     sns.kdeplot(samples_a_weights_posterior[:,i], fill=False, color="blue", label="samples_a_posterior", ax=axes[i])
#     axes[i].legend(loc="upper right")
# plt.show()

In [None]:
# Visualize the generated prior and posterior samples, individual features
nrows = 2
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True, figsize=(10,8))
axes = axes.flatten()

num_hist = 5
samples_a_weights_posterior_all = []
for j in range(num_hist):
    idx = np.random.randint(num_samples)
    w_map_j, cov_map_j = samples_a_weights_posterior_params[idx]
    samples_a_weights_posterior_j = stats.multivariate_normal.rvs(w_map_j, cov_map_j, size=(10000), random_state=j)
    samples_a_weights_posterior_all.append(samples_a_weights_posterior_j)
    
for i in range(nrows):
    sns.kdeplot(samples_a_weights_prior[:,i], fill=False, color="red", label="samples_a_prior", ax=axes[i])
    sns.kdeplot(samples_b_weights_prior[:,i], fill=False, color="green", label="samples_b_prior", ax=axes[i])
    sns.kdeplot(samples_a_weights_posterior[:,i], fill=False, color="blue", label="samples_a_posterior", ax=axes[i])
    
    for j in range(num_hist):
        samples_a_weights_posterior_j = samples_a_weights_posterior_all[j]
        sns.kdeplot(samples_a_weights_posterior_j[:,i], fill=False, color="lightsteelblue", linestyle="--",
                    label=f"samples_a_posterior_{j}", ax=axes[i])
        
    axes[i].legend(loc="upper right")
plt.show()

In [None]:
# Visualize the generated prior and posterior samples, pair of features
fig, axes = plt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(10,10))
axes = axes.flatten()

sns.kdeplot(x=samples_a_weights_prior[:,0], y=samples_a_weights_prior[:,1], n_levels=20, 
            cmap="inferno", fill=False, ax=axes[0])

sns.kdeplot(x=samples_b_weights_prior[:,0], y=samples_b_weights_prior[:,1], n_levels=20, 
            cmap="inferno", fill=False, ax=axes[1])

sns.kdeplot(x=samples_a_weights_posterior[:,0], y=samples_a_weights_posterior[:,1], n_levels=20, 
            cmap="inferno", fill=False, ax=axes[2])
axes[0].set_aspect(aspect="equal")
axes[1].set_aspect(aspect="equal")
axes[2].set_aspect(aspect="equal")
axes[0].set_title("samples_a_prior")
axes[1].set_title("samples_b_prior")
axes[2].set_title("samples_a_posterior")
plt.show()

### Measure the differences between the prior and samples

* Kernelized two sample test: maximum mean distance with RBF kernel
* Wasserstein distance of two samples
* Difference between the standard deviations (from true mean) of two samples

In [None]:
# Maximum mean distance with RBF kernel
mmd_rbf_prior_a_prior_b = compute_mmd_rbf(samples_a_weights_prior, samples_b_weights_prior)
mmd_rbf_posterior_a_prior_b = compute_mmd_rbf(samples_a_weights_posterior, samples_b_weights_prior)
print(f"MMD between prior a and prior b: {mmd_rbf_prior_a_prior_b:0.5f}")
print(f"MMD between posterior a and prior b: {mmd_rbf_posterior_a_prior_b:0.5f}")

In [None]:
# Wasserstein distance with RBF kernel
wd_prior_a_prior_b = compute_wasserstein_distance(samples_a_weights_prior, samples_b_weights_prior)
wd_posterior_a_prior_b = compute_wasserstein_distance(samples_a_weights_posterior, samples_b_weights_prior)
print(f"Wasserstein distance between prior a and prior b: {wd_prior_a_prior_b:0.5f}")
print(f"Wasserstein distance between posterior a and prior b: {wd_posterior_a_prior_b:0.5f}")

In [None]:
# Difference between the standard deviations (from true mean) of two samples
diff_std_prior_a_prior_b = compute_diff_std(samples_a_weights_prior, samples_b_weights_prior, weights_prior_params[0])
diff_std_posterior_a_prior_b = compute_diff_std(samples_a_weights_posterior, samples_b_weights_prior, weights_prior_params[0])
print(f"Difference standard deviations between between prior a and prior b: {diff_std_prior_a_prior_b:0.5f}")
print(f"Difference standard deviations between posterior a and prior b: {diff_std_posterior_a_prior_b:0.5f}")