In [9]:
import pandas as pd
from pathlib import Path
import pickle
from matplotlib import pyplot as plt
from typing import List

In [10]:
def get_ta_summary_df(df):
    """
    Get means and SEs for the total average learning and habituation effects 
    """
    ta_sum = df.groupby('n').agg(
    tate_mean=('tate', 'mean'),
    tate_se=('tate', 'sem'),
    tale_mean=('tale', 'mean'),
    tale_se=('tale', 'sem'),
    tahe_mean=('tahe', 'mean'),
    tahe_se=('tahe', 'sem'),
    ).reset_index()

    ta_sum['tate_upper'] = ta_sum['tate_mean'] + 1.96 * ta_sum['tate_se']
    ta_sum['tate_lower'] = ta_sum['tate_mean'] - 1.96 * ta_sum['tate_se']
    ta_sum['tale_upper'] = ta_sum['tale_mean'] + 1.96 * ta_sum['tale_se']
    ta_sum['tale_lower'] = ta_sum['tale_mean'] - 1.96 * ta_sum['tale_se']
    ta_sum['tahe_upper'] = ta_sum['tahe_mean'] + 1.96 * ta_sum['tahe_se']
    ta_sum['tahe_lower'] = ta_sum['tahe_mean'] - 1.96 * ta_sum['tahe_se']

    return ta_sum

def get_tau_hat_summary_df(df):
    """
    Get means and SEs for the difference-in-means estimates
    """
    tau_hat_sum = df.groupby('n').agg(
    tau_hat_dm_mean=('tau_hat_dm', 'mean'),
    tau_hat_dm_se=('tau_hat_dm', 'sem'),
    ).reset_index()

    tau_hat_sum['tau_hat_dm_upper'] = tau_hat_sum['tau_hat_dm_mean'] + 1.96 * tau_hat_sum['tau_hat_dm_se']
    tau_hat_sum['tau_hat_dm_lower'] = tau_hat_sum['tau_hat_dm_mean'] - 1.96 * tau_hat_sum['tau_hat_dm_se']

    return tau_hat_sum

In [11]:
def make_plot(
    dec_quality_dir_name: str,
    adapt_type_dir_names: List[str],
    res_dir: Path,
    fig_dir: Path,
    effect_type="tate",
):
    """
    Plot effects across simulated experiments
    """
    fig, axs = plt.subplots(
        1,
        len(adapt_type_dir_names),
        figsize=(4 * len(adapt_type_dir_names), 4),
        sharex=True,
        sharey=True,
    )

    for i, (adapt_type_dir, adapt_params) in enumerate(adapt_type_dir_names.items()):
        if adapt_type_dir == "step":
            alpha_str = (
                f"a0pr-{adapt_params['a0pr']:.2f}"
                f"_a0po-{adapt_params['a0po']:.1f}"
                f"_a1pr-{adapt_params['a1pr']:.1f}"
                f"_a1po-{adapt_params['a1po']:.1f}"
                f"_lt-200"
            )
        else:
            alpha_str = f'a0-{adapt_params["a0"]}_a1-{adapt_params["a1"]}'

        tau_hat_dm_fname = f"tau_hat_dm_ns-200_nz-10_{alpha_str}.pkl"
        p_ta_fname = f"p_ta_ns-200_nz-10_{alpha_str}.pkl"
        m_ta_fname = f"m_ta_ns-200_nz-10_{alpha_str}.pkl"

        tau_hat_dm_path = (
            res_dir / dec_quality_dir_name / adapt_type_dir / tau_hat_dm_fname
        )
        p_ta_path = res_dir / dec_quality_dir_name / adapt_type_dir / p_ta_fname
        m_ta_path = res_dir / dec_quality_dir_name / adapt_type_dir / m_ta_fname

        if (
            not tau_hat_dm_path.exists()
            or not p_ta_path.exists()
            or not m_ta_path.exists()
        ):
            continue
        with open(tau_hat_dm_path, "rb") as f:
            tau_hat_dm = pickle.load(f)
        with open(p_ta_path, "rb") as f:
            p_ta = pickle.load(f)
        with open(m_ta_path, "rb") as f:
            m_ta = pickle.load(f)

        tau_hat_dm_sum = get_tau_hat_summary_df(tau_hat_dm)
        p_ta_sum = get_ta_summary_df(p_ta)
        m_ta_sum = get_ta_summary_df(m_ta)

        if effect_type == "tate":
            axs[i].plot(
                tau_hat_dm_sum["n"],
                tau_hat_dm_sum["tau_hat_dm_mean"],
                label=r"$\hat{\tau}_{DM}$",
            )
            axs[i].fill_between(
                tau_hat_dm_sum["n"],
                tau_hat_dm_sum["tau_hat_dm_lower"],
                tau_hat_dm_sum["tau_hat_dm_upper"],
                alpha=0.2,
            )
            axs[i].plot(p_ta_sum["n"], p_ta_sum["tate_mean"], label=r"$\tau_{P-TATE}$")
            axs[i].fill_between(
                p_ta_sum["n"], p_ta_sum["tate_lower"], p_ta_sum["tate_upper"], alpha=0.2
            )
            axs[i].plot(m_ta_sum["n"], m_ta_sum["tate_mean"], label=r"$\tau_{M-TATE}$")
            axs[i].fill_between(
                m_ta_sum["n"], m_ta_sum["tate_lower"], m_ta_sum["tate_upper"], alpha=0.2
            )
        elif effect_type == "tale-tahe":
            axs[i].plot(p_ta_sum["n"], p_ta_sum["tale_mean"], label=r"$\tau_{P-TALE}$")
            axs[i].fill_between(
                p_ta_sum["n"], p_ta_sum["tale_lower"], p_ta_sum["tale_upper"], alpha=0.2
            )
            axs[i].plot(p_ta_sum["n"], p_ta_sum["tahe_mean"], label=r"$\tau_{P-TAHE}$")
            axs[i].fill_between(
                p_ta_sum["n"], p_ta_sum["tahe_lower"], p_ta_sum["tahe_upper"], alpha=0.2
            )
            axs[i].plot(m_ta_sum["n"], m_ta_sum["tale_mean"], label=r"$\tau_{M-TALE}$")
            axs[i].fill_between(
                m_ta_sum["n"], m_ta_sum["tale_lower"], m_ta_sum["tale_upper"], alpha=0.2
            )
            axs[i].plot(m_ta_sum["n"], m_ta_sum["tahe_mean"], label=r"$\tau_{M-TAHE}$")
            axs[i].fill_between(
                m_ta_sum["n"], m_ta_sum["tahe_lower"], m_ta_sum["tahe_upper"], alpha=0.2
            )
        else:
            raise ValueError(f"Unknown effect type: {effect_type}")

        axs[i].set_title(f"{adapt_type_dir}", fontsize=18)
        axs[i].set_xlabel("N", fontsize=16)
        axs[i].set_ylabel("Value", fontsize=16)
        axs[i].tick_params(axis="both", which="major", labelsize=14)

        if i == 0:
            axs[i].legend(fontsize=14)

    if not fig_dir.exists():
        fig_dir.mkdir(parents=True)
    fig.savefig(
        fig_dir / f"{dec_quality_dir_name}_{effect_type}.png",
        bbox_inches="tight",
        dpi=300,
    )

In [12]:
res_dir = Path("res")
fig_dir = Path("figs")
adapt_type_dirs = {
    "static": {"a0": 0.01, "a1": 100.0},
    "linear": {"a0": 0.001, "a1": 0.01},
    "exponential": {"a0": 0.01, "a1": 0.1},
    "step": {"a0pr": 0.01, "a0po": 0.5, "a1pr": 0.5, "a1po": 5},
}
make_plot("phys-0.1-0.2-0.9_mdl-0.4-0.2-0.6", adapt_type_dirs, res_dir, fig_dir, effect_type='tate')
make_plot("phys-0.4-0.2-0.6_mdl-0.1-0.2-0.9", adapt_type_dirs, res_dir, fig_dir, effect_type='tate')
make_plot("phys-0.1-0.2-0.9_mdl-0.4-0.2-0.6", adapt_type_dirs, res_dir, fig_dir, effect_type='tale-tahe')