In [1]:
#!git clone https://github.com/whyhardt/SPICE.git

In [2]:
# !pip install -e SPICE

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice import SpiceEstimator, SpiceConfig, csv_to_dataset, BaseRNN, cross_entropy_loss, mse_loss, plot_session, split_data_along_sessiondim

# For custom RNN
import torch
import torch.nn as nn

## Load dataset

Let's load the data first with the `csv_to_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [None]:
dataset = csv_to_dataset(
    file = '../data/ganesh2024a/ganesh2024a_slider.csv',
    df_participant_id='subjID',
    df_choice='chose_high',
    df_feedback='reward',
    df_block='blocks',
    additional_inputs=['contrast_difference', 'slider'],
    # timeshift_additional_inputs=[0, 1],
    )
# -> dataset has input values dataset.xs (shape: participants x experiments x blocks, timesteps, features) and target values dataset.ys (shape: participants x experiments x blocks, timesteps, next action)
# the features in dataset.xs are (action, feedback per action, additional inputs, block id, experiment id, participant id)
# action, feedback and next action are one-hot-encoded
# CHECK IF NECESSARY: move slider one timestep forward to make it the target

# dataset characteristics
n_participants = dataset.xs[..., -1].unique().shape[0]
n_actions = dataset.ys.shape[-1]
n_additional_inputs = dataset.xs.shape[-1] - 2*n_actions - 3

# set slider values as target
# replace target values in dataset.ys (next action) with slider values 
dataset.ys = torch.concat((
    dataset.xs[:, 1:, n_actions*2+1].unsqueeze(-1),
    torch.full((dataset.xs.shape[0], 1, 1), fill_value=torch.nan)
    ), dim=1)

# cut off last timestep from dataset.xs for each block in dim=0
for index_block, block in enumerate(dataset.xs):
    n_trials = (~torch.isnan(dataset.xs[index_block, :, 0])).sum()
    dataset.xs[index_block] = torch.concat((
        dataset.xs[index_block, :n_trials-1], 
        torch.full((dataset.xs.shape[1]-n_trials+1, dataset.xs.shape[-1]), 
                   fill_value=torch.nan)))

dataset_train, dataset_test = split_data_along_sessiondim(dataset, list_test_sessions=[3,5])

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {n_additional_inputs}")

Shape of dataset: torch.Size([1176, 25, 9])
Number of participants: 98
Number of actions in dataset: 2
Number of additional inputs: 2


## SPICE Setup

Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [None]:
spice_config = SpiceConfig(
    library_setup={
        'value_reward_chosen': ['contr_diff', 'reward'],
        'value_reward_not_chosen': ['contr_diff'],
        'value_choice': ['contr_diff', 'choice'],
        'value_slider': ['value_slider[t]', 'value_reward_0', 'value_reward_1', 'value_choice_0', 'value_choice_1'],
    },
    
    memory_state={
            'value_reward': 0.,
            'value_choice': 0.,
            'value_slider': 0.5,
        }
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [None]:
class SPICERNN(BaseRNN):
    
    def __init__(self, spice_config, **kwargs):
        super().__init__(spice_config=spice_config, **kwargs)
        
        # participant embedding
        self.participant_embedding = self.setup_embedding(num_embeddings=self.n_participants, embedding_size=self.embedding_size, dropout=0.)
        
        # set up the submodules
        self.setup_module(key_module='value_reward_chosen', input_size=2+self.embedding_size)
        self.setup_module(key_module='value_reward_not_chosen', input_size=1+self.embedding_size)
        self.setup_module(key_module='value_choice', input_size=2+self.embedding_size)
        self.setup_module(key_module='value_slider', input_size=5+self.embedding_size)
        
    def forward(self, inputs, prev_state, batch_first=False):
        
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)
        # outputs carry next slider value
        outputs = torch.zeros((*spice_signals.actions.shape[:-1], 1))
        
        contr_diffs = spice_signals.additional_inputs[..., 0].unsqueeze(-1).repeat(1, 1, self.n_actions)
        rewards_chosen = (spice_signals.actions * spice_signals.rewards).sum(dim=-1, keepdim=True).repeat(1, 1, self.n_actions)
        
        # time-invariant participant features
        participant_embeddings = self.participant_embedding(spice_signals.participant_ids)
        
        for timestep in spice_signals.timesteps:
            
            # update chosen value
            self.call_module(
                key_module='value_reward_chosen',
                key_state='value_reward',
                action_mask=spice_signals.actions[timestep],
                inputs=(
                    contr_diffs[timestep], 
                    rewards_chosen[timestep],
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            # update not chosen value
            self.call_module(
                key_module='value_reward_not_chosen',
                key_state='value_reward',
                action_mask=1-spice_signals.actions[timestep],
                inputs=(
                    contr_diffs[timestep],
                    ),  # add input rewards_chosen[timestep] for counterfactual updating (adjust in config as well)
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            # same for choice values
            self.call_module(
                key_module='value_choice',
                key_state='value_choice',
                action_mask=spice_signals.actions[timestep],
                inputs=(
                    contr_diffs[timestep], 
                    spice_signals.actions[timestep],
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            # predict slider value
            self.call_module(
                key_module='value_slider',
                key_state='value_slider',
                # action_mask=spice_signals.actions[timestep],
                inputs=(
                    self.state['value_slider'],
                    self.state['value_reward'][:, 0].unsqueeze(-1).repeat(1, self.n_actions),
                    self.state['value_reward'][:, 1].unsqueeze(-1).repeat(1, self.n_actions),
                    self.state['value_choice'][:, 0].unsqueeze(-1).repeat(1, self.n_actions),
                    self.state['value_choice'][:, 1].unsqueeze(-1).repeat(1, self.n_actions),
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                activation_rnn=torch.nn.functional.sigmoid,  # slider gives probality -> range: [0,1]
            )
            
            outputs[timestep] = self.state['value_slider'][..., :1]

        spice_signals.logits = outputs
        spice_signals = self.post_forward_pass(spice_signals, batch_first)
        
        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [14]:
estimator = SpiceEstimator(
    rnn_class=SPICERNN,
    spice_config=spice_config,
    n_participants=n_participants,
    
    epochs=1000,
    epochs_confidence=0,
    warmup_steps=500,
    # loss_fn=cross_entropy_loss,
    loss_fn=mse_loss,
    
    sindy_epochs=1000,
    sindy_weight=0.1,
    sindy_l2_lambda=0.0001,
    sindy_pruning_patience=100,
    sindy_confidence_threshold=0.05,
    sindy_library_polynomial_degree=2,
    
    # save_path_spice=path_spice,
    verbose=True,
)

In [15]:
print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys, dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cpu...

SPICE Training Configuration:
	SPICE joint training: active
	Confidence-based SINDy coefficient filtering: deactived
	SINDy-only finetuning: active

Stage 1: SPICE joint training
Epoch 1/1000 --- L(Train): 0.1065140 --- L(Val, RNN): 0.0000000 --- L(Val, SINDy): 0.0000000 --- Time: 0.82s; --- Convergence: 1.00e+00
--------------------------------------------------------------------------------
SPICE Model (Coefficients: 47):
value_reward_chosen[t+1]     = -0.007 1 + 0.999 value_reward_chosen[t] + 0.002 contr_diff + -0.003 reward + 0.001 value_reward_chosen^2 + 0.0 value_reward_chosen*contr_diff + -0.002 value_reward_chosen*reward + 0.0 contr_diff^2 + -0.001 contr_diff*reward + -0.006 reward^2 
value_reward_not_chosen[t+1] = -0.008 1 + 1.006 value_reward_not_chosen[t] + 0.001 contr_diff + -0.006 value_reward_not_chosen^2 + -0.001 value_reward_not_chosen*contr_diff + -0.003 contr_diff^2 
value_choice[t+1]            = 0.008 1 + 0.994 value_choice[t] + 0.006 c

In [None]:
estimator.load_spice(path_spice)

## GRU for benchmarking

In [16]:
import sys

sys.path.append('../..')
from weinhardt2025.benchmarking.benchmarking_gru import GRU, training, setup_agent_gru

path_gru = '../../weinhardt2025/params/ganesh2024a/gru_ganesh2024a.pkl'

In [17]:
class GRU(torch.nn.Module):
    
    def __init__(self, n_actions, additional_inputs: int = 0, hidden_size: int = 16, **kwargs):
        super().__init__()
        
        self.gru_features = hidden_size
        self.n_actions = n_actions
        self.additional_inputs = additional_inputs
        
        self.linear_in = torch.nn.Linear(in_features=n_actions*2+additional_inputs, out_features=hidden_size)
        self.dropout = torch.nn.Dropout(0.1)
        self.gru = torch.nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True)
        self.linear_out = torch.nn.Linear(in_features=hidden_size, out_features=1)
        
    def forward(self, inputs: torch.Tensor, state=None):
        
        if state is not None and len(inputs.shape) == 3:
            state = state.reshape(1, 1, self.gru_features)
        
        y = self.linear_in(inputs[..., :-3].nan_to_num(0))
        y = self.dropout(y)
        y, state = self.gru(y, state)
        y = self.dropout(y)
        y = self.linear_out(y)
        return y, state

In [18]:
epochs = 1000

gru = GRU(n_actions=n_actions, additional_inputs=n_additional_inputs).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
# criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gru.parameters(), lr=0.01)

gru = training(
    gru=gru,
    optimizer=optimizer,
    dataset_train=dataset,
    dataset_test=dataset,
    epochs=1000,
    # criterion=cross_entropy_loss,
    criterion=mse_loss,
)

torch.save(gru.state_dict(), path_gru)
print("Trained GRU parameters saved to " + path_gru)

Epoch 1/1000: L(Train): 0.8310936689376831; L(Test): 0.558139443397522
Epoch 2/1000: L(Train): 0.5620357394218445; L(Test): 0.35051628947257996
Epoch 3/1000: L(Train): 0.3558593988418579; L(Test): 0.19603556394577026
Epoch 4/1000: L(Train): 0.2017137110233307; L(Test): 0.09831967949867249
Epoch 5/1000: L(Train): 0.10628300160169601; L(Test): 0.06589443236589432
Epoch 6/1000: L(Train): 0.07547271996736526; L(Test): 0.0928729772567749
Epoch 7/1000: L(Train): 0.10310812294483185; L(Test): 0.13365864753723145
Epoch 8/1000: L(Train): 0.14521406590938568; L(Test): 0.14614365994930267
Epoch 9/1000: L(Train): 0.1575026512145996; L(Test): 0.13044753670692444
Epoch 10/1000: L(Train): 0.1432136744260788; L(Test): 0.10345059633255005
Epoch 11/1000: L(Train): 0.11524749547243118; L(Test): 0.080243781208992
Epoch 12/1000: L(Train): 0.09080776572227478; L(Test): 0.06828930228948593
Epoch 13/1000: L(Train): 0.0781652182340622; L(Test): 0.06765039265155792
Epoch 14/1000: L(Train): 0.07682143151760101; 

In [None]:
gru_agent = setup_agent_gru(path_gru)

## Plot SPICE against benchmark models

In [None]:
# plotting
participant_id = 7

estimator.print_spice_model(participant_id)

agents = {
    # add baseline agent here
    'rnn': estimator.rnn_agent,
    'spice': estimator.spice_agent,
    'gru': gru_agent,
}

fig, axs = plot_session(agents, dataset.xs[participant_id])
plt.show()

In [None]:
# TODO: make plot for slider values

participant_id = 0

participant_mask = dataset.xs[:, 0, -1] == participant_id

xs_participant = dataset.xs[participant_mask]
slider_trajectory = xs_participant[..., n_actions*2+1]

for t in range(slider_trajectory.shape[1]):
    pass