In [5]:
from functools import partial

import matplotlib.pyplot as plt
import pickle
import torch

import sbi
import sbibm
from sbi.inference import SNLE
from sbi.inference import likelihood_estimator_based_potential, MCMCPosterior
from mnle_utils import BernoulliMN, MNLE

%matplotlib inline

In [6]:
with open("ddm_training_data_gitlfs.p", "rb") as fh:
    theta, x = pickle.load(fh).values()

In [None]:
# choices are encoded as the sign of rts
num_simulations = 100000
theta = theta[:num_simulations]
x = x[:num_simulations]
rts = abs(x)
choices = torch.zeros_like(x)
choices[x > 0] = 1

theta_and_choices = torch.hstack((theta, choices))

In [None]:
plt.figure(figsize=(12, 5))
plt.hist(x.numpy(), bins=50, density=True)
plt.xlabel("reaction time [s] (sign encodes choice)")

In [None]:
def build_choice_net(batch_theta, batch_choices, num_choices=2, z_score_theta=False, hidden_features: int=10, hidden_layers: int=2):
    
    dim_parameters = batch_theta[0].numel()
    num_output = num_choices
    
    assert num_choices == 2, "Not implemented for more than two choices."
    
    choice_net = BernoulliMN(n_input=dim_parameters, 
                             n_output=1, 
                             n_hidden_layers=hidden_layers, 
                             n_hidden_units=hidden_features)
    
    if z_score_theta:
        choice_net = nn.Sequential(standardizing_net(batch_y), choice_net)
    
    return choice_net

choice_net_builder = partial(build_choice_net, z_score_theta=False)

In [None]:
def build_rt_flow(batch_theta, batch_x, 
                  num_transforms=2, 
                  hidden_features=10, 
                  num_bins=5, 
                  tail_bound=10.0, 
                 **kwargs
                 ):

    return sbi.neural_nets.flow.build_nsf(batch_x=batch_x, batch_y=batch_theta, 
                                          num_transforms=num_transforms, 
                                          hidden_features=hidden_features, 
                                          num_bins=num_bins, 
                                          tail_bound=tail_bound,
                                          **kwargs
                                         )

In [None]:
rt_trainer = SNLE(density_estimator=build_rt_flow)

rt_flow = rt_trainer.append_simulations(theta_and_choices, rts)

In [None]:
rt_flow = rt_trainer.train(max_num_epochs=100)

In [None]:
trainer = SNLE(density_estimator=choice_net_builder, )
choice_estimator = trainer.append_simulations(theta, choices).train(max_num_epochs=100)

In [None]:
mnle = MNLE(choice_estimator, rt_flow, )

In [None]:
prior = sbibm.get_task("ddm").get_prior_dist()
num_trials = 100
tho = prior.sample((1,))
x_o = mnle.sample(num_trials, tho)

In [None]:
potential_fn, parameter_transform = likelihood_estimator_based_potential(mnle, prior, x_o)

In [None]:
num_chains = 1
%timeit potential_fn(prior.sample((num_chains,))).shape

In [None]:
posterior = MCMCPosterior(potential_fn, proposal=prior, theta_transform=parameter_transform)

In [None]:
posterior_samples = posterior.sample((10,), method="slice_np", thin=1,)