In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import glob
import pickle

import numpy as np
from aind_behavior_gym.dynamic_foraging.task import CoupledBlockTask, UncoupledBlockTask
from aind_dynamic_foraging_models.generative_model import ForagerCollection

In [None]:
from pynwb import NWBHDF5IO

LOCAL_NWB_TMP = "/data/foraging_nwb_bonsai"

def get_nwb_from_local_tmp(session_id):
    """Get NWB file from session_id.

    Overwrite this function to get NWB file from other places.

    Parameters
    ----------
    session_id : _type_
        _description_
    """
    io = NWBHDF5IO(f"{LOCAL_NWB_TMP}/{session_id}.nwb", mode="r")
    nwb = io.read()
    return nwb


def get_history_from_nwb(nwb):
    """Get choice and reward history from nwb file
    
    #TODO move this to aind-behavior-nwb-util
    """

    df_trial = nwb.trials.to_dataframe()

    autowater_offered = (df_trial.auto_waterL == 1) | (df_trial.auto_waterR == 1)
    choice_history = df_trial.animal_response.map({0: 0, 1: 1, 2: np.nan}).values
    reward_history = df_trial.rewarded_historyL | df_trial.rewarded_historyR
    p_reward = [
        df_trial.reward_probabilityL.values,
        df_trial.reward_probabilityR.values,
    ]
    random_number = [
        df_trial.reward_random_number_left.values,
        df_trial.reward_random_number_right.values,
    ]

    baiting = False if "without baiting" in nwb.protocol.lower() else True

    return (
        baiting,
        choice_history,
        reward_history,
        p_reward,
        autowater_offered,
        random_number,
    )

In [None]:
from logging import BASIC_FORMAT
from socket import AI_CANONNAME


# subject_id = '781370'  # uncoupled, no baiting
# subject_id = '764769'  # uncoupled, baiting
# subject_id = '776293'  # uncoupled, baiting
subject_id = '769884'  # uncoupled, baiting


nll_ctt = []
nll_hattori = []
aic_ctt = []
aic_hattori = []
bic_ctt = []
bic_hattori = []

for session_name in sorted(glob.glob(f'{LOCAL_NWB_TMP}/{subject_id}_*'), reverse=True):
    print('############################################')
    session_id = session_name.split('/')[-1].split('.')[0]
    print(session_id)

    nwb = get_nwb_from_local_tmp(session_id=session_id)
    (
        baiting,
        choice_history,
        reward_history,
        _,
        autowater_offered,
        random_number,
    ) = get_history_from_nwb(nwb)


    # Remove NaNs
    ignored = np.isnan(choice_history)
    choice_history = choice_history[~ignored]
    reward_history = reward_history[~ignored].to_numpy()
    
    # handle invalid sessions if there are too few trials
    # -- Skip if len(valid trials) < 50 --
    if len(choice_history) < 50:
        fit_result = {
            "status": "skipped. valid trials < 50",
            "upload_figs_s3": {},
            "upload_pkls_s3": {},
            "upload_record_docDB": {},
        }
        print(f"Skipping session {session_id} due to too few trials n={len(choice_history)}.")
    
    else:
        # -- Initialize model --
        # forager = ForagerCollection().get_forager(
        #     agent_class_name="ForagerCompareThreshold",
        #     agent_kwargs={
        #         'choice_kernel': "none",
        #     },
        # )
        forager_ctt = ForagerCollection().get_preset_forager("CompareToThreshold")
        fitting_result_ctt, _ = forager_ctt.fit(
            choice_history,
            reward_history,
            clamp_params={
                # "biasL": 0, 
                # "softmax_inverse_temperature": 5.0
            },
            DE_kwargs=dict(
                workers=4, 
                disp=True, 
                seed=np.random.default_rng(42)
            ),
            # k_fold_cross_validation=None
        )

        forager_hattori = ForagerCollection().get_preset_forager("Hattori2019")
        fitting_result_hattori, _ = forager_hattori.fit(
            choice_history,
            reward_history,
            DE_kwargs=dict(
                workers=4, 
                disp=True, 
                seed=np.random.default_rng(42)
            ),
            # k_fold_cross_validation=None
        )


        # Check fitted parameters
        for model_ind, fitting_result in enumerate([fitting_result_ctt, fitting_result_hattori]):
            fit_names = fitting_result.fit_settings["fit_names"]
            print(f'Model: {['CompareToThreshold', 'Hattori'][model_ind]}')
            print(f'fitting results keys: {fitting_result_ctt.keys()}')
            print(f"Num of trials: {len(choice_history)}")
            print(f"Likelihood-Per-Trial: {fitting_result.LPT}")
            print(f"Negative Log-Likelihood: {-1*fitting_result.log_likelihood}")
            print(f"AIC: {fitting_result.AIC}")
            print(f"BIC: {fitting_result.BIC}")
            print(f"Prediction accuracy full dataset: {fitting_result.prediction_accuracy}")
            print(f"Fitted parameters: {fit_names}")
            print(f'Fitted:       {[f"{num:.4f}" for num in fitting_result.x]}\n')

        # append results
        nll_ctt.append(-1 * fitting_result_ctt.log_likelihood)
        nll_hattori.append(-1 * fitting_result_hattori.log_likelihood)
        aic_ctt.append(fitting_result_ctt.AIC)
        aic_hattori.append(fitting_result_hattori.AIC)
        bic_ctt.append(fitting_result_ctt.BIC)
        bic_hattori.append(fitting_result_hattori.BIC)

        # # plot fitted session
        # fig_fitting_ctt, axes_ctt = forager_ctt.plot_fitted_session(if_plot_latent=True)
        # fig_fitting_hattori, axes_hattori = forager_hattori.plot_fitted_session(if_plot_latent=True)

        # fig_fitting_ctt.savefig(f'/results/{session_id}-ctt.png', dpi=150)
        # fig_fitting_hattori.savefig(f'/results/{session_id}-hattori.png', dpi=150)

# convert to numpy arrays
nll_ctt = np.array(nll_ctt)
nll_hattori = np.array(nll_hattori)
aic_ctt = np.array(aic_ctt)
aic_hattori = np.array(aic_hattori)
bic_ctt = np.array(bic_ctt)
bic_hattori = np.array(bic_hattori)

# save data
with open(f'/results/{subject_id}_model_fits.pkl', 'wb') as f:
    pickle.dump({
        'nll_ctt': nll_ctt,
        'nll_hattori': nll_hattori,
        'aic_ctt': aic_ctt,
        'aic_hattori': aic_hattori,
        'bic_ctt': bic_ctt,
        'bic_hattori': bic_hattori,
    }, f)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


def plot_model_comparison(
    hattori_nll, ctt_nll, 
    hattori_aic, ctt_aic, 
    hattori_bic, ctt_bic, 
    subject_id,
    figsize=(20, 6)
):
    """
    Create violin plots comparing model fit metrics between Hattori and CTT models.
    
    Parameters:
    -----------
    hattori_nll, ctt_nll : array-like, shape (n_sessions, 1)
        Negative log-likelihood values for each model
    hattori_aic, ctt_aic : array-like, shape (n_sessions, 1)
        AIC values for each model
    hattori_bic, ctt_bic : array-like, shape (n_sessions, 1)
        BIC values for each model
    figsize : tuple
        Figure size (width, height)
    """
    
    # Ensure arrays are 1D
    hattori_nll = np.array(hattori_nll).flatten()
    ctt_nll = np.array(ctt_nll).flatten()
    hattori_aic = np.array(hattori_aic).flatten()
    ctt_aic = np.array(ctt_aic).flatten()
    hattori_bic = np.array(hattori_bic).flatten()
    ctt_bic = np.array(ctt_bic).flatten()
    
    n_sessions = len(hattori_nll)
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    fig.suptitle(
        f'Model Comparison for {subject_id}: Q-learning (Hattori) vs Foraging (compare-to-threshold)', 
        fontsize=14)
    
    # Define metrics and their data
    metrics = ['NLL', 'AIC', 'BIC']
    hattori_data = [hattori_nll, hattori_aic, hattori_bic]
    ctt_data = [ctt_nll, ctt_aic, ctt_bic]
    
    # Colors for models
    colors = {'Hattori': '#3498db', 'CTT': '#e74c3c'}
    
    # Plot each metric
    for idx, (ax, metric, h_data, c_data) in enumerate(zip(axes, metrics, hattori_data, ctt_data)):
        # Prepare data for violin plot
        data_dict = {
            'Model': ['Hattori'] * n_sessions + ['CTT'] * n_sessions,
            'Value': np.concatenate([h_data, c_data]),
            'Session': list(range(n_sessions)) * 2
        }
        df = pd.DataFrame(data_dict)
        
        # Create violin plot
        parts = ax.violinplot([h_data, c_data], positions=[0, 1], 
                             showmeans=True, showmedians=True, showextrema=True)
        
        # Customize violin colors
        for pc, color in zip(parts['bodies'], [colors['Hattori'], colors['CTT']]):
            pc.set_facecolor(color)
            pc.set_alpha(0.6)
            pc.set_edgecolor('black')
            pc.set_linewidth(1)
        
        # Customize other elements
        for partname in ('cbars', 'cmins', 'cmaxes', 'cmedians', 'cmeans'):
            if partname in parts:
                parts[partname].set_edgecolor('black')
                parts[partname].set_linewidth(1.5)
        
        # Add individual points
        x_hattori = np.random.normal(0, 0.04, n_sessions)
        x_ctt = np.random.normal(1, 0.04, n_sessions)
        
        ax.scatter(x_hattori, h_data, color=colors['Hattori'], 
                  alpha=0.8, s=30, edgecolor='black', linewidth=0.5, zorder=5)
        ax.scatter(x_ctt, c_data, color=colors['CTT'], 
                  alpha=0.8, s=30, edgecolor='black', linewidth=0.5, zorder=5)
        
        # Add connecting lines between corresponding sessions
        for i in range(n_sessions):
            ax.plot([x_hattori[i], x_ctt[i]], [h_data[i], c_data[i]], 
                   color='gray', alpha=0.3, linewidth=1, zorder=1)
        
        # Customize axes
        ax.set_ylabel(f'{metric}', fontsize=12, fontweight='bold')
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Hattori', 'CTT'], fontsize=11)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add metric statistics
        h_mean, h_std = np.mean(h_data), np.std(h_data)
        c_mean, c_std = np.mean(c_data), np.std(c_data)
        
        stats_text = f'Hattori: {h_mean:.2f} ± {h_std:.2f}\n'
        stats_text += f'CTT: {c_mean:.2f} ± {c_std:.2f}'
        ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, 
               verticalalignment='top', fontsize=14,
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # Add x-label
        ax.set_xlabel('Model', fontsize=12, fontweight='bold')
    
    # Adjust layout
    plt.tight_layout()
    
    return fig, axes


def plot_model_comparison_difference(
    hattori_nll, ctt_nll, 
    hattori_aic, ctt_aic, 
    hattori_bic, ctt_bic, 
    subject_id,
    figsize=(20, 6)
):
    """
    Create violin plots showing the distribution of differences between Hattori and CTT models.
    
    Parameters:
    -----------
    hattori_nll, ctt_nll : array-like, shape (n_sessions, 1)
        Negative log-likelihood values for each model
    hattori_aic, ctt_aic : array-like, shape (n_sessions, 1)
        AIC values for each model
    hattori_bic, ctt_bic : array-like, shape (n_sessions, 1)
        BIC values for each model
    subject_id : str
        Subject identifier for the title
    figsize : tuple
        Figure size (width, height)
    """
    
    # Ensure arrays are 1D
    hattori_nll = np.array(hattori_nll).flatten()
    ctt_nll = np.array(ctt_nll).flatten()
    hattori_aic = np.array(hattori_aic).flatten()
    ctt_aic = np.array(ctt_aic).flatten()
    hattori_bic = np.array(hattori_bic).flatten()
    ctt_bic = np.array(ctt_bic).flatten()
    
    n_sessions = len(hattori_nll)
    
    # Calculate differences (CTT - Hattori)
    # Negative values mean Hattori is better (lower is better for these metrics)
    diff_nll = ctt_nll - hattori_nll
    diff_aic = ctt_aic - hattori_aic
    diff_bic = ctt_bic - hattori_bic

    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=figsize, dpi=150)
    fig.suptitle(
        f'Model Comparison for {subject_id}' +
        ', difference = CTT - Hattori (negative values favor CTT)', 
        fontsize=16)
    
    # Define metrics and their data
    metrics = ['NLL', 'AIC', 'BIC']
    diff_data = [diff_nll, diff_aic, diff_bic]
    
    # Color for difference plot
    diff_color = '#9b59b6'
    
    # Plot each metric
    for idx, (ax, metric, diff) in enumerate(zip(axes, metrics, diff_data)):
        # Create violin plot for difference
        parts = ax.violinplot([diff], positions=[0], 
                             showmeans=True, showmedians=True, showextrema=True)
        
        # Customize violin colors
        for pc in parts['bodies']:
            pc.set_facecolor(diff_color)
            pc.set_alpha(0.6)
            pc.set_edgecolor('black')
            pc.set_linewidth(1)
        
        # Customize other elements
        for partname in ('cbars', 'cmins', 'cmaxes', 'cmedians', 'cmeans'):
            if partname in parts:
                parts[partname].set_edgecolor('black')
                parts[partname].set_linewidth(1.5)
        
        # Add horizontal line at y=0 (no difference)
        ax.axhline(y=0, color='black', linestyle='--', alpha=0.5, linewidth=2)
        
        # Add individual points
        x_positions = np.random.normal(0, 0.04, n_sessions)
        
        # Color points based on which model is better
        point_colors = ['#3498db' if d < 0 else '#e74c3c' for d in diff]
        
        ax.scatter(x_positions, diff, color=point_colors, 
                  alpha=0.8, s=40, edgecolor='black', linewidth=0.5, zorder=5)
        
        # # Add vertical lines from 0 to each point
        # for i in range(n_sessions):
        #     ax.plot([x_positions[i], x_positions[i]], [0, diff[i]], 
        #            color='gray', alpha=0.3, linewidth=0.8, zorder=1)
        
        # Customize axes
        ax.set_ylabel(f'Δ{metric} (CTT - Hattori)', fontsize=14, fontweight='bold')
        ax.set_xticks([])
        # ax.set_xticklabels(['Difference'], fontsize=11)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add metric statistics
        mean_diff = np.mean(diff)
        std_diff = np.std(diff)
        n_hattori_better = np.sum(diff > 0)
        n_ctt_better = np.sum(diff < 0)
        n_equal = np.sum(diff == 0)
        
        stats_text = f'Mean difference: {mean_diff:.2f} ± {std_diff:.2f}\n'
        stats_text += f'Hattori better: {n_hattori_better}/{n_sessions} sessions\n'
        stats_text += f'CTT better: {n_ctt_better}/{n_sessions} sessions'
        if n_equal > 0:
            stats_text += f'\nEqual: {n_equal}/{n_sessions} sessions'
        
        ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, 
               verticalalignment='top', fontsize=14,
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # Add x-label
        # ax.set_xlabel('Model Difference', fontsize=12, fontweight='bold')
    
    # Adjust layout
    fig.tight_layout()
    
    return fig, axes


# Create the plot
fig, axes = plot_model_comparison(
    nll_hattori, nll_ctt, 
    aic_hattori, aic_ctt, 
    bic_hattori, bic_ctt,
    subject_id=subject_id,
)
fig, axes = plot_model_comparison_difference(
    nll_hattori, nll_ctt, 
    aic_hattori, aic_ctt, 
    bic_hattori, bic_ctt,
    subject_id=subject_id,
)


# Print summary statistics
print("Model Comparison Summary (Difference = CTT - Hattori, negative favors CTT):")
n_sessions = len(nll_hattori)
print(f"Number of sessions: {n_sessions}")

print("-" * 50)
print(f"NLL - Hattori: {np.mean(nll_hattori):.2f} ± {np.std(nll_hattori):.2f}")
print(f"NLL - CTT: {np.mean(nll_ctt):.2f} ± {np.std(nll_ctt):.2f}")
print(f"AIC - Hattori: {np.mean(aic_hattori):.2f} ± {np.std(aic_hattori):.2f}")
print(f"AIC - CTT: {np.mean(aic_ctt):.2f} ± {np.std(aic_ctt):.2f}")
print(f"BIC - Hattori: {np.mean(bic_hattori):.2f} ± {np.std(bic_hattori):.2f}")
print(f"BIC - CTT: {np.mean(bic_ctt):.2f} ± {np.std(bic_ctt):.2f}")

print(f"\nΔNLL: {np.mean(nll_ctt - nll_hattori):.2f} ± {np.std(nll_ctt - nll_hattori):.2f}")
print(f"ΔAIC: {np.mean(aic_ctt - aic_hattori):.2f} ± {np.std(aic_ctt - aic_hattori):.2f}")
print(f"ΔBIC: {np.mean(bic_ctt - bic_hattori):.2f} ± {np.std(bic_ctt - bic_hattori):.2f}")

print("\nSessions where CTT performs better:")
print(f"NLL: {np.sum((nll_ctt - nll_hattori) < 0)}/{n_sessions}")
print(f"AIC: {np.sum((aic_ctt - aic_hattori) < 0)}/{n_sessions}")
print(f"BIC: {np.sum((bic_ctt - bic_hattori) < 0)}/{n_sessions}")


In [None]:
# -- Load data --
# session_id = '781896_2025-04-10_14-11-57'

# session_id = '781370_2025-02-03_11-09-28'
# session_id = '781370_2025-02-05_11-25-51'
session_id = '781370_2025-03-20_11-12-56'
# session_id = '781370_2025-02-14_11-26-21'
# session_id = '781370_2025-02-17_11-11-23'

# session_id = '784806_2025-04-21_13-13-39'

# session_id = '770527_2025-01-15_11-01-55'

# session_id = '739977_2024-10-03_09-04-34'

# session_id = '786866_2025-04-10_11-24-47'


nwb = get_nwb_from_local_tmp(session_id=session_id)
(
    baiting,
    choice_history,
    reward_history,
    _,
    autowater_offered,
    random_number,
) = get_history_from_nwb(nwb)

In [None]:
# Remove NaNs
ignored = np.isnan(choice_history)
choice_history = choice_history[~ignored]
reward_history = reward_history[~ignored].to_numpy()

# -- Skip if len(valid trials) < 50 --
if len(choice_history) < 50:
    fit_result = {
        "status": "skipped. valid trials < 50",
        "upload_figs_s3": {},
        "upload_pkls_s3": {},
        "upload_record_docDB": {},
    }

# -- Initialize model --
# forager = ForagerCollection().get_forager(
#     agent_class_name="ForagerCompareThreshold",
#     agent_kwargs={
#         'choice_kernel': "none",
#     },
# )

forager_ctt = ForagerCollection().get_preset_forager("CompareToThreshold")
fitting_result_ctt, _ = forager_ctt.fit(
    choice_history,
    reward_history,
    clamp_params={
        # "biasL": 0, 
        # "softmax_inverse_temperature": 5.0
    },
    DE_kwargs=dict(
        workers=4, 
        disp=True, 
        seed=np.random.default_rng(42)
    ),
    # k_fold_cross_validation=None
)


forager_hattori = ForagerCollection().get_preset_forager("Hattori2019")
fitting_result_hattori, _ = forager_hattori.fit(
    choice_history,
    reward_history,
    DE_kwargs=dict(
        workers=4, 
        disp=True, 
        seed=np.random.default_rng(42)
    ),
    # k_fold_cross_validation=None
)

In [None]:
# Check fitted parameters
for model_ind, fitting_result in enumerate([fitting_result_ctt, fitting_result_hattori]):
    fit_names = fitting_result.fit_settings["fit_names"]
    print(f'Model: {['CompareToThreshold', 'Hattori'][model_ind]}')
    print(f"Num of trials: {len(choice_history)}")
    print(f"Likelihood-Per-Trial: {fitting_result.LPT}")
    print(f"AIC: {fitting_result.AIC}")
    print(f"BIC: {fitting_result.BIC}")
    print(f"Prediction accuracy full dataset: {fitting_result.prediction_accuracy}")
    print(f"Fitted parameters: {fit_names}")
    print(f'Fitted:       {[f"{num:.4f}" for num in fitting_result.x]}\n')

In [None]:
fig_fitting, axes = forager_ctt.plot_fitted_session(if_plot_latent=True)
fig_fitting, axes = forager_hattori.plot_fitted_session(if_plot_latent=True)