In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#!which python

In [None]:
from model import model_sim, param_gen
import numpy as np
import torch
import json
import os

In [None]:
path = 'data/environments'
json_files = [file for file in os.listdir(path) if file.endswith('_canon.json')]
envList = []
for file in json_files:
    f=open(os.path.join(path, file))
    envList.append(json.load(f))

torch.manual_seed(1)

np.random.seed(42)
pars = param_gen(4, 20, hom=True, models=3)
simulations = model_sim(pars, envList, 8, 15, payoff=True)

In [None]:
simulations.head()

In [None]:
np.sort(simulations.env.unique())

In [None]:
theta = torch.tensor(simulations[['env', 'trial', 'lambda', 'beta', 'tau', 'eps_soc']].to_numpy(), dtype=torch.float32)
# divide env by max(env) and trial by max(trial) to normalize
theta[:, 0] = theta[:, 0] / theta[:, 0].max()
theta[:, 1] = theta[:, 1] / theta[:, 1].max()

x = torch.tensor(simulations['choice'].to_numpy(), dtype=torch.float32).unsqueeze(1)

In [None]:
from sbi.utils import BoxUniform, MultipleIndependent, mcmc_transform
from torch.distributions import LogNormal, Exponential

n_env = simulations.env.unique().shape[0]
prior = MultipleIndependent(
    [
        BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),   # env
        BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),   # trial
        LogNormal(torch.tensor([-0.75]), torch.tensor([0.5])),  # lambda
        LogNormal(torch.tensor([-0.75]), torch.tensor([0.5])),  # beta
        LogNormal(torch.tensor([-4.5]), torch.tensor([0.9])),   # tau
        Exponential(torch.tensor([2.0])),  # eps_soc
    ],
    validate_args=False,
)

prior_transform = mcmc_transform(prior)
num_simulations = 1000
prior.sample((num_simulations,))

In [None]:
from sbi.inference import NLE
trainer = NLE(prior)

In [None]:
estimator = trainer.append_simulations(theta, x)

In [None]:
density_estimator = estimator.train()

In [None]:
posterior = trainer.build_posterior(density_estimator)

print(posterior) # prints how the posterior was trained

In [None]:
posterior_sample = posterior.sample((10,), x=x[:20, :])

In [None]:
from sbi.analysis import pairplot
_ = pairplot(posterior_sample,
             # limits=[[-2, 2], [-2, 2], [-2, 2]],
             # figsize=(6, 6),
             labels=["env", "trial", r"$\lambda$", r"$\beta$", r"$\tau$", r"$\epsilon_{soc}$"],
             )