In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
import logging
logging.basicConfig(
    level=logging.INFO, 
    format='%(filename)s:%(lineno)d - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

SESSION_KEYS = ["subject_id", "session_date"]

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none'
import seaborn as sns

from aind_analysis_arch_result_access.han_pipeline import get_session_table, get_mle_model_fitting
from aind_analysis_arch_result_access.util.s3 import get_s3_pkl, get_s3_json

import aind_dynamic_foraging_population_analysis

# from aind_dynamic_foraging_population_analysis.model_comparison import (
#     get_all_model_metrics,
#     enrich_with_df_session,
# )

In [None]:
def get_all_model_metrics(
    use_cache=True, cache_path="~/capsule/data/df_model_fitting_all.pkl"
):
    """Get all model metrics from either cache or result access API.

    Parameters
    ----------
    use_cache : bool, optional
        Whether to use cached data, by default True
        If true, it will load data from the cache. If cache does not exist
           or is invalid, it will fetch data from the API.
        If False, it will fetch data from the API and update the cache_path.
    cache_path : str, optional
        Cache path, by default "~/capsule/data/df_model_fitting_all.pkl"
    """

    if use_cache:
        try:
            logger.info(f"Trying to load data from cache: {cache_path}...")
            df_model_fitting = pd.read_pickle(cache_path)
            logger.info(f"{len(df_model_fitting)} rows loaded from cache.")
            return df_model_fitting
        except Exception as e:
            logger.warning(f"Cache not found or invalid: {e}. Fetching from API.")

    # Fetch from result access API
    logger.info("Fetching data from result access API...")
    df_model_fitting = get_mle_model_fitting(
        from_custom_query={"status": "success"},
        if_include_latent_variables=False,
        paginate_settings={"paginate": True, "paginate_batch_size": 5000},
    )
    df_model_fitting.to_pickle(cache_path)
    logger.info(f"{len(df_model_fitting)} rows fetched from API and saved to cache.")
    return df_model_fitting


def enrich_with_df_session(df, selected_fields):
    """Enrich any df with session information from get_session_table.

    Parameters
    ----------
    df: pd.DataFrame
        Any dataFrame containing SESSION_KEYS (["subject_id", "session_date"])
    selected_fields: list of str
        Fields to merge from session table.

    Returns
    -------
    pd.DataFrame
        Enriched DataFrame with selected session information.
    """
    logger.info("Fetching session table...")
    df_session = get_session_table()

    logger.info("Merging model fitting data with session data...")
    # Merge in session metadata
    df_session["session_date"] = df_session["session_date"].astype("str")
    df_enriched = df.merge(
        df_session[SESSION_KEYS + selected_fields],
        on=SESSION_KEYS,
        how="left",
    )
    return df_enriched

In [None]:
df_model_fitting = get_all_model_metrics(use_cache=True, 
                                         cache_path=os.path.expanduser("~/capsule/results/df_model_fitting_all.pkl"))
df_model_fitting = enrich_with_df_session(
    df_model_fitting,
    selected_fields=[
        "nwb_suffix",
        "curriculum_name",
        "curriculum_version_group",
        "current_stage_actual",
    ],
)

In [None]:
df_model_fitting.columns

## compare model fit 

In [None]:
df_model_1_both

In [None]:
# for a given metric, compare model fit across two models using scatter plots

from cProfile import label
from botocore import session


# model_1 = 'QLearning_L2F1_softmax'  # Hattori
model_1 = 'QLearning_L1F1_CK1_softmax'  # Bari

# model_2 = 'ForagingCompareThreshold'
model_2 = 'QLearning_L2F1_softmax'  # Hattori
# model_2 = 'QLearning_L1F1_CK1_softmax'  # Bari


metrics = [
    "log_likelihood",
    "AIC",
    "BIC",
    "prediction_accuracy_10-CV_test",
]
# metric = "log_likelihood"
# metric = "AIC"
# metric = "BIC"
# metric = 'prediction_accuracy_10-CV_test'


for metric in metrics:
    # for a given metric, compare model fit across two models using scatter plots
    # Filter to sessions that have both models
    df_model_1 = df_model_fitting[df_model_fitting["agent_alias"] == model_1].sort_values(["subject_id", "session_date"])
    df_model_2 = df_model_fitting[df_model_fitting["agent_alias"] == model_2].sort_values(["subject_id", "session_date"])

    # Get sessions where both models are present
    session_counts = (
        df_model_fitting[df_model_fitting["agent_alias"].isin([model_1, model_2])]
        .groupby(SESSION_KEYS)["agent_alias"]
        .nunique()
    )
    sessions_with_both = session_counts[session_counts == 2].index

    # Subset data for only those sessions
    df_model_1_both = df_model_1[df_model_1.set_index(SESSION_KEYS).index.isin(sessions_with_both)]
    df_model_2_both = df_model_2[df_model_2.set_index(SESSION_KEYS).index.isin(sessions_with_both)]

    # scatter plot of model_1 vs model_2 using the metric
    plt.figure(figsize=(8, 8))

    # where model_1 > model_2
    greater_ids = df_model_1_both[metric].values > df_model_2_both[metric].values

    plt.scatter(
        df_model_1_both[greater_ids][metric],
        df_model_2_both[greater_ids][metric],
        alpha=0.5,
        color='red',
        s=1,
        label=f"{model_1} > {model_2}: {greater_ids.sum()} sessions",
    )
    plt.scatter(
        df_model_1_both[~greater_ids][metric],
        df_model_2_both[~greater_ids][metric],
        alpha=0.5,
        color='blue',
        s=1,
        label=f"{model_1} <= {model_2}: {(~greater_ids).sum()} sessions",
    )
    # add a diagonal line for reference
    max_value = max(df_model_1_both[metric].max(), df_model_2_both[metric].max())
    min_value = min(df_model_1_both[metric].min(), df_model_2_both[metric].min())
    plt.plot([min_value, max_value], [min_value, max_value], color='k', linestyle='--', lw=2)

    plt.grid()
    plt.xlabel(model_1)
    plt.ylabel(model_2)
    plt.title(f"Comparison of {model_1} and {model_2} using {metric}")
    plt.legend()

## get fitted parameters

In [None]:
# subject_id = '781370'  # uncoupled, no baiting
subject_id = '769884'  # uncoupled, baiting

df_subject = df_model_fitting[df_model_fitting["subject_id"] == subject_id]

# count number of models per subject
df_subject.value_counts("agent_alias").sort_index()

In [None]:
# get a specific model
# agent_alias = 'ForagingCompareThreshold'
# agent_alias = 'QLearning_L2F1_softmax'  # Hattori
agent_alias = 'QLearning_L1F1_CK1_softmax'  # Bari

df_subject_agent = df_subject[df_subject["agent_alias"] == agent_alias].sort_values("session_date")
df_subject_agent.head()

In [None]:
# iterate over rows to get all fitted parameters and their values

# get parameter names
fitted_params = {
    para_name: [] for para_name in df_subject_agent['params'].iloc[0].keys()
}

for index, row in df_subject_agent.iterrows():
    for param_name, param_value in row['params'].items():
        fitted_params[param_name].append(param_value)


In [None]:
# for each parameter, plot evolution over sessions
for param_name, param_values in fitted_params.items():
    plt.figure(figsize=(10, 5))
    plt.plot(df_subject_agent['session_date'], param_values, marker='o')
    plt.title(f"Evolution of {param_name} over sessions for subject {subject_id}")
    plt.xlabel("Session Date")
    plt.ylabel(param_name)
    plt.xticks(rotation=45)
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:

# for each parameter pair, plot scatter plot of parameter 1 vs parameter 2
for i, (param_name_1, param_values_1) in enumerate(fitted_params.items()):
    for j, (param_name_2, param_values_2) in enumerate(fitted_params.items()):
        if i >= j:  # avoid duplicate pairs and self-comparison
            continue
        plt.figure(figsize=(8, 8))
        plt.scatter(param_values_1, param_values_2, alpha=0.5)
        plt.title(f"{param_name_1} vs {param_name_2} for subject {subject_id}")
        plt.xlabel(param_name_1)
        plt.ylabel(param_name_2)
        plt.grid()
        plt.tight_layout()
        plt.show()

In [None]:
# all subjects

# get a specific model
agent_alias = 'ForagingCompareThreshold'
# agent_alias = 'QLearning_L2F1_softmax'  # Hattori
# agent_alias = 'QLearning_L1F1_CK1_softmax'  # Bari

df_agent = df_model_fitting[df_model_fitting["agent_alias"] == agent_alias].sort_values(["subject_id", "session_date"])
df_agent.head()

In [None]:
# iterate over rows to get all fitted parameters and their values

# get parameter names
fitted_params_agent = {
    param_name: [] for param_name in df_agent['params'].iloc[0].keys()
}

for index, row in df_agent.iterrows():
    for param_name, param_value in row['params'].items():
        fitted_params_agent[param_name].append(param_value)


In [None]:
# for each parameter, plot histogram of fitted parameter values
for param_name, param_values in fitted_params_agent.items():
    plt.figure(figsize=(10, 5))
    sns.histplot(param_values, bins=30, kde=True)
    plt.title(f"Distribution of {param_name} for {agent_alias}")
    plt.xlabel(param_name)
    plt.ylabel("Count")
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
# for each parameter pair, plot scatter plot of parameter 1 vs parameter 2
for i, (param_name_1, param_values_1) in enumerate(fitted_params_agent.items()):
    for j, (param_name_2, param_values_2) in enumerate(fitted_params_agent.items()):
        if i >= j:  # avoid duplicate pairs and self-comparison
            continue
        plt.figure(figsize=(8, 8))
        plt.scatter(param_values_1, param_values_2, alpha=0.6, s=2)
        plt.title(f"{agent_alias}: {param_name_1} vs {param_name_2}")
        plt.xlabel(param_name_1)
        plt.ylabel(param_name_2)
        plt.grid()
        plt.tight_layout()
        plt.show()
