goal
- want to understand behavior of lognormal approximation to gamma

In [None]:
# Reference:
#    
import scipy
import numpy as onp
onp.set_printoptions(precision=3,suppress=True)

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn
from jax.scipy.stats import dirichlet

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.tri as tri
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman' 
cmap = plt.cm.get_cmap('bwr')

from tabulate import tabulate

from plt_utils import *
from gpax import *


In [None]:
αs = np.array([[1.,1.,1.],
               [2,1.,1.],
               [10,1.,1.]])

fig, axs = plt.subplots(1, 3, figsize=(15,5))

for i in range(3):
    ax = axs[i]
    α = αs[i]
    plt_2simplex_dirichlet_pdf(ax, α, levels=200)
    plt_2simplex_scatter(ax, [α/np.sum(α)], c='r')

In [None]:

fig, axs = plt.subplots(1, 2, figsize=(20,10))

ax = axs[0]

Z = np.linspace(0,10,100)
lnΓ = jax.scipy.special.gammaln(Z)
lnΓ_apx = Z*np.log(Z) - Z - .5*np.log(Z) + .5*np.log(2*np.pi) 

ax.plot(Z, lnΓ, label='lnΓ(Z)')
ax.plot(Z, lnΓ_apx, label='lnΓ_apx(Z) = (Z-.5)lnZ-Z+.5ln(2π)')
ax.legend()
ax.grid()
ax.set_title('Approximate lnΓ')


ax = axs[1]


Z = np.linspace(0,10,100)
Ψ0 = jax.scipy.special.polygamma(0, Z)
Ψ1 = jax.scipy.special.polygamma(1, Z)
Ψ0_apx = np.log(Z) - 1/(2*Z)

ax.plot(Z, Ψ0, label='ψ0(Z)')
ax.plot(Z, Ψ0_apx, label='ψ0_apx(Z) = ln(Z)-1/(2Z)')
ax.legend()
ax.grid()
ax.set_title('Approximate digamma ψ0')


fig.tight_layout()
plt_savefig(fig, 'summary/assets/note_lognormal_gamma_apx_polyagamma.png')
    



In [None]:
# Compare different Lognormal approximation to Gamma(α, 1)


fig, ax = plt.subplots(1, 1, figsize=(10,10))

α = np.logspace(-3,3,100,base=np.ℯ)
kl_μ, kl_σ2 = gamma_to_lognormal(α, approx_type='kl')
mt_μ, mt_σ2 = gamma_to_lognormal(α, approx_type='moment')

ax.plot(kl_μ, α, label='ln(α)-.5/α -> α')
ax.plot(mt_μ, α, label='ln(α)-.5ln(1/α+1) -> α')

ax.set_xlabel('μ', fontsize=40)
ax.set_ylabel('α', fontsize=40)
ax.legend()
ax.grid()
ax.set_yscale('log')


fig.tight_layout()
plt_savefig(fig, 'summary/assets/note_lognormal_gamma_conversion.png')
    




In [None]:
key = random.PRNGKey(0)

n = 300
Z = np.linspace(0+1e-10,10,n).reshape(-1,1)

α = np.array([.9,1,1.1,2,10.])
σ̃2 = np.log(1/α + 1)
ỹ = np.log(α) - .5*σ̃2
σ2 = 1/α
μ = np.log(α) - .5*σ2

X̃pdf = jax.scipy.stats.gamma.pdf(Z, α)
Xpdf = scipy.stats.lognorm.pdf(Z, s=np.sqrt(σ̃2), scale=np.exp(ỹ))
X̂pdf = scipy.stats.lognorm.pdf(Z, s=np.sqrt(σ2), scale=np.exp(μ))

m = len(α)
fig, axs = plt.subplots(m,1,figsize=(10,m*10))


for i in range(X̃pdf.shape[1]):
    ax = axs[i]
    ax.plot(Z, X̃pdf[:,i], label=f'Gamma(α={α[i]:.3f},1)')
    ax.plot(Z, Xpdf[:,i], label=f'LognormalMT(μ={ỹ[i]:.2},σ2={σ̃2[i]:.1})')
    ax.plot(Z, X̂pdf[:,i], label=f'LogNormalKL(μ={μ[i]:.2},σ2={σ2[i]:.1})')
    ax.legend()
    ax.grid()





In [None]:
key = random.PRNGKey(0)

n = 300
Z = np.linspace(0+1e-10,10,n).reshape(-1,1)

μ = np.array([.5, 1, 2, 3])
σ2 = np.array([1,1,1,1.])*.2
β = 1/(σ2*np.exp(μ+σ2/2))
α = 1/σ2


X1pdf = jax.scipy.stats.gamma.pdf(Z, α, scale=1)
X2pdf = scipy.stats.lognorm.pdf(Z, s=np.sqrt(σ2), scale=np.exp(μ+np.log(β)))

m = len(α)
fig, axs = plt.subplots(m,1,figsize=(10,m*10))


for i in range(X1pdf.shape[1]):
    ax = axs[i]
    ax.plot(Z, X1pdf[:,i], label=f'Gamma(α={α[i]:.3f},1)')
    ax.plot(Z, X2pdf[:,i], label=f'LN(μ={ỹ[i]:.2},σ2={σ̃2[i]:.1})')
    ax.legend()
    ax.grid()





In [None]:
key = random.PRNGKey(0)

n = 300
Z = np.linspace(0+1e-10,20,n).reshape(-1,1)

μ = np.array([.5, 1, 2, 3])
σ2 = np.array([1,1,1,1.])*.2
α = 1/σ2
β = 1/(σ2*np.exp(μ+σ2/2))

gapdf = jax.scipy.stats.gamma.pdf(Z, α, scale=1/β)
lnpdf = scipy.stats.lognorm.pdf(Z, s=np.sqrt(σ2), scale=np.exp(μ))
gapdf2 = jax.scipy.stats.gamma.pdf(Z*β, α)
lnpdf2 = scipy.stats.lognorm.pdf(Z, s=np.sqrt(σ2), scale=np.exp(np.log(α) -1/(2*α) - np.log(β)))

m = len(α)
fig, axs = plt.subplots(m,1,figsize=(10,m*10))

for i in range(m):
    ax = axs[i]
    ax.plot(Z, gapdf[:,i], label=f'Gamma(α={α[i]:.3f},β={β[i]:.3f})')
    ax.plot(Z, lnpdf[:,i], '--',label=f'LogNormal(μ={μ[i]:.2},σ2={σ2[i]:.1})')
    ax.plot(Z, gapdf2[:,i], label=f'Gamma2(α={α[i]:.3f},β={β[i]:.3f})')
    ax.plot(Z, lnpdf2[:,i], '.', label=f'LogNormal2(μ={μ[i]:.2},σ2={σ2[i]:.1})')
    ax.legend()
    ax.grid()



In [None]:
# keep α = 1/(2*W(1/2e^μ)) same and vary variance see effect 


key = random.PRNGKey(0)

n = 300
Z = np.linspace(0+1e-10,20,n).reshape(-1,1)

α = np.ones((5,))*4
β = np.ones((5,))
μ = np.log(α) - 1/(2*α)
σ2 = np.array([.1,.5,1,2.,5])


gapdf = jax.scipy.stats.gamma.pdf(Z, α, scale=1/β)
lnpdf = scipy.stats.lognorm.pdf(Z, s=np.sqrt(σ2), scale=np.exp(μ))
m = len(α)
fig, axs = plt.subplots(m,1,figsize=(10,m*10))

for i in range(m):
    ax = axs[i]
    ax.plot(Z, gapdf[:,i], label=f'Gamma(α={α[i]:.3f},β={β[i]:.3f})')
    ax.plot(Z, lnpdf[:,i], '--',label=f'LogNormal(μ={μ[i]:.2},σ2={σ2[i]:.1})')
    ax.legend()
    ax.grid()



