In [None]:
#!git clone https://github.com/whyhardt/SPICE.git

In [None]:
# !pip install -e SPICE

In [1]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice import SpiceEstimator, SpiceConfig, csv_to_dataset, BaseRNN, plot_session

# For custom RNN
import torch
import torch.nn as nn

## Load dataset

Let's load the data first with the `csv_to_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [2]:
# Load your data
dataset = csv_to_dataset(
    file = '../data/ganesh2024a/ganesh2024a_choice.csv',
    df_participant_id='subjID',
    df_choice='chose_high',
    df_feedback='reward',
    df_block='blocks',
    additional_inputs=['contrast_difference'],
    timeshift_additional_inputs=[-1],
    )

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
n_actions = dataset.ys.shape[-1]
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

Shape of dataset: torch.Size([1176, 25, 8])
Number of participants: 98
Number of actions in dataset: 2
Number of additional inputs: 1


## SPICE Setup

Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [None]:
spice_config = SpiceConfig(
    library_setup={
        'value_reward_chosen': ['contr_diff', 'reward'],
        'value_reward_not_chosen': ['contr_diff'],
        'value_choice': ['contr_diff', 'choice'],
    },
    
    memory_state={
            'value_reward': 0.,
            'value_choice': 0.,
        }
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [None]:
class SPICERNN(BaseRNN):
    
    def __init__(self, spice_config, **kwargs):
        super().__init__(spice_config=spice_config, **kwargs)
        
        # participant embedding
        self.participant_embedding = self.setup_embedding(num_embeddings=self.n_participants, embedding_size=self.embedding_size, dropout=0.)
        
        # set up the submodules
        self.setup_module(key_module='value_reward_chosen', input_size=2+self.embedding_size)
        self.setup_module(key_module='value_reward_not_chosen', input_size=1+self.embedding_size)
        self.setup_module(key_module='value_choice', input_size=2+self.embedding_size)
        
    def forward(self, inputs, prev_state, batch_first=False):
        
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)
        
        contr_diffs = spice_signals.additional_inputs.repeat(1, 1, self.n_actions)
        rewards_chosen = (spice_signals.actions * spice_signals.rewards).sum(dim=-1, keepdim=True).repeat(1, 1, self.n_actions)
        
        # time-invariant participant features
        participant_embeddings = self.participant_embedding(spice_signals.participant_ids)
        
        for timestep in spice_signals.timesteps:
            
            # update chosen value
            self.call_module(
                key_module='value_reward_chosen',
                key_state='value_reward',
                action_mask=spice_signals.actions[timestep],
                inputs=(contr_diffs[timestep], rewards_chosen[timestep]),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            # update not chosen value
            self.call_module(
                key_module='value_reward_not_chosen',
                key_state='value_reward',
                action_mask=1-spice_signals.actions[timestep],
                inputs=(contr_diffs[timestep]),  # add input rewards_chosen[timestep] for counterfactual updating (adjust in config as well)
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            # same for choice values
            self.call_module(
                key_module='value_choice',
                key_state='value_choice',
                action_mask=spice_signals.actions[timestep],
                inputs=(contr_diffs[timestep], spice_signals.actions[timestep]),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            spice_signals.logits[timestep] = self.state['value_reward'] + self.state['value_choice']
            
        spice_signals = self.post_forward_pass(spice_signals, batch_first)
        
        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [None]:
path_spice = '../params/ganesh2024a/spice_ganesh2024a.pkl'

estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SPICERNN,
        spice_config=spice_config,
        n_actions=2,
        n_participants=n_participants,
        
        # rnn training parameters
        epochs=1000,
        warmup_steps=200,
        learning_rate=0.01,
        
        # sindy fitting parameters
        sindy_weight=0.1,
        sindy_pruning_threshold=0.05,
        sindy_pruning_frequency=1,
        sindy_pruning_terms=1,
        sindy_pruning_patience=100,
        sindy_epochs=1000,
        sindy_l2_lambda=0.0001,
        sindy_library_polynomial_degree=2,
        sindy_ensemble_size=1,
        
        # additional generalization parameters
        batch_size=1024,
        bagging=True,
        scheduler=True,
        
        verbose=True,
        save_path_spice=path_spice,
    )

In [None]:
print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys, dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)

In [None]:
estimator.load_spice(path_spice)

## GRU for benchmarking

In [3]:
import sys

sys.path.append('../..')
from weinhardt2025.benchmarking.benchmarking_gru import GRU, training, setup_agent_gru

path_gru = '../../weinhardt2025/params/ganesh2024a/gru_ganesh2024a.pkl'

In [4]:
epochs = 10000

gru = GRU(n_actions=n_actions, additional_inputs=1).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gru.parameters(), lr=0.01)

gru = training(
    gru=gru,
    optimizer=optimizer,
    dataset_train=dataset,
    dataset_test=dataset,
    epochs=epochs,
    )

torch.save(gru.state_dict(), path_gru)
print("Trained GRU parameters saved to " + path_gru)

Epoch 1/10000: L(Train): 0.689823567867279; L(Test): 0.6534026861190796
Epoch 2/10000: L(Train): 0.6544772982597351; L(Test): 0.6174305081367493
Epoch 3/10000: L(Train): 0.6187154650688171; L(Test): 0.5790481567382812
Epoch 4/10000: L(Train): 0.5809915661811829; L(Test): 0.5414983034133911
Epoch 5/10000: L(Train): 0.5447678565979004; L(Test): 0.5073610544204712
Epoch 6/10000: L(Train): 0.5105915665626526; L(Test): 0.47670066356658936
Epoch 7/10000: L(Train): 0.4810183644294739; L(Test): 0.45386821031570435
Epoch 8/10000: L(Train): 0.4579528570175171; L(Test): 0.4428096413612366
Epoch 9/10000: L(Train): 0.44717276096343994; L(Test): 0.44129398465156555
Epoch 10/10000: L(Train): 0.44429340958595276; L(Test): 0.44515877962112427
Epoch 11/10000: L(Train): 0.4482787847518921; L(Test): 0.4515182673931122
Epoch 12/10000: L(Train): 0.4546673595905304; L(Test): 0.45708611607551575
Epoch 13/10000: L(Train): 0.45963090658187866; L(Test): 0.45915669202804565
Epoch 14/10000: L(Train): 0.46249815821

KeyboardInterrupt: 

In [None]:
gru_agent = setup_agent_gru(path_gru)

## Plot SPICE against benchmark models

In [None]:
# plotting
participant_id = 7

estimator.print_spice_model(participant_id)

agents = {
    # add baseline agent here
    'rnn': estimator.rnn_agent,
    'spice': estimator.spice_agent,
    'gru': gru_agent,
}

fig, axs = plot_session(agents, dataset.xs[participant_id])
plt.show()