<a href="https://colab.research.google.com/github/sushmit86/Statiistical_rethinking_jax/blob/main/sushmit_02_small_worlds_and_large_worlds_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [40]:
# Install packages that are not installed in colab
try:
  import google.colab
  !pip install watermark
  !pip install jaxopt
except:
  pass



In [41]:
%load_ext watermark

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark


In [42]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jaxopt
import pandas as pd
plt.style.use('fivethirtyeight')
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp_jax
import plotly.express as px
from jax.scipy.optimize import minimize
tfd_jax = tfp_jax.distributions
# jax.config.update("jax_enable_x64", True)

In [43]:
%watermark -p numpy,tensorflow,tensorflow_probability,scipy,pandas,jax,jaxopt

numpy                 : 1.23.5
tensorflow            : 2.15.0
tensorflow_probability: 0.22.0
scipy                 : 1.11.4
pandas                : 1.5.3
jax                   : 0.4.23
jaxopt                : 0.8.2



In [44]:
#@title Calculate the probability of observing 6 W out 9 with 0.5
W = 6
L = 3
tfd_jax.Binomial(total_count = W +L, probs= 0.5).prob(W)

Array(0.16406271, dtype=float32)

In [45]:
#@title Grid approximation
# create a grid
grid = jnp.linspace(0,1,100)
# create the prior Uniform distribution
prior = jnp.ones_like(grid)
# Likelihood = Binomial
unstd_post = prior * tfd_jax.Binomial(total_count = W +L,probs=grid).prob(W)
std_post = unstd_post/unstd_post.sum()
fig = px.line(x=grid,y=std_post)
fig.update_layout( autosize=False,
    width=500,
    height=500)
fig.show()

In [46]:
#@title Quadratic Approximation
## Approximate using the hessian and normal distb
## few concepts -- you need BFGS optimizer, grad
## first we need define the log likehoof
key = jax.random.PRNGKey(11)
dist = tfd_jax.JointDistributionNamed(
    {
        "probability":tfd_jax.Uniform(low= 0.0, high = 1.0),
        "water": lambda probability:tfd_jax.Binomial(total_count = W +L,probs = probability)
    }
)
def neg_log_prob(x):
    return tfp_jax.math.value_and_gradient(
        lambda p: -dist.log_prob(
            water=W, probability= jnp.clip(p[-1],0,1)),
        x,
    )


optim_results = tfp_jax.optimizer.bfgs_minimize(neg_log_prob,
                                                initial_position=[0.5],
                                                tolerance = 1e-6)
assert(optim_results.converged)

print(optim_results)
mean = optim_results.position
std_dev = jnp.sqrt(optim_results.inverse_hessian_estimate)
quad_approx_dist = tfd_jax.Normal(loc = mean,scale = std_dev)

BfgsOptimizerResults(converged=Array(True, dtype=bool), failed=Array(False, dtype=bool), num_iterations=Array(4, dtype=int32), num_objective_evaluations=Array(13, dtype=int32), position=Array([0.6666667], dtype=float32), objective_value=Array(1.2978106, dtype=float32), objective_gradient=Array([9.536743e-07], dtype=float32), inverse_hessian_estimate=Array([[0.02470744]], dtype=float32))


In [47]:
# test_fn = lambda x: -dist.log_prob(water=W, probability= jnp.clip(x,0,1))[0]

# optim_scipy_results = minimize(test_fn,x0 = jnp.array([0.5]),method="BFGS",tol=1e-6)
# optim_scipy_results

In [48]:
analytical_solution = tfd_jax.Beta(concentration1 = 1+W, concentration0 = 1+L)
grid = jnp.linspace(0,1,100)
df_soln = pd.DataFrame({'x':grid,'analytical_solution':analytical_solution.prob(grid),
             'quad_approx':quad_approx_dist.prob(grid).reshape(100,)})
fig = px.line(df_soln, x='x', y=['analytical_solution', 'quad_approx'])


fig.update_layout( autosize=False,
    width=500,
    height=500)
fig.show()