# How to use the normal approximations module

In [None]:
from os import getcwd, path

from pandas import read_csv, read_json

from bayes_chime.normal.models import SEIRModel
from bayes_chime.normal.utilities import one_minus_logistic_fcn

import bayes_chime.normal.fitting as ft
import bayes_chime.normal.plotting as pl

from seaborn import FacetGrid, distplot
from matplotlib.pylab import show as show_plot

from importlib import reload

In [None]:
ROOT = path.dirname(getcwd())
RUN = "2020_04_22_09_07_17"

In [None]:
OUTPUT = path.join(ROOT, "output", RUN)
DATA = path.join(ROOT, OUTPUT, "parameters")

In [None]:
data_df = (
    read_csv(path.join(DATA, "census_ts.csv"), parse_dates=["date"])
    .dropna(how="all", axis=1)
    .fillna(0)
    .set_index("date")
    .astype(int)
)
data_df.head()

In [None]:
prior_df = read_csv(path.join(DATA, f"params.csv"))
priors = ft.fit_norm_to_prior_df(prior_df)
prior_df

In [None]:
g = FacetGrid(
    prior_df.query("distribution != 'constant'"),
    col="param",
    col_wrap=5,
    sharex=False,
    sharey=False,
)
g.map_dataframe(pl.plot_prior_fit)
show_plot(g)

The line below may take a while maybe `HDF5` might be a more suiteable format

In [None]:
posterior_df = read_json(
    path.join(OUTPUT, "output", "chains.json.bz2"), orient="records", lines=True
)
drop_cols = [
    col
    for col in posterior_df.columns
    if not col in prior_df.param.values and col != "offset"
]
posterior_df = posterior_df.drop(columns=drop_cols)

In [None]:
posterior_df.head()

In [None]:
posteriors = {}
for col in posterior_df.columns:
    if isinstance(priors.get(col, 0), float):
        continue
    posteriors[col] = ft.fit_norm_dist_to_ens(posterior_df[col].values)

posteriors

In [None]:
def fcn(**kwargs):
    distplot(a=kwargs["data"].x.values)

In [None]:
stacked = posterior_df.stack().reset_index().drop(columns=["level_0"]).rename(
    columns={"level_1": "param", 0: "x"}
)
g = FacetGrid(
    stacked,
    col="param",
    col_wrap=5,
    sharex=False,
    sharey=False,
)
g.map_dataframe(fcn)
show_plot(g)