In [None]:
import keras
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import pickle
import sbibm
import torch

# Plotting settings
mpl.rcParams["axes.spines.top"] = False
mpl.rcParams["axes.spines.right"] = False
mpl.rcParams["axes.labelsize"] = "medium"
mpl.rcParams["legend.frameon"] = False
mpl.rcParams["legend.fontsize"] = 14
mpl.rcParams["font.size"] = 20

# Fix seed for figures.
seed = 223174 #torch.randint(1000000, (1,))
torch.manual_seed(seed)
print(f"seed {seed}")

## Load DDM task from `sbibm`

In [None]:
task = sbibm.get_task("ddm")
prior = task.get_prior_dist()
simulator = task.get_simulator(seed=seed) # Passing the seed to Julia.

### Load pretrained LANs

In [None]:
# network trained on KDE likelihood for 4-param ddm
lan_kde_path = "../sbibm/algorithms/lan/lan_pretrained/model_final_ddm.h5"
lan_ana_path = "../sbibm/algorithms/lan/lan_pretrained/model_final_ddm_analytic.h5"
lan_kde = keras.models.load_model(lan_kde_path, compile=False)
lan_ana = keras.models.load_model(lan_ana_path, compile=False)

def lan_likelihood(theta, data, net, ll_lower_bound):
    """Return log likelihood summed over all trials in data, 
        given a batch of parameters theta.
    
    Args
        theta: batch of parameters
        data: batch of iid reaction times and choices encoded
            as negative and positive reaction times. 
        net: lan keras model
        ll_lower_bound: lower bound of single trial log likelihood.
        
    Returns
        llsum: batch of log likelihoods over trials. 
    """
    
    # Convert to positive rts.
    rts = abs(data)
    num_trials = rts.numel()
    num_parameters = theta.shape[0]
    assert rts.shape == torch.Size([num_trials, 1])
    theta = torch.tensor(theta, dtype=torch.float32)
    # Convert DDM boundary seperation to symmetric boundary size.
    theta[:, 1] *= 0.5

    # Code down -1 up +1.
    cs = torch.ones_like(rts)
    cs[data < 0] *= -1

    # Repeat theta trial times
    theta_repeated = theta.repeat(num_trials, 1)
    # repeat trial data theta times.
    rts_repeated = torch.repeat_interleave(rts, num_parameters, dim=0)
    cs_repeated = torch.repeat_interleave(cs, num_parameters, dim=0)

    # stack everything for the LAN net.
    theta_x_stack = torch.cat((theta_repeated, rts_repeated, cs_repeated), dim=1)
    ll_each_trial = torch.tensor(
        net.predict_on_batch(theta_x_stack.numpy()),
        dtype=torch.float32,
    ).reshape(num_trials, num_parameters)

    # Lower bound on each trial log likelihood.
    # Sum across trials.
    llsum = torch.where(
        torch.logical_and(
            rts.repeat(1, num_parameters) > theta[:, -1], 
            ll_each_trial > ll_lower_bound,
        ),
        ll_each_trial,
        ll_lower_bound * torch.ones_like(ll_each_trial),
    ).sum(0)
    
    return llsum

## Likelihood comparison for single example

##### Sample example parameter from prior

For creating a figure showing the likelihood over the entire data space given a fixed parameter combination we sample a single parameter combination from the prior and evaluate the synthetic likelihoods for a large range of reaction times and both choices, while holding the parameters fixed.

In [None]:
# Sample one parameter from the prior for a likelihood example.
theta_o = prior.sample((1,))
l_lower_bound = 1e-7
theta_o

In [None]:
# Load pretrained NLE model
with open("../sbibm/algorithms/lan/nle_pretrained/mm_315_2.p", "rb") as fh: 
    nle = pickle.load(fh)

In [None]:
# construct rts and choices for nle in [-test_tmax, test_tmax]

# RT range
test_tmax = 5
# Number of test points
ntest = 1000
rs = torch.cat((torch.linspace(test_tmax, 1e-7, ntest//2), 
           torch.linspace(1e-7, test_tmax, ntest//2))).reshape(-1, 1)
cs = torch.cat((torch.zeros(ntest//2), 
           torch.ones(ntest//2))).reshape(-1, 1)

test_rts = torch.linspace(-test_tmax, test_tmax, ntest)

# get NLE synthetic likelihood for each data point with fixed theta.
lps_nle = torch.tensor([nle.log_prob(r.reshape(-1, 1), c.reshape(-1, 1), theta_o) for r, c in zip(rs, cs)])

# from analytical likelihood
lps_true = torch.tensor([task.get_log_likelihood(theta_o, d.reshape(-1, 1), l_lower_bound=l_lower_bound) 
                     for d in test_rts])

# and from both LANs.
lps_lanana = torch.tensor([lan_likelihood(theta_o, 
                                   d.reshape(-1, 1), 
                                   net=lan_ana, 
                                   ll_lower_bound=np.log(l_lower_bound)) 
                    for d in test_rts])
lps_lankde = torch.tensor([lan_likelihood(theta_o, 
                                   d.reshape(-1, 1), 
                                   net=lan_kde, 
                                   ll_lower_bound=np.log(l_lower_bound)) 
                    for d in test_rts])

## Systematic Likelihood comparison

Next we do a systematic comparison based on the Huber loss (the LAN training loss) and the mean squared error (MSE) between analytical and synthetic log likelihoods of LAN and NLE.

To mimick the inference setup we sample an observation from the simulator via parameters from the prior and obtain the likelihood of this observation given a large batch of parameters, e.g., `1000`, sampled from the prior. The Huber loss and MSE is then calculated across this batch, giving a single number. This procedure we repeat for many observation, e.g., `100` and then show boxplots over the resulting 100 numbers. 

In [None]:
# Define losses.
def huberloss(y, yhat):
    diff = abs(y-yhat)
    
    err = np.zeros(y.numel())
    err[diff <= 1.0] = 0.5 * diff[diff <= 1.0]**2
    err[diff > 1.0] = 0.5 + diff[diff > 1.0]
    return err.mean()

def mse(y, yhat):
    return torch.mean((y - yhat)**2)

In [None]:
# mimick the MCMC setting: xo is fixed, thetas are scored with loss
num_observations = 100
num_thetas_per_observation = 1000

# first sample observations xo
xos = simulator(prior.sample((num_observations,)))

labels = [
#     "LAN-ANA", 
    "LAN", 
    "NLE",
]
errors = []
for xoi in xos:
    
    # Sample test thetas from prior.
    test_thetas = prior.sample((num_thetas_per_observation,))
    xoi = xoi.reshape(-1, 1)
    # Extract positive RTs and choices for mixed model.
    rsi = abs(xoi)
    csi = torch.ones_like(rsi)
    csi[xoi < 0] = 0

    # Evaluate
    lps_nle_i = nle.log_prob(rsi, csi, test_thetas).squeeze()    
    lps_true_i = task.get_log_likelihood(test_thetas, data=xoi.reshape(1, -1), l_lower_bound=l_lower_bound)
    lps_lanana_i = lan_likelihood(test_thetas, xoi, lan_ana, np.log(l_lower_bound))
    lps_lankde_i = lan_likelihood(test_thetas, xoi, lan_kde, np.log(l_lower_bound))
    

    # Score
    errors.append([
            [
#                 huberloss(lps_lanana_i, lps_true_i),
                huberloss(lps_lankde_i, lps_true_i),
                huberloss(lps_nle_i, lps_true_i),
            ],
            [
#                 huberloss(lps_lanana_i.exp(), lps_true_i.exp()),
                huberloss(lps_lankde_i.exp(), lps_true_i.exp()),
                huberloss(lps_nle_i.exp(), lps_true_i.exp()),
            ],            
            [
#                 mse(lps_lanana_i, lps_true_i),
                mse(lps_lankde_i, lps_true_i),
                mse(lps_nle_i, lps_true_i),
            ], 
            [
#                 mse(lps_lanana_i.exp(), lps_true_i.exp()),
                mse(lps_lankde_i.exp(), lps_true_i.exp()),
                mse(lps_nle_i.exp(), lps_true_i.exp()),
            ],         
        ])
errors = np.array(errors)

In [None]:
# extract separate dataframes for huber and mse.
dfhuber_log = pd.DataFrame(data=np.array(errors)[:, 0, :], columns=labels)
dfhuber = pd.DataFrame(data=np.array(errors)[:, 1, :], columns=labels)
dfmse_log = pd.DataFrame(data=np.array(errors)[:, 2, :], columns=labels)
dfmse = pd.DataFrame(data=np.array(errors)[:, 3, :], columns=labels)

## Estimate evaluation times

In [None]:
import time
# Vary size of theta (number of MCMC chains in parallel)
num_chains = [10]
# Vary size of data (number of trials)
num_trials = [100]
reps = 100

lan_rts = np.zeros((len(num_chains), len(num_trials), reps))
nle_rts = np.zeros_like(lan_rts)

thetas = prior.sample((num_chains[-1],))
xs = simulator(prior.sample((1,)).repeat(num_trials[-1], 1))

for ii in range(reps):
    for jj, nc in enumerate(num_chains): 
        for kk, nt in enumerate(num_trials):
            # LAN timing
            tic = time.time()
            lan_likelihood(thetas[:nc,], xs[:nt,], net=lan_kde, ll_lower_bound=np.log(l_lower_bound))
            lan_rts[jj, kk, ii] = time.time() - tic

            # NLE timing
            rts = abs(xs[:nt])
            cs = torch.ones_like(rts)
            cs[xs[:nt] < 0] = 0
            tic = time.time()
            nle.log_prob(rts, cs, thetas[:nc])
            nle_rts[jj, kk, ii] = time.time() - tic

# convert ot ms
lan_rts *= 1000
nle_rts *= 1000

In [None]:
nle_rts.mean(), nle_rts.std()

In [None]:
lan_rts.mean(), lan_rts.std()

In [None]:
df_rt = pd.DataFrame(data={"LAN-KDE":lan_rts.reshape(-1), "NLE": nle_rts.reshape(-1)}, index=range(reps))

## Results figure 1

- likelihood examples

- likelihood accuracy

- number of simulations

- evaluation time

In [None]:
fig, ax = plt.subplots(2, 3, sharex=False, figsize=(18, 8),  
                       gridspec_kw=dict(wspace=0.25, hspace=0.2, width_ratios=[0.6, .2, .2]))

mpl.rcParams["legend.fontsize"] = 12
mpl.rcParams["font.size"] = 15
colors = [
#     "C1", 
    "C2", 
    "C3"
]
grid = True
showfliers = True

labels = ["Analytical", 
#           "LAN-ANA", 
          "LAN", 
          "NLE", 
         ]

plt.sca(ax[0, 0])
plt.plot(test_rts, lps_true.exp(), label="Analytical L");
# plt.plot(test_rts, lps_lanana.exp(), label="LAN-ANA");
plt.plot(test_rts, lps_lankde.exp(), label="LAN-KDE", ls="-", c="C2");
plt.plot(test_rts, lps_nle.exp(), label="NLE", ls="-", c="C3");
plt.ylabel(r"$L(x | \theta)$");
plt.legend(labels)
plt.xticks([-4, -2, 0, 2, 4], [])
# plt.yticks([0, .4, .8, 1.2], [0, .4, .8, 1.2])
plt.axvline(0, color="k", lw=1)
y = max(lps_true.exp())+.1
plt.arrow(0, y, 0.3, 0., width=0.03, color="k", alpha=0.5)
plt.text(0.25, 1.05 * y, s="up")
plt.arrow(0, y, -0.3, 0., width=0.03, color="k", alpha=0.5)
plt.text(-.94, 1.05*y, s="down")
# plt.suptitle(fr"v={theta_o[0, 0]:.2f}, a={theta_o[0, 1]:.2f}, w={theta_o[0, 2]:.2f}, $\tau$={theta_o[0, 3]:.2f}");

plt.sca(ax[1, 0])
plt.plot(test_rts, lps_true)
# plt.plot(test_rts, lps_lanana)
plt.plot(test_rts, lps_lankde, ls="-", c="C2")
plt.plot(test_rts, lps_nle, ls="-", c="C3")
# plt.legend(labels)
plt.xlabel("$x$: reaction time [s]")
plt.ylabel(r"$\log L(x | \theta)$");
plt.xticks([-4, -2, 0, 2, 4], [4, 2, 0, 2, 4])
plt.axvline(0, color="k", lw=1)



plt.sca(ax[0, 1])
box_widths = [0.3] * len(colors)
bdict = dfhuber_log.boxplot(ax=ax[0, 1], patch_artist=True, return_type="dict", 
                        medianprops={"color": "k"}, grid=grid, 
                           notch=True, 
                           widths=box_widths, 
                           showfliers=showfliers,
                           )
plt.ylabel("Huber loss");
for i,box in enumerate(bdict['boxes']):
    box.set_color(colors[i])
plt.yticks(np.linspace(0, 0.4, 3));
plt.ylim(0, .4)

plt.sca(ax[0, 2])
bdict = dfmse_log.boxplot(ax=ax[0, 2], patch_artist=True, return_type="dict", 
                          medianprops={"color": "k"}, grid=grid, 
                          notch=True, 
                          widths=box_widths,
                          showfliers=showfliers,
                         )
plt.ylabel(r"MSE");
for i,box in enumerate(bdict['boxes']):
    box.set_color(colors[i])
plt.yticks(np.linspace(0, 2.5, 3));


ddd = pd.DataFrame({'method': ["LAN", "NLE"], 'training budget': [15e9, 1e5]})
plt.sca(ax[1, 2])
ddd.plot.bar(x="method", y="training budget", color=["C2", "C3"], ax=ax[1, 2], 
             rot=0, width=box_widths[0])
plt.xlabel('')
plt.ylabel("training budget")
plt.legend("")
plt.yscale("log")
plt.yticks(np.logspace(5, 10, 2), [r"$10^5$", 
#                                    r"$10^6$",r"$10^7$",r"$10^8$",r"$10^9$", 
                                   r"$10^{10}$"])
# plt.grid()

ddd = pd.DataFrame({'method': ["LAN", "NLE"], 'rt': [lan_rts.mean(), nle_rts.mean()]})
plt.sca(ax[1, 1])
ddd.plot.bar(x="method", y="rt", color=["C2", "C3"], ax=ax[1, 1], 
             rot=0, width=box_widths[0], 
            yerr=[lan_rts.std()/np.sqrt(reps), nle_rts.std()/np.sqrt(reps)])
plt.xlabel('')
plt.ylabel("evaluation time [ms]")
plt.legend("")
plt.yticks(np.linspace(0, 6, 4), np.linspace(0, 6, 4))

weight = "bold"
fontsize = 20
y1 = 0.9
x1 = 0.075
fig.text(x1, y1, "A", fontsize=fontsize, fontweight=weight)
fig.text(x1 + 0.46, y1, "B", fontsize=fontsize, fontweight=weight)
fig.text(x1 + 0.64, y1, "C", fontsize=fontsize, fontweight=weight)
fig.text(x1 + .46, y1 - 0.42, "D", fontsize=fontsize, fontweight=weight)
fig.text(x1 + 0.64, y1 - 0.42, "E", fontsize=fontsize, fontweight=weight);
plt.tight_layout()
fig.savefig("LAN-NLE-likelihood-comparison.png", dpi=300, bbox_inches="tight");

## Figure 2: posterior space

- posterior example
- posterior metrics, timings should be according to likelihood evaluation timings.

In [None]:
import sys
sys.path.append("../../results/benchmarking_sbi/")
from utils import compile_df

# from sbibm.utils import compile_df
from sbibm.utils.io import get_tensor_from_csv, get_ndarray_from_csv
import pandas as pd
from sbi.analysis import pairplot

In [None]:
df = pd.concat([
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-01/09-59-40/"),  # LAN 1-10
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-01/18-35-25/"),  # NLE 1-10
    compile_df("../../results/benchmarking_sbi/multirun/2021-09-29/21-43-22/"),  # LAN 11-20
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-06/07-58-18/"),  # NLE 11-20
    compile_df("../../results/benchmarking_sbi/multirun/2021-09-29/22-30-49/"),  # LAN 21-30
    compile_df("../../results/benchmarking_sbi/multirun/2021-09-30/09-16-33/"),  # NLE 21-30
    compile_df("../../results/benchmarking_sbi/multirun/2021-09-30/09-49-49/"),  # NLE 31-40
    compile_df("../../results/benchmarking_sbi/multirun/2021-09-30/11-48-19/"),  # LaN 31-40
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/11-42-01/"),  # NLE 41-50
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/12-35-48/"),  # LaN 41-50
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/14-59-22/"),  # NLE 51-60
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/15-25-27/"),  # NLE 61-70
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/16-40-14/"),  # NLE 71-80
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/17-05-47/"),  # NLE 81-90
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/17-31-34/"),  # NLE 91-100
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/17-57-20/"),  # LAN 51-60
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/18-43-08/"),  # LAN 61-70
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/19-29-10/"),  # LAN 71-80
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/20-14-40/"),  # LAN 81-90
    compile_df("../../results/benchmarking_sbi/multirun/2021-10-05/20-59-55/"),  # LAN 91-100
]
);

#### Load posterior samples

In [None]:
obs = 24
labels = ["LAN-KDE", "NLE"]
cols = ["MEANERR", "VARERR", "C2ST", "RT"]
paths = []

# df = df100
for alg in labels:
    idx = df[df.num_observation==obs].algorithm == alg
    p = df[df.num_observation==obs].loc[idx, "path"].values
    [paths.append(pi) for pi in p]

ss = [
    sbibm.get_task("ddm").get_reference_posterior_samples(obs)
     ] + [get_ndarray_from_csv(path+"/posterior_samples.csv.bz2") for path in paths]

print(sbibm.get_task("ddm").get_true_parameters(obs))

In [None]:
fig = plt.figure(figsize=(18, 8))
outer_grid = fig.add_gridspec(1, 2, wspace=0.15, hspace=0, width_ratios=[.6, .4])
# Plotting settings
mpl.rcParams["font.size"] = 16
notch = True

# posterior samples
num_plots = 4
inner_grid = outer_grid[0, 0].subgridspec(num_plots, num_plots, wspace=0.09, hspace=0)
ax1 = inner_grid.subplots()  # Create all subplots for the inner grid.
pairplot(ss, 
         points=sbibm.get_task("ddm").get_true_parameters(obs), 
         limits=[[-2, 2], [0.5, 2.0], [.3, .7], [.2, 1.]], 
         samples_colors=["C0", "C2", "C3"], 
         diag="kde",
         upper="contour",
         kde_offdiag=dict(bw_method="scott", bins=30),
         contour_upper=dict(levels=[0.1], percentile=False),
         points_offdiag=dict(marker="+", markersize=10), 
         points_colors=["k"], 
         fig=fig, 
         axes=ax1,
         labels=[r"$v$", r"$a$", r"$w$", r"$\tau$"],
        );
plt.sca(ax1[0, 0])
plt.legend(["Analytical", "LAN", "NLE", r"Ground truth $\theta$"], 
           bbox_to_anchor=(-.1, -2.2), 
           loc=2, 
          fontsize=16)


# posterior metrics
inner_grid = outer_grid[0, 1].subgridspec(2, 2, wspace=.7, hspace=.3, )
ax2 = inner_grid.subplots()  # Create all subplots for the inner grid.

bdict = df.boxplot(ax=ax2, column=cols, by=["algorithm"], rot=0, 
                grid=True, 
                fontsize=14.0,
                patch_artist=True,
                widths=box_widths,
                return_type="both", medianprops={"color": "k"}, 
                notch=notch, 
                showfliers=showfliers,);

colors = ["C2", "C3"]

ticks = [
    np.linspace(0, 2., 3), 
    np.linspace(0.0, .005, 3), 
    np.linspace(0.5, 0.9, 3),
    np.linspace(20, 60, 3),
]

for a, t in zip(ax2.reshape(-1), ticks):
    a.set_yticks(t)
    a.set_ylim(t[0], t[-1])


for row_key, (axi,row) in bdict.iteritems():
    for i,box in enumerate(row['boxes']):
        box.set_color(colors[i])
        
col_labels = ["mean error", "variance error", "C2ST", "time [min]"]

for i, a in enumerate(ax2.reshape(-1)):
    a.set_ylabel(col_labels[i])
    a.set_title("")
    a.set_ylabel(col_labels[i])
    a.set_xlabel("")
    a.set_xticklabels(["LAN", "NLE"], fontsize=mpl.rcParams["font.size"])
plt.suptitle("")

weight = "bold"
fontsize = 20
y1 = 0.9
x1 = 0.09
fig.text(x1, y1, "A", fontsize=fontsize, fontweight=weight)
fig.text(x1 + 0.46, y1, "B", fontsize=fontsize, fontweight=weight)
fig.text(x1 + 0.64, y1, "C", fontsize=fontsize, fontweight=weight)
fig.text(x1 + 0.46, y1 - 0.42, "D", fontsize=fontsize, fontweight=weight)
fig.text(x1 + 0.64, y1 - 0.42, "E", fontsize=fontsize, fontweight=weight)

plt.tight_layout();
fig.savefig("LAN-NLE-posterior-comparison.png", dpi=300, bbox_inches="tight");

## Visualize Likelihood for 10 Benchmark observations
We can also visualize the single trial likelihoods for the 10 random observations used in the benchmark.

In [None]:
ll_true = []
ll_lan = []
ll_nle = []

for ii in range(1, 11):
    
    xo = task.get_observation(ii).reshape(-1, 1)
    test_thetas = prior.sample((1000,))

    # Extract positive RTs and choices for mixed model.
    rs = abs(xo)
    cs = torch.ones_like(rs)
    cs[xo < 0] = 0
  
    ll_true.append(task.get_log_likelihood(test_thetas, data=xo.reshape(1, -1), l_lower_bound=l_lower_bound))

    ll_lan.append(lan_likelihood(test_thetas, 
                                xo, 
                                net=lan_kde, 
                                ll_lower_bound=np.log(l_lower_bound)))
    
    ll_nle.append(nle.log_prob(rs, cs, test_thetas))

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(18, 12), sharey=True, sharex=True)
mpl.rcParams["axes.spines.right"] = False
mpl.rcParams["axes.spines.top"] = False
s = 7
idx = 0
alpha = 1.0
for ii in range(0, 9):
    if ii%3 == 0:
        plt.sca(ax[idx, 0])
    elif ii%3 == 1:
        plt.sca(ax[idx, 1])
    elif ii%3 == 2:
        plt.sca(ax[idx, 2])
        idx += 1

    plt.title(f"Observation {ii+1}")
    
    plt.scatter(ll_true[ii], ll_lan[ii], alpha=alpha, color="C2", s=s)
    plt.scatter(ll_true[ii], ll_nle[ii], alpha=alpha, color="C3", s=s)
    plt.plot(ll_true[ii], ll_true[ii], "k")
    if not ii:
        plt.legend(["Identity", "LAN", "NLE", 
                   ], frameon=False, fontsize=12)
    if ii in [0, 3, 6]:
        plt.ylabel("synthetic log L")
    if ii > 5: 
        plt.xlabel("analytic log L")