In [None]:
import numpy as np

In [None]:
import matplotlib
import matplotlib.pyplot as plt

In [None]:
import stan
import asyncio
import nest_asyncio
nest_asyncio.apply()

## Nonlinear model example

Following [Gelman+2015](http://www.stat.columbia.edu/~gelman/research/published/stan_jebs_2.pdf).

Fitting $y = a_1e^{−b_1x} + a_2e^{−b_2x}$, with samples
$y_i = \left( a_1e^{−b_1x} + a_2e^{−b_2x} \right) * \epsilon_i$ for $i = 1, \ldots, n$ with $\log \epsilon_i \sim N(0,\sigma^2)$.

Constrain $b_2 > b_1$---separates the components.

In [None]:
stan_code = '''
data {
    int N;
    vector [N] x;
    vector [N] y;
}

parameters {
    vector[2] log_a;
    ordered[2] log_b;
    real<lower=0> sigma;
}

transformed parameters {
    vector<lower=0>[2] a;
    vector<lower=0>[2] b;
    a <- exp( log_a );
    b <- exp( log_b );
}

model {
    vector[N] ypred;
    ypred <- a[1] * exp(-b[1]*x) + a[2] * exp(-b[2]*x);
    y ~ lognormal(log(ypred), sigma);
}
'''

## Create fake data

In [None]:
params = {
    'a': [ 0.8, 1 ],
    'b': [ 2, 0.1 ],
    'sigma': 0.2,
}

In [None]:
N = 1000
x = np.linspace( 0, 10, N )
ypred = params['a'][0] * np.exp( - params['b'][0] * x ) * params['a'][1] * np.exp( - params['b'][1] )

In [None]:
rng = np.random.default_rng()
error = rng.normal( 0, params['sigma'], N )

In [None]:
y_fake = ypred * np.exp( error )

In [None]:
fig = plt.figure()
ax = plt.gca()

ax.scatter(
    x,
    y_fake,
)

ax.plot(
    x,
    ypred,
    color = 'r',
)

ax.set_xlabel( 'x' )
ax.set_ylabel( 'y' )

ax.set_yscale( 'log' )

## Fit the data

In [None]:
data = {
    'N': N,
    'x': x,
    'y': y_fake,
}

In [None]:
%%capture
posterior = stan.build( stan_code, data=data, random_seed=1 )

In [None]:
%%capture
fit = posterior.sample(num_chains=4, num_samples=1000)

In [None]:
fit_df = fit.to_frame()

## Plot fit

In [None]:
fig = plt.figure()
main_ax = plt.gca()

ax_dict = fig.subplot_mosaic(
    [
        [ 'a.1', 'a.2' ],
        [ 'b.1', 'b.2', ],
    ]
)
     
for ax_key, ax in ax_dict.items():
    
    value_min = np.nanpercentile( np.log10( fit_df[ax_key] ), 5 )
    value_max = np.nanpercentile( np.log10( fit_df[ax_key]), 95 )
    bins = np.linspace( value_min, value_max, 512 )
    
    ax.hist(
        np.log10( fit_df[ax_key] ),
        bins
    )