In [None]:
#!git clone https://github.com/whyhardt/SPICE.git

In [None]:
# !pip install -e SPICE

In [1]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice import SpiceEstimator, SpiceDataset, SpiceConfig, BaseRNN, convert_dataset, split_data_along_sessiondim, plot_session

# For custom RNN
import torch
import torch.nn as nn

## Load dataset

Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [2]:
# Load your data
dataset = convert_dataset(
    file = '../data/weber2024/weber2024.csv',
    df_participant_id='participant',
    df_experiment_id='experiment',
    df_choice='choice',
    df_reward='reward',
    df_block='block',
    additional_inputs=['laserRotation', 'shieldRotation', 'totalReward'],
    timeshift_additional_inputs=True,
    )

test_sessions = 8, 10, 12

# restructure data to have only two actions (stay, move) instead of three (stay, move_clockwise, move_counter_clockwise)
move = dataset.xs[..., 1:3].sum(dim=-1, keepdim=True)
rewards_move = dataset.xs[..., 4:6].nan_to_num(0).sum(dim=-1, keepdim=True)
move_ys = dataset.ys[..., 1:3].sum(dim=-1, keepdim=True)
# create restructured dataset
xs = torch.concat((dataset.xs[..., :1], move, dataset.xs[..., 3:4], rewards_move, dataset.xs[..., 6:]), dim=-1)
ys = torch.concat((dataset.ys[..., :1], move_ys), dim=-1)
dataset = SpiceDataset(xs, ys)

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data
# to get the number of participants n_participants we do:
n_actions = dataset.ys.shape[-1]
n_participants = len(dataset.xs[..., -1].unique())
n_experiments = len(dataset.xs[..., -2].unique())

# split data into training and testing data
dataset_train, dataset_test = split_data_along_sessiondim(dataset, list_test_sessions=test_sessions)

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
print(f"Number of experiments (baseline vs. infusion): {n_experiments}")
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

Shape of dataset: torch.Size([1692, 1078, 10])
Number of participants: 30
Number of experiments (baseline vs. infusion): 2
Number of actions in dataset: 2
Number of additional inputs: 3


## SPICE Setup

Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [None]:
spice_config = SpiceConfig(
    library_setup={
        'value_distance': [
            'distance[t]',
            'distance[t-1]',
            'distance[t-2]',
            'distance[t-3]',
            ],
        'value_laser_volatility': [  # rather noise in task language
            'laser_volatility[t]',      # captures noise (stochasticity) and mean jumps (volatility) 
            'laser_volatility[t-1]',
            'laser_volatility[t-2]',
            'laser_volatility[t-3]',
            ],
        'value_reward': [
            'reward[t]',
            'reward[t-1]',
            'reward[t-2]',
            'reward[t-3]',
        ],
        'value_stay': [
            'total_reward',
            'reward',
            'distance',
            'laser_volatility',
            'move_switch',
        ],
        'value_move': [
            'total_reward',
            'reward',
            'distance',
            'laser_volatility',
            'move_switch',
        ],
        # 'value_{move/stay}': [
        #     'reward',
        #     'distance',
        #     'volatility',
        #     'move_switch',
        # ],
        # 'urgency': [
        #     'totalReward',
        #     ],
    },
    
    memory_state=[
            'value_action',
            'value_distance',
            'value_laser_volatility',
            'value_reward',
            'distance[t-1]',
            'distance[t-2]',
            'distance[t-3]',
            # 'laser_rotation[t-1]',
            # 'laser_rotation[t-2]',
            # 'laser_rotation[t-3]',
            'laser_volatility[t-1]',
            'laser_volatility[t-2]',
            'laser_volatility[t-3]',
            'reward[t-1]',
            'reward[t-2]',
            'reward[t-3]',
    ],
    
    states_in_logit=['value_action'],
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [None]:
class SPICERNN(BaseRNN):
    
    def __init__(self, spice_config, **kwargs):
        super().__init__(spice_config=spice_config, **kwargs)
        
        dropout = 0.1
        
        # participant embedding
        self.participant_embedding = self.setup_embedding(num_embeddings=self.n_participants, embedding_size=self.embedding_size, dropout=dropout)
        
        # set up the submodules
        # the inputs to the modules will be: (actual_inputs, participant_embedding, infusion_flag) #+ experiment conditions (sthoch, volatility)
        self.setup_module(key_module='value_distance', input_size=4+self.embedding_size+1, dropout=dropout)
        self.setup_module(key_module='value_laser_volatility', input_size=4+self.embedding_size+1, dropout=dropout)  # could describe diff between stable and volatile state (mean jump) (e.g. attention weights)
        self.setup_module(key_module='value_reward', input_size=4+self.embedding_size+1, dropout=dropout)
        self.setup_module(key_module='value_stay', input_size=5+self.embedding_size+1)
        self.setup_module(key_module='value_move', input_size=5+self.embedding_size+1)
        
    def forward(self, inputs, prev_state, batch_first=False):
        
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)
        
        # transform rewards from action dependent (action: (1, 0) -> reward: (0.3, nan)) 
        # to same for all actions (action: (1, 0) -> reward: (0.3, 0.3))
        rewards = spice_signals.rewards.nan_to_num(0).sum(dim=-1, keepdim=True).repeat(1, 1, self.n_actions)
        
        # the additional inputs are as defined in the dataset cell: (laser_rotation, shield_rotation, total_reward)
        laser_rotation = spice_signals.additional_inputs[..., 0][:, :, None].repeat(1, 1, self.n_actions)
        shield_rotation = spice_signals.additional_inputs[..., 1][:, :, None].repeat(1, 1, self.n_actions)
        total_reward = spice_signals.additional_inputs[..., 2][:, :, None].repeat(1, 1, self.n_actions)
        
        # compute additional features (could also be done in dataset cell)
        distance = (laser_rotation - shield_rotation).abs()
        # shield move direction switch (switch between stay, move clockwise, move counter clockwise)
        move_switch = torch.zeros_like(distance)
        move_switch[1:-1] = shield_rotation.diff(dim=0).sign().diff(dim=0).abs()
        # laser volatility
        laser_volatility = torch.zeros_like(distance)
        laser_volatility[1:] = laser_rotation.diff(dim=0)
        
        # time-invariant participant features
        participant_embeddings = self.participant_embedding(spice_signals.participant_ids)
        experiment_embeddings = spice_signals.experiment_ids.reshape(-1, 1)
        
        # masks for action selection
        mask_action_stay = torch.zeros(spice_signals.actions.shape[1:])
        mask_action_stay[:, 0] = 1
        mask_action_move = torch.zeros(spice_signals.actions.shape[1:])
        mask_action_move[:, 1] = 1
        
        for timestep in spice_signals.timesteps:
            
            self.call_module(
                key_module='value_distance',
                key_state='value_distance',
                inputs=(
                    distance[timestep],
                    self.state['distance[t-1]'],
                    self.state['distance[t-2]'],
                    self.state['distance[t-3]'],
                ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                experiment_index=spice_signals.experiment_ids,
                experiment_embedding=experiment_embeddings,
            )
            
            self.call_module(
                key_module='value_laser_volatility',
                key_state='value_laser_volatility',
                inputs=(
                    laser_volatility[timestep],
                    self.state['laser_volatility[t-1]'],
                    self.state['laser_volatility[t-2]'],
                    self.state['laser_volatility[t-3]'],
                ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                experiment_index=spice_signals.experiment_ids,
                experiment_embedding=experiment_embeddings,
            )
            
            self.call_module(
                key_module='value_reward',
                key_state='value_reward',
                inputs=(
                    rewards[timestep],
                    self.state['reward[t-1]'],
                    self.state['reward[t-2]'],
                    self.state['reward[t-3]'],
                ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                experiment_index=spice_signals.experiment_ids,
                experiment_embedding=experiment_embeddings,
            )
            
            self.call_module(
                key_module='value_stay',
                key_state='value_action',
                action_mask=mask_action_stay,
                inputs=(
                    total_reward[timestep],
                    self.state['value_reward'],
                    self.state['value_distance'],
                    self.state['value_laser_volatility'],
                    move_switch[timestep],
                ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                experiment_index=spice_signals.experiment_ids,
                experiment_embedding=experiment_embeddings,
            )
            
            self.call_module(
                key_module='value_move',
                key_state='value_action',
                action_mask=mask_action_move,
                inputs=(
                    total_reward[timestep],
                    self.state['value_reward'],
                    self.state['value_distance'],
                    self.state['value_laser_volatility'],
                    move_switch[timestep],
                ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                experiment_index=spice_signals.experiment_ids,
                experiment_embedding=experiment_embeddings,
            )
            
            # save action values as logits
            spice_signals.logits[timestep] = self.state['value_action']
            
            # update working memory buffer
            self.state['distance[t-1]'], self.state['distance[t-2]'], self.state['distance[t-3]'] = distance[timestep], self.state['distance[t-1]'], self.state['distance[t-2]']
            self.state['reward[t-1]'], self.state['reward[t-2]'], self.state['reward[t-3]'] = rewards[timestep], self.state['reward[t-1]'], self.state['reward[t-2]']
            self.state['laser_volatility[t-1]'], self.state['laser_volatility[t-2]'], self.state['laser_volatility[t-3]'] = laser_volatility[timestep], self.state['laser_volatility[t-1]'], self.state['laser_volatility[t-2]']
            
        spice_signals = self.post_forward_pass(spice_signals, batch_first)
        
        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [None]:
path_spice = '../params/weber2024/spice_weber2024.pkl'

estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SPICERNN,
        spice_config=spice_config,
        n_actions=n_actions,
        n_participants=n_participants,
        n_experiments=n_experiments,
        
        # rnn training parameters
        epochs=2,
        warmup_steps=200,
        learning_rate=0.01,
        
        # sindy fitting parameters
        sindy_weight=0.1,
        sindy_pruning_threshold=0.05,
        sindy_pruning_frequency=1,
        sindy_pruning_terms=1,
        sindy_pruning_patience=100,
        sindy_epochs=1000,
        sindy_l2_lambda=0.0001,
        sindy_library_polynomial_degree=2,
        sindy_ensemble_size=1,
        
        # additional generalization parameters
        batch_size=1024,
        bagging=True,
        scheduler=True,
        
        verbose=True,
        save_path_spice=path_spice,
    )

In [None]:
print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset_train.xs, dataset_train.ys, dataset_test.xs, dataset_test.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cpu...

Training the RNN...
Epoch 1/2 --- L(Train): 6.7472973 --- L(Val, RNN): 2.5669329 --- L(Val, SINDy): 7.9584112 --- Time: 43.52s; --- Convergence: 1.28e+00; LR: 1.00e-02; Metric: inf; Bad epochs: 0/100
--------------------------------------------------------------------------------
SPICE Model (Coefficients: 119):
value_distance[t+1] = 0.0 1 + 1.001 value_distance[t] + 0.0 distance[t] + 0.001 distance[t-1] + 0.001 distance[t-2] + 0.001 distance[t-3] + -0.0 value_distance^2 + 0.0 value_distance*distance[t] + 0.001 value_distance*distance[t-1] + 0.001 value_distance*distance[t-2] + -0.0 value_distance*distance[t-3] + 0.001 distance[t]^2 + 0.001 distance[t]*distance[t-1] + 0.001 distance[t]*distance[t-2] + 0.0 distance[t]*distance[t-3] + -0.0 distance[t-1]^2 + -0.002 distance[t-1]*distance[t-2] + 0.0 distance[t-1]*distance[t-3] + -0.001 distance[t-2]^2 + -0.002 distance[t-2]*distance[t-3] + -0.003 distance[t-3]^2 
value_laser_volatility[t+1] = -0.009 1 + 1.01 v

KeyboardInterrupt: 

In [None]:
estimator.load_spice(path_spice)

## GRU for benchmarking

In [3]:
import sys

sys.path.append('../..')
from weinhardt2025.benchmarking.benchmarking_gru import GRU, training, setup_agent_gru

path_gru = '../../weinhardt2025/params/weber2024/gru_weber2024.pkl'

In [4]:
epochs = 1000

gru = GRU(n_actions=n_actions, additional_inputs=3).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gru.parameters(), lr=0.01)

gru = training(
    gru=gru,
    optimizer=optimizer,
    dataset_train=dataset_train,
    dataset_test=dataset_train,
    epochs=epochs,
    batch_size=1024,
    )

torch.save(gru.state_dict(), path_gru)
print("Trained GRU parameters saved to " + path_gru)

Epoch 1/1000: L(Train): 0.5476048588752747; L(Test): 0.5421544313430786
Epoch 2/1000: L(Train): 0.5668220520019531; L(Test): 0.5406689643859863
Epoch 3/1000: L(Train): 0.5455498099327087; L(Test): 0.5279489159584045
Epoch 4/1000: L(Train): 0.5340773463249207; L(Test): 0.5325810313224792
Epoch 5/1000: L(Train): 0.5427414774894714; L(Test): 0.5351895689964294
Epoch 6/1000: L(Train): 0.5421103835105896; L(Test): 0.5297689437866211
Epoch 7/1000: L(Train): 0.5311235785484314; L(Test): 0.5268921852111816
Epoch 8/1000: L(Train): 0.5373024940490723; L(Test): 0.5304366946220398
Epoch 9/1000: L(Train): 0.5316314101219177; L(Test): 0.5325705409049988
Epoch 10/1000: L(Train): 0.5388942360877991; L(Test): 0.5308281183242798
Epoch 11/1000: L(Train): 0.5284050107002258; L(Test): 0.5276982188224792
Epoch 12/1000: L(Train): 0.5302596092224121; L(Test): 0.5275452136993408
Epoch 13/1000: L(Train): 0.5358317494392395; L(Test): 0.5296292901039124
Epoch 14/1000: L(Train): 0.535137414932251; L(Test): 0.52937

KeyboardInterrupt: 

In [None]:
gru_agent = setup_agent_gru(path_gru)

## Plot SPICE against benchmark models

In [None]:
# plotting
participant_id = 7

estimator.print_spice_model(participant_id)

agents = {
    # add baseline agent here
    'rnn': estimator.rnn_agent,
    'spice': estimator.spice_agent,
    'gru': gru_agent,
}

fig, axs = plot_session(agents, dataset.xs[participant_id])
plt.show()