In [1]:
import os
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt

this_dir = str(Path().resolve())

In [2]:
def lognormal(x, mu, sigma):
    fact1 = (1 / (x * sigma * jnp.sqrt(2 * jnp.pi)))
    fact2 = jnp.exp(-((jnp.log(x) - mu) ** 2) / (2 * sigma**2))
    return fact1 * fact2

In [None]:
# alpha
mode = 1.2
sigma_alpha = 0.3
mu_alpha = jnp.log(mode) + sigma_alpha**2
print(mu_alpha)

# beta
mode = 8.0
sigma_beta = 0.2
mu_beta = jnp.log(mode) + sigma_beta**2
print(mu_beta)

mus = [mu_alpha,  mu_beta]
sigmas = [sigma_alpha, sigma_beta]

x = jnp.linspace(0.01, 15, 1000)

labels = [r"$\alpha$",  r"$\beta$"]
lss = ["-"] * 4
cs = ["r", "b"]  # ['k']*4
alpha = 0.5

ts = 3
fs, rat = 6, 2.5
fig, ax = plt.subplots(1, 1, figsize=(fs * rat, fs))

for i in range(2):
    f = lognormal(x, mus[i], sigmas[i])
    ax.fill_between(x, f, color=cs[i], alpha=alpha, lw=fs * 0.2, label=labels[i])

ax.grid(True, alpha=0.5)

ax.set_xlim(0, x.max())
ax.set_ylim(0, 1.3)

ax.set_xlabel("x", fontsize=ts * fs)
ax.set_ylabel("pdf", fontsize=ts * fs)

ax.tick_params(labelsize=ts * fs)
ax.legend(fontsize=fs*ts)

fig.savefig(os.path.join(this_dir, "article/b_priors.png"), bbox_inches="tight")
plt.close()

In [None]:
# alpha
mode = 1.0
sigma_alpha = 0.3
mu_alpha = jnp.log(mode) + sigma_alpha**2

# rho_hp
mode = 10.0
sigma_rho_hp = 0.1
mu_rho_hp = jnp.log(mode) + sigma_rho_hp**2

# e_hp
mode = 1.0
sigma_e_hp = 0.4
mu_e_hp = jnp.log(mode) + sigma_e_hp**2

# beta
mode = 8.0
sigma_beta = 0.2
mu_beta = jnp.log(mode) + sigma_beta**2

mus = [mu_alpha, mu_rho_hp, mu_e_hp, mu_beta]
sigmas = [sigma_alpha, sigma_rho_hp, sigma_e_hp, sigma_beta]

x = jnp.linspace(0.01, 15, 1000)

labels = [r"$\alpha$", r"$\rho_{\text{hp}}$",  r"$\epsilon_{\text{hp}}$",  r"$\beta$"]
lss = ["-"] * 4
cs = ["r", "b", "g", "y"]  # ['k']*4
alpha = 0.5

ts = 2
fs, rat = 6, 1.5
fig, ax = plt.subplots(1, 1, figsize=(fs * rat, fs))

for i in range(4):
    f = lognormal(x, mus[i], sigmas[i])
    ax.fill_between(x, f, color=cs[i], alpha=alpha, lw=fs * 0.2, label=labels[i])

ax.grid(True, alpha=0.5)

ax.set_xlim(0, x.max())
ax.set_ylim(0, 1.3)

ax.set_xlabel("x", fontsize=ts * fs)
ax.set_ylabel("pdf", fontsize=ts * fs)

ax.tick_params(labelsize=ts * fs)
ax.legend(fontsize=fs*ts)

fig.savefig(os.path.join(this_dir, "article/b_priors.png"), bbox_inches="tight")
plt.close()