In [None]:
from jax import numpy as jnp

from summer2 import CompartmentalModel
from summer2.parameters import Parameter, Function

In [None]:
def build_model():

    sir_model = CompartmentalModel([0.0,100.0],["S","I","R"],["I"])
    sir_model.set_initial_population({"S": 999.0, "I": 1.0})
    sir_model.add_infection_frequency_flow("infection",Parameter("contact_rate"),"S","I")
    sir_model.add_transition_flow("recovery",Parameter("recovery_rate"),"I","R")

    incidence = sir_model.request_output_for_flow("incidence", "infection")

    # Add differential output, with adjustable lag.
    def lag_diff_func(output, lag):
        lagged_diff = output[lag:] - output[:-lag]
        return jnp.insert(lagged_diff, 0, jnp.zeros(lag))

    sir_model.request_function_output("incidence_diff10", Function(lag_diff_func, [incidence, 1]))

    return sir_model

sir_model = build_model()

In [None]:
parameters = {
    "contact_rate": 0.3,
    "recovery_rate": 0.1
}
sir_model.run(parameters)
res = sir_model.get_derived_outputs_df()
res['incidence'].plot()

# Sample from a known distribution

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import truncnorm

def sample_from_truncnorm(mean, std_dev, lower_bound, upper_bound, sample_size, name):
    a = (lower_bound - mean) / std_dev
    b = (upper_bound - mean) / std_dev
    samples = truncnorm.rvs(a, b, loc=mean, scale=std_dev, size=sample_size)

    return pd.DataFrame(samples, columns=[name])

samples = {
    "contact_rate": sample_from_truncnorm(0.225, 0.005, 0.2, 0.25, 10000, "contact_rate"),
    "recovery_rate": sample_from_truncnorm(0.1, 0.005, 0.05, 0.15, 10000, "recovery_rate")
}

In [None]:
import seaborn as sns
sns.kdeplot(samples["contact_rate"], fill=True)

In [None]:
sns.kdeplot(samples["recovery_rate"], fill=True)

# Run model forward (i.e. feed the samples to the model)

In [None]:
from estival.model import BayesianCompartmentalModel
import estival.priors as esp
import estival.targets as est
from estival.sampling import tools as esamp


priors = [
    esp.UniformPrior("contact_rate", [0, 1]),
    esp.UniformPrior("recovery_rate", [0, 1]),
]
targets = []
bcm = BayesianCompartmentalModel(model=sir_model,priors=priors, targets=targets,parameters=parameters)
samples_for_estival = [{"contact_rate": samples["contact_rate"].iloc[i], "recovery_rate": samples["recovery_rate"].iloc[i]} for i in range(len(samples["contact_rate"]))]

model_runs = esamp.model_results_for_samples(samples_for_estival, bcm)

In [None]:
model_runs.results['incidence'].plot(legend=False)

## Collect the synthetic data and generate likelihood components

In [None]:
data_times = list(range(20, 81, 10))
len(data_times)

In [None]:
from jax.scipy.stats import gaussian_kde
import jax.numpy as jnp

use_diffs = False

likelihood_comps = {}
for i, t in enumerate(data_times):
    if not use_diffs or i == 0:
        likelihood_comps[t] = gaussian_kde(jnp.array(model_runs.results['incidence'].loc[t]))
    else:
        likelihood_comps[t] = gaussian_kde(jnp.array(model_runs.results['incidence_diff10'].loc[t]))

In [None]:
# Check one likelihood component
import numpy as np
import matplotlib.pyplot as plt

t = 40
kde = likelihood_comps[t]
x_values = np.linspace(0, 30, 1000)
pdf_values = kde(x_values)
plt.plot(x_values, pdf_values)

model_runs.results['incidence'].loc[t].plot.hist(density=True, bins=50)

# Refit the model using the likelihood components derived from synthetic data

In [None]:
from jax import lax

fitted_output = ['incidence'] * len(data_times) if not use_diffs else ['incidence'] + ['incidence_diff10'] * len(data_times)

# Flat prior
priors = [
    esp.UniformPrior("contact_rate", [0.1, 0.3]),
    esp.UniformPrior("recovery_rate", [0.01, 0.2])
]
n_data_points = len(data_times)
# Define a custom target using the likelihood components
def make_eval_func(t):
    def eval_func(modelled, obs, parameters, time_weights):
        likelihood_comp = likelihood_comps[t](modelled) 
        likelihood_comp = jnp.max(jnp.array([likelihood_comp, jnp.array([1.e-300])]))  # to avoid zero values.
        return jnp.log(likelihood_comp) / n_data_points

    return eval_func

targets = [est.CustomTarget(f"likelihood_comp_{t}", pd.Series([0.], index=[t]), make_eval_func(t), model_key=fitted_output[i]) for i, t in enumerate(data_times)]

refit_bcm = BayesianCompartmentalModel(model=sir_model,priors=priors, targets=targets,parameters=parameters)

In [None]:
import pymc as pm
from estival.wrappers import pymc as epm

with pm.Model() as model:    
    variables = epm.use_model(refit_bcm)
    idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=5000, tune=1000,cores=4,chains=4)  #, initvals = [{'contact_rate': x} for x in [0.1, 0.4 ,0.6 ,0.8]])


In [None]:
import arviz as az

In [None]:
az.plot_trace(idata)

In [None]:
posterior_sample = idata.posterior.to_dataframe()['contact_rate'].to_list()
sns.kdeplot(samples["contact_rate"], fill=True, label="true sample")
sns.kdeplot(posterior_sample, fill=True, label="posterior")
plt.legend()

In [None]:
lls = esamp.likelihood_extras_for_idata(idata, refit_bcm)

In [None]:
lls['logposterior'].min()

In [None]:
lls['logposterior'].plot.hist()

In [None]:
posterior_sample = idata.posterior.to_dataframe()['recovery_rate'].to_list()
sns.kdeplot(samples["recovery_rate"], fill=True, label="true sample")
sns.kdeplot(posterior_sample, fill=True, label="posterior")
plt.legend()

In [None]:
posterior_model_runs = esamp.model_results_for_samples(idata, refit_bcm)

In [None]:
posterior_model_runs.results['incidence'].plot(legend=False)

In [None]:
model_runs.results['incidence'].plot(legend=False)

In [None]:
import pandas as pd
import numpy as np

def create_grid(x_range, y_range, step, names):
    x_values = np.arange(x_range[0], x_range[1] + step, step)
    y_values = np.arange(y_range[0], y_range[1] + step, step)
    X, Y = np.meshgrid(x_values, y_values)
    grid_df = pd.DataFrame({names["x"]: X.ravel(), names["y"]: Y.ravel()})
    return grid_df

# Define ranges and step size
x_range = (0.1, 0.4)
y_range = (0.05, .25)
step = .001

# Create grid DataFrame
grid_df = create_grid(x_range, y_range, step, names={"x": "contact_rate", "y": "recovery_rate"})


In [None]:
ll_outputs = esamp.likelihood_extras_for_samples(grid_df, refit_bcm)

In [None]:
idx = grid_df['logposterior'].idxmax()
grid_df.iloc[idx]

In [None]:
grid_df_tip = grid_df[grid_df['logposterior'] > -1]
fig = px.scatter_3d(grid_df_tip, x='contact_rate', y='recovery_rate', z='logposterior', opacity=0.5, color='logposterior')
fig.update_traces(mode='markers', marker=dict(size=5))
fig.show()

In [None]:
import plotly.express as px

grid_df['logposterior'] = ll_outputs['logposterior']
fig = px.scatter_3d(grid_df, x='contact_rate', y='recovery_rate', z='logposterior', opacity=0.5, color='logposterior')
fig.update_traces(mode='markers', marker=dict(size=5))
fig.show()

In [None]:
fig = px.density_heatmap(grid_df, x='contact_rate', y='recovery_rate', z='logposterior')
fig.show()

In [None]:
px.density_contour?
