In [None]:
#!git clone https://github.com/whyhardt/SPICE.git

In [None]:
# !pip install -e SPICE

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice import SpiceEstimator, SpiceConfig, convert_dataset, split_data_along_sessiondim, split_data_along_timedim, BaseRNN, plot_session
from spice.precoded.workingmemory import SpiceModel, CONFIG
from spice.utils.setup_agents import Agent

# For custom RNN
import torch
import torch.nn as nn

Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [35]:
# Load your data
dataset = convert_dataset(
    file = '../data/dezfouli2019/dezfouli2019.csv',
    df_participant_id='session',
    df_choice='choice',
    df_reward='reward',
    df_block='block',
    )

dataset_train, dataset_test = split_data_along_sessiondim(dataset, list_test_sessions=[3, 6, 9])
# dataset_train, dataset_test = split_data_along_timedim(dataset, 0.8)

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Shape of training dataset: {dataset_train.xs.shape}")
print(f"Shape of test dataset: {dataset_test.xs.shape}")
print(f"Number of participants: {n_participants}")
n_actions = dataset.ys.shape[-1]
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

Shape of dataset: torch.Size([1212, 202, 7])
Shape of training dataset: torch.Size([909, 202, 7])
Shape of test dataset: torch.Size([303, 202, 7])
Number of participants: 101
Number of actions in dataset: 2
Number of additional inputs: 0


Let's setup now the `SpiceEstimator` object and fit it to the data!

In [None]:
estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SpiceModel,
        spice_config=CONFIG,
        n_actions=n_actions,
        n_participants=n_participants,
        
        # rnn training parameters
        epochs=1000,
        warmup_steps=200,
        learning_rate=0.01,
        
        # sindy fitting parameters
        sindy_weight=0.1,
        sindy_pruning_threshold=0.05,
        sindy_pruning_frequency=1,
        sindy_pruning_terms=1,
        sindy_pruning_patience=100,
        sindy_epochs=1000,
        sindy_l2_lambda=0.0001,
        sindy_library_polynomial_degree=2,
        sindy_ensemble_size=1,
        
        # additional generalization parameters
        batch_size=1024,
        bagging=True,
        scheduler=True,
        
        verbose=True,
        save_path_spice='../params/eckstein2024/spice_eckstein2024.pkl',
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset_train.xs, dataset_train.ys, dataset_test.xs, dataset_test.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)

In [None]:
# plotting
participant_id = 7

estimator.print_spice_model(participant_id)

agents = {
    'rnn': estimator.rnn_agent,
    'spice': estimator.spice_agent,
}

fig, axs = plot_session(agents, dataset.xs[participant_id])
plt.show()


### GRU for benchmarking

In [36]:
import sys

sys.path.append('../..')
from weinhardt2025.benchmarking.benchmarking_gru import GRU, training

In [37]:
epochs = 1000

gru = GRU(input_size=n_actions+1, n_actions=n_actions).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
optimizer = torch.optim.Adam(gru.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

training(
    gru=gru, 
    dataset_train=dataset_train,
    dataset_test=dataset_test,
    optimizer=optimizer,
    epochs=epochs,
    )

torch.save(gru.state_dict(), '../../weinhardt2025/params/eckstein2024/gru_eckstein2024.pkl')

Epoch 1/1000: L(Train): 0.7371260523796082; L(Test): 0.7743759751319885
Epoch 2/1000: L(Train): 0.7090870141983032; L(Test): 0.8017184734344482
Epoch 3/1000: L(Train): 0.70463627576828; L(Test): 0.8095449805259705
Epoch 4/1000: L(Train): 0.7071415781974792; L(Test): 0.8021448850631714
Epoch 5/1000: L(Train): 0.7035369277000427; L(Test): 0.786541223526001
Epoch 6/1000: L(Train): 0.7017152905464172; L(Test): 0.7651168704032898
Epoch 7/1000: L(Train): 0.7034543752670288; L(Test): 0.7396465539932251
Epoch 8/1000: L(Train): 0.6985102891921997; L(Test): 0.7157225012779236
Epoch 9/1000: L(Train): 0.6959549188613892; L(Test): 0.6946724057197571
Epoch 10/1000: L(Train): 0.694764256477356; L(Test): 0.6811244487762451
Epoch 11/1000: L(Train): 0.6975914239883423; L(Test): 0.6738269925117493
Epoch 12/1000: L(Train): 0.6951149702072144; L(Test): 0.6708783507347107
Epoch 13/1000: L(Train): 0.6964312791824341; L(Test): 0.6719015836715698
Epoch 14/1000: L(Train): 0.6955560445785522; L(Test): 0.67564588