In [None]:
import sys
import os
import matplotlib.pyplot as plt
import warnings
import matplotlib.ticker as ticker
import numpy as np
import torch
from torch import manual_seed

# Disable cudnn
from torch.backends import cudnn
cudnn.enabled = False

from spice import pipeline_rnn_autoreg
from spice.resources.old_rnn import RLRNN_dezfouli2019, RLRNN_eckstein2022 # Get predefined RNN architectures


In [None]:
# Set most important arguments:
dataset = 'eckstein2022'  # 'eckstein2022' or 'dezfouli2019'
epochs = 512
metaopt_type = 'awd'

# AWD
lambda_awd = 1e-1

# iMAML
initial_reg_param = 1e-4
outer_lr = 1e-2

# Fixed model path name
if metaopt_type == 'awd':
    path_model = f'params/{dataset}/AWD_{dataset}_ep{epochs}_lawd-{lambda_awd}_rnn.pkl'
elif metaopt_type == 'imaml':
    path_model = f'params/{dataset}/iMAML_{dataset}_ep{epochs}_metalr-{outer_lr}_in-{initial_reg_param}_rnn.pkl'
else:
    raise ValueError('metaopt_type must be either "awd" or "imaml"')

# SPICE config
path_data = f'../data/{dataset}/{dataset}.csv'
additional_inputs = None

if dataset == 'eckstein2022':
    train_test_ratio = 0.8
    class_rnn = RLRNN_eckstein2022
elif dataset == 'dezfouli2019':
    train_test_ratio = [3, 6, 9]
    class_rnn = RLRNN_dezfouli2019
else:
    raise ValueError('Dataset must be either "eckstein2022" or "dezfouli2019"')


In [None]:
model, _, histories = pipeline_rnn_autoreg.main(
    
    dropout=0.25,
    train_test_ratio=train_test_ratio,
    
    # general training parameters
    checkpoint=False,
    epochs=epochs, # <- 2^16
    scheduler=True,
    learning_rate=1e-2, # 1e-2

    # Meta-optimization parameters
    metaopt_type=metaopt_type,

    lambda_awd=lambda_awd,

    meta_update_interval=50,
    inner_steps=3,
    outer_lr=outer_lr,
    hypergradient_steps=3,
    initial_reg_param=initial_reg_param,

    # hand-picked params
    n_steps=-1,
    embedding_size=32,
    batch_size=-1,
    sequence_length=-1,
    bagging=True,
    
    class_rnn=class_rnn,
    model=path_model,
    data=path_data,
    additional_inputs_data=additional_inputs,
    
    # synthetic dataset parameters
    n_sessions=128,
    n_trials=200,
    sigma=0.2,
    beta_reward=3.,
    alpha_reward=0.25,
    alpha_penalty=0.5,
    forget_rate=0.,
    confirmation_bias=0.,
    beta_choice=0.,
    alpha_choice=1.,
    counterfactual=False,
    alpha_counterfactual=0.,
    
    save_checkpoints=True,
    analysis=False,
    participant_id=0,
    )

In [None]:
train_loss_history, val_loss_history, reg_history = histories
plot_epochs = np.arange(1, len(train_loss_history) + 1)

plt.figure(figsize=(12, 4))

# Plot 1: Losses
plt.subplot(1, 2, 1)
plt.plot(plot_epochs, train_loss_history, label="Train Loss", linewidth=2)
plt.plot(plot_epochs, val_loss_history, label="Validation Loss", linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)


if metaopt_type == 'awd':
    # Subplot 2: lambda weight decay
    plt.subplot(1,2,2)
    plt.plot(plot_epochs, reg_history, label='λ weight decay')
    plt.xlabel('Epoch')
    plt.ylabel('λ weight decay')
    plt.legend()
    plt.grid(True, alpha=0.3)
elif metaopt_type == 'imaml':
    # Plot 2: Regularization parameter
    plt.subplot(1, 2, 2)
    plt.plot(plot_epochs, reg_history, label="λ (Regularization)", linewidth=2, color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Regularization Parameter')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

In [None]:
# Set Configs
from spice.resources import sindy_utils
sindy_config = sindy_utils.SindyConfig_eckstein2022 if dataset == 'eckstein2022' else sindy_utils.SindyConfig_dezfouli2019

In [None]:
### SINDy
from spice import pipeline_sindy

agent_spice, features, sindy_loss = pipeline_sindy.main(
    class_rnn=class_rnn,
    model = path_model,
    data = path_data,
    additional_inputs_data=additional_inputs,
    save = True,
    
    # general recovery parameters
    participant_id=None,
    filter_bad_participants=False,
    use_optuna=True,
    pruning=False,
    
    # sindy parameters
    train_test_ratio=train_test_ratio,
    polynomial_degree=3,
    optimizer_alpha=0.1,
    optimizer_threshold=0.05, # 0.05
    n_trials_off_policy=1000,
    n_sessions_off_policy=1,
    n_trials_same_action_off_policy=5,
    optuna_threshold=0.1, # 0.1
    optuna_n_trials=50, # 50
    optimizer_type='SR3_weighted_l1',  # 'STLSQ',  'SR3_weighted_l1'
    # optimizer_type='SR3_L1',
    verbose=False,
    
    # generated training dataset parameters
    n_actions=2,
    sigma=0.2,
    beta_reward=1.,
    alpha=0.25,
    alpha_penalty=0.25,
    forget_rate=0.,
    confirmation_bias=0.,
    beta_choice=1.,
    alpha_choice=1.,
    counterfactual=False,
    alpha_counterfactual=0.,
    
    analysis=True,
    get_loss=False,
    
    **sindy_config,
)

In [None]:
# Model analysis
if dataset == 'eckstein2022':
    # ------------------- CONFIGURATION ECKSTEIN2022 w/o AGE --------------------
    study = 'eckstein2022'
    models_benchmark = ['ApAnBrBcfBch']#['ApBr', 'ApBrAcfpBcf', 'ApBrAcfpBcfBch', 'ApAnBrBch', 'ApAnBrAcfpAcfnBcfBch', 'ApAnBrBcfBch']
    train_test_ratio = 0.8
    sindy_config = sindy_utils.SindyConfig_eckstein2022
    rnn_class = rnn.RLRNN_eckstein2022
    additional_inputs = None
    setup_agent_benchmark = benchmarking_eckstein2022.setup_agent_benchmark
    rl_model = benchmarking_eckstein2022.rl_model
    benchmark_file = f'mcmc_{study}_MODEL.nc'
    model_config_baseline = 'ApBr'
    baseline_file = f'mcmc_{study}_ApBr.nc'

elif dataset == 'dezfouli2019':
    # ------------------------ CONFIGURATION DEZFOULI2019 -----------------------
    study = 'dezfouli2019'
    train_test_ratio = [3, 6, 9]
    models_benchmark = ['PhiChiBetaKappaC']
    sindy_config = sindy_utils.SindyConfig_dezfouli2019
    rnn_class = rnn.RLRNN_dezfouli2019
    additional_inputs = []
    # setup_agent_benchmark = benchmarking_dezfouli2019.setup_agent_benchmark
    # gql_model = benchmarking_dezfouli2019.gql_model
    setup_agent_benchmark = benchmarking_dezfouli2019.setup_agent_gql
    gql_model = benchmarking_dezfouli2019.Dezfouli2019GQL
    benchmark_file = f'gql_{study}_MODEL.pkl'
    model_config_baseline = 'PhiBeta'
    baseline_file = f'gql_{study}_PhiBeta.pkl'

# ------------------------- CONFIGURATION FILE PATHS ------------------------
use_test = True

path_data = f'data/{study}/{study}.csv'
path_model_rnn = path_model
path_model_spice = path_model.replace('_rnn.pkl', '_spice.pkl')
path_model_baseline = None
path_model_benchmark = None

dataset = convert_dataset(path_data, additional_inputs=additional_inputs)[0]
# use these participant_ids if not defined later
participant_ids = dataset.xs[:, 0, -1].unique().cpu().numpy()

