In [None]:
import sys

sys.path.append("/workspaces/BI-LEVEL-SMC")

from paper.simulation.simulation_config import settings

import pandas as pd
import pickle
import seaborn as sns
import numpy as np
import matplotlib.lines as mlines
import matplotlib.pyplot as plt

plt.style.use("ggplot")

In [None]:
p_group = 5
n_runs = 10

In [None]:
result = []
for setting in settings:
    n, p_ind, approximation_method, pi_ind = setting

    filename = f"simulation_results/SMC_output_n_{n}_p_{p_ind}_am_{approximation_method}_prior_{pi_ind}.pkl"
    with open(filename, "rb") as file:
        smc = pickle.load(file)

    for run in range(n_runs):
        post_prob = np.mean(smc[run]["output"].X.theta, axis=0)
        group_prob = post_prob[:p_group]
        ind_prob = post_prob[p_group:]

        idx_select_group = np.array([2, 3, 4])
        prob_select_group = np.mean(group_prob[idx_select_group])
        prob_no_select_group = np.mean(
            group_prob[~np.isin(np.arange(len(group_prob)), idx_select_group)]
        )

        step = int(p_ind / p_group)
        idx_select_ind = np.array(
            [p_ind - (step * 2 + 1), p_ind - (step + 1), p_ind - 1]
        )
        prob_select_ind = np.mean(ind_prob[idx_select_ind])
        prob_no_select_ind = np.mean(
            ind_prob[~np.isin(np.arange(len(ind_prob)), idx_select_ind)]
        )

        model_result = {
            "n": n,
            "likelihood": approximation_method,
            "p_ind": p_ind,
            "p_group": p_group,
            "prob_select_ind": prob_select_ind,
            "prob_no_select_ind": prob_no_select_ind,
            "prob_select_group": prob_select_group,
            "prob_no_select_group": prob_no_select_group,
        }

        result.append(model_result)

result = pd.DataFrame(result)

In [None]:
data = result.loc[result["p_ind"] == 50].copy()

plt.figure(figsize=(30, 10))

# figure 1
plt.subplot(1, 2, 1)

ax = sns.lineplot(
    data=data,
    x="n",
    y="prob_select_ind",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darksalmon", "deepskyblue"],
    markers=["s", "o"],
    errorbar="sd",
    markersize=13,
)

ax1 = sns.lineplot(
    data=data,
    x="n",
    y="prob_select_group",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darkred", "darkblue"],
    markers=["<", "d"],
    errorbar="sd",
    markersize=13,
)

plt.ylim(0, 1.02)

ax.tick_params(axis="x", labelsize=20)
ax.tick_params(axis="y", labelsize=20)

ax.set_xlabel("$n$", size=20)
ax.set_ylabel("Marginal posterior inclusion probability", size=20)

plt.xticks(
    [
        100,
        200,
        300,
        400,
        500,
        600,
        700,
        800,
        900,
        1000,
        1250,
        1500,
        1750,
        2000,
        2250,
        2500,
    ],
    [
        100,
        200,
        300,
        400,
        500,
        600,
        700,
        800,
        900,
        1000,
        1250,
        1500,
        1750,
        2000,
        2250,
        2500,
    ],
    rotation=50,
)

l1 = mlines.Line2D(
    [], [], color="darkblue", marker="d", ls="-", label="LA - groups", markersize=13
)
l2 = mlines.Line2D(
    [],
    [],
    color="deepskyblue",
    marker="o",
    ls="-",
    label="LA - predictors",
    markersize=13,
)
l3 = mlines.Line2D(
    [], [], color="darkred", marker="<", ls="-", label="ALA - groups", markersize=13
)
l4 = mlines.Line2D(
    [],
    [],
    color="darksalmon",
    marker="s",
    ls="-",
    label="ALA - predictors",
    markersize=13,
)
plt.legend(handles=[l1, l2, l3, l4], fontsize=20, loc="center right")

# figure 2
plt.subplot(1, 2, 2)

ax = sns.lineplot(
    data=data,
    x="n",
    y="prob_no_select_ind",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darksalmon", "deepskyblue"],
    markers=["s", "o"],
    errorbar="sd",
    markersize=13,
)

ax1 = sns.lineplot(
    data=data,
    x="n",
    y="prob_no_select_group",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darkred", "darkblue"],
    markers=["<", "d"],
    errorbar="sd",
    markersize=13,
)

plt.ylim(-0.01, 1)

ax.tick_params(axis="x", labelsize=20)
ax.tick_params(axis="y", labelsize=20)

ax.set_xlabel("$n$", size=20)
ax.set_ylabel("", size=20)

plt.xticks(
    [
        100,
        200,
        300,
        400,
        500,
        600,
        700,
        800,
        900,
        1000,
        1250,
        1500,
        1750,
        2000,
        2250,
        2500,
    ],
    [
        100,
        200,
        300,
        400,
        500,
        600,
        700,
        800,
        900,
        1000,
        1250,
        1500,
        1750,
        2000,
        2250,
        2500,
    ],
    rotation=50,
)

l1 = mlines.Line2D(
    [], [], color="darkblue", marker="d", ls="-", label="LA - groups", markersize=13
)
l2 = mlines.Line2D(
    [],
    [],
    color="deepskyblue",
    marker="o",
    ls="-",
    label="LA - predictors",
    markersize=13,
)
l3 = mlines.Line2D(
    [], [], color="darkred", marker="<", ls="-", label="ALA - groups", markersize=13
)
l4 = mlines.Line2D(
    [],
    [],
    color="darksalmon",
    marker="s",
    ls="-",
    label="ALA - predictors",
    markersize=13,
)
plt.legend(handles=[l1, l2, l3, l4], fontsize=20, loc="center left")

plt.subplots_adjust(wspace=0.08, hspace=0)
plt.savefig(
    "model_results/bi_level_SMC_simulation_analyze_n.pdf", dpi=400, bbox_inches="tight"
)
plt.show()

In [None]:
data = result.loc[result["n"] == 1500].copy()

plt.figure(figsize=(30, 10))

# figure 1
plt.subplot(1, 2, 1)

ax = sns.lineplot(
    data=data,
    x="p_ind",
    y="prob_select_ind",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darksalmon", "deepskyblue"],
    markers=["s", "o"],
    errorbar="sd",
    markersize=13,
)

ax1 = sns.lineplot(
    data=data,
    x="p_ind",
    y="prob_select_group",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darkred", "darkblue"],
    markers=["<", "d"],
    errorbar="sd",
    markersize=13,
)

plt.ylim(0, 1.02)

ax.tick_params(axis="x", labelsize=20)
ax.tick_params(axis="y", labelsize=20)

ax.set_xlabel("$p$", size=20)
ax.set_ylabel("Marginal posterior inclusion probability", size=20)

plt.xticks(
    [10, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250],
    [10, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250],
)

l1 = mlines.Line2D(
    [], [], color="darkblue", marker="d", ls="-", label="LA - groups", markersize=13
)
l2 = mlines.Line2D(
    [],
    [],
    color="deepskyblue",
    marker="o",
    ls="-",
    label="LA - predictors",
    markersize=13,
)
l3 = mlines.Line2D(
    [], [], color="darkred", marker="<", ls="-", label="ALA - groups", markersize=13
)
l4 = mlines.Line2D(
    [],
    [],
    color="darksalmon",
    marker="s",
    ls="-",
    label="ALA - predictors",
    markersize=13,
)
plt.legend(handles=[l1, l2, l3, l4], fontsize=20, loc="center left")

# figure 2
plt.subplot(1, 2, 2)

ax = sns.lineplot(
    data=data,
    x="p_ind",
    y="prob_no_select_ind",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darksalmon", "deepskyblue"],
    markers=["s", "o"],
    errorbar="sd",
    markersize=13,
)

ax1 = sns.lineplot(
    data=data,
    x="p_ind",
    y="prob_no_select_group",
    hue="likelihood",
    style="likelihood",
    dashes=False,
    palette=["darkred", "darkblue"],
    markers=["<", "d"],
    errorbar="sd",
    markersize=13,
)

plt.ylim(-0.01, 1)

ax.tick_params(axis="x", labelsize=20)
ax.tick_params(axis="y", labelsize=20)

ax.set_xlabel("$p$", size=20)
ax.set_ylabel("", size=20)

plt.xticks(
    [10, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250],
    [10, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250],
)

l1 = mlines.Line2D(
    [], [], color="darkblue", marker="d", ls="-", label="LA - groups", markersize=13
)
l2 = mlines.Line2D(
    [],
    [],
    color="deepskyblue",
    marker="o",
    ls="-",
    label="LA - predictors",
    markersize=13,
)
l3 = mlines.Line2D(
    [], [], color="darkred", marker="<", ls="-", label="ALA - groups", markersize=13
)
l4 = mlines.Line2D(
    [],
    [],
    color="darksalmon",
    marker="s",
    ls="-",
    label="ALA - predictors",
    markersize=13,
)
plt.legend(handles=[l1, l2, l3, l4], fontsize=20, loc="center left")

plt.subplots_adjust(wspace=0.08, hspace=0)
plt.savefig(
    "model_results/bi_level_SMC_simulation_analyze_p.pdf", dpi=400, bbox_inches="tight"
)
plt.show()