In [None]:
import cmdstanpy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.integrate import odeint

# Set working directory the upstream folder
import os
os.chdir('..')

In [None]:
# Load the data
data = pd.read_csv('datasets/observed_counts.csv')

# Compile the Stan model (make sure to point to the correct Stan file)
model = cmdstanpy.CmdStanModel(stan_file='stan_models/simple_homogeneous_model.stan') 

In [None]:
# Prepare data for Stan (adjust based on your data structure)
stan_data = {
    't0': 10,         # Initial time
    'y0': data['cell_counts'].values[0],       # Initial observed value
    'y_obs': data['cell_counts'].values[1:].tolist(),  # Observed values -- drop first time point since it's initial condition
    'time_obs': data['time'].values[1:].tolist(),  # Observation times -- drop first time point
    'T': len(data) - 1,  # Number of time points (exclude initial condition)
}

# Fit the model --- run chains in parallel
fit = model.sample(data=stan_data, chains=4, parallel_chains=4, iter_sampling=500, show_progress=False, iter_warmup=300) # some exceptions may occur here depending on data structure and model priors

In [None]:
# extract posterior samples
posterior = fit.draws_pd()

# parameter names based on the Stan model
param_names = ['s', 'b', 'd', 'sigma']

# plot bivariate distribution for all parameters
fig, axes = plt.subplots(len(param_names), len(param_names), figsize=(12, 10))
for i in range(len(param_names)):
    for j in range(len(param_names)):
        axes[i, j].set_facecolor('black')  # Set axes background to black
        if i == j:
            sns.kdeplot(x=posterior[param_names[i]], ax=axes[i, j], fill=True, color='orange')
            axes[i, j].set_xlabel(param_names[i])
            axes[i, j].set_ylabel('Density')
        elif i < j:
            sns.kdeplot(x=posterior[param_names[i]], y=posterior[param_names[j]], fill=True, cmap='inferno', thresh=0, levels=10, ax=axes[i, j])
            axes[i, j].set_xlabel(param_names[i])
            axes[i, j].set_ylabel(param_names[j])
        else:
            axes[i, j].axis('off')
plt.tight_layout()
plt.show()

In [None]:
# solve exponential ode using posterior samples
y_pred_samples = []

time_points = np.linspace(10, data['time'].max(), 100)
model_params = (posterior['s'], posterior['b'], posterior['d'])

for i in range(len(posterior['s'])):
    s = posterior['s'][i]
    b = posterior['b'][i]
    d = posterior['d'][i]

    def model_ode(y, t):
        dydt = s + b * y - d * y
        return dydt
    y0_val = data['cell_counts'].values[0]
    y_pred = odeint(model_ode, y0_val, time_points)
    y_pred_samples.append(y_pred.flatten())

# Plot results using seaborn
plt.figure(figsize=(8, 5))
sns.lineplot(x=time_points, y=np.median(y_pred_samples, axis=0), color='black', label='Median Prediction')
plt.fill_between(time_points,
                 np.percentile(y_pred_samples, 2.5, axis=0),
                 np.percentile(y_pred_samples, 97.5, axis=0),
                 color='black', alpha=0.3, label='95% Credible Interval')
plt.scatter(data['time'], data['cell_counts'], color='black', label='Observations')
plt.xlabel('Time')
plt.xlim(5, 60)
plt.yscale('log')
plt.ylabel('Cell counts')
plt.title('Simple Homogeneous Model Fit')
plt.legend()
plt.grid()

In [None]:
# Print summary
print(fit.summary())