In [1]:
#!git clone https://github.com/whyhardt/SPICE.git

In [2]:
# !pip install -e SPICE

In [3]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator
from spice.resources.spice_utils import SpiceConfig
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [4]:
# Load your data
dataset = convert_dataset(
    file = '../data/augustat2025/augustat2025.csv',
    df_participant_id='participant_id',
    df_choice='choice',
    df_reward='reward',
    additional_inputs=['shown_at_0', 'shown_at_1'],
    )[0]

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
n_actions = dataset.ys.shape[-1]
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

Shape of dataset: torch.Size([277, 300, 9])
Number of participants: 277
Number of actions in dataset: 2
Number of additional inputs: 2


Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [5]:
spice_config = SpiceConfig(
    library_setup={
        'value_reward_chosen': ['reward'],
        'value_reward_not_chosen': [],
        'value_reward_not_displayed': [],
    },
    
    memory_state={
        'value_reward': 0,
        },
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [6]:
class SPICERNN(BaseRNN):
    
    def __init__(self, n_actions, spice_config, n_participants, n_items, **kwargs):
        super().__init__(n_actions=n_actions, spice_config=spice_config, n_participants=n_participants, n_items=n_items, embedding_size=32)
        
        # participant embedding
        self.participant_embedding = self.setup_embedding(num_embeddings=n_participants, embedding_size=self.embedding_size, dropout=0.)
        
        # rnn modules
        # reward-based modules
        self.submodules_rnn['value_reward_chosen'] = self.setup_module(1+self.embedding_size)
        self.submodules_rnn['value_reward_not_chosen'] = self.setup_module(self.embedding_size)
        self.submodules_rnn['value_reward_not_displayed'] = self.setup_module(self.embedding_size)
        # choice-based modules (reward-agnostic; can encode e.g. choice perseverance)
        # ...
        
        # inverse noise temperatures
        self.betas['value_reward'] = self.setup_constant(embedding_size=self.embedding_size)
    
    def transform_signals_to_item_space(self, actions, rewards, shown_at_0, shown_at_1):
        """Transform actions, rewards, and additional inputs from the action space into the item space"""
        nan_mask = actions[..., 0].isnan()
        
        # map the chosen and not chosen actions to the symbol numbers
        action_chosen = torch.where(actions.argmax(dim=-1) == 0, shown_at_0, shown_at_1).nan_to_num(0.).long()
        action_not_chosen = torch.where(actions.argmin(dim=-1) == 0, shown_at_0, shown_at_1).nan_to_num(0.).long()
        
        # compute one hot encoded arrays for action_chosen and action_not_chosen
        action_chosen_onehot = torch.nn.functional.one_hot(action_chosen, num_classes=self.n_items)
        action_chosen_onehot = torch.where(nan_mask.unsqueeze(-1).repeat(1, 1, self.n_items), torch.nan, action_chosen_onehot)
        
        action_not_chosen_onehot = torch.nn.functional.one_hot(action_not_chosen, num_classes=self.n_items)
        action_not_chosen_onehot = torch.where(nan_mask.unsqueeze(-1).repeat(1, 1, self.n_items), torch.nan, action_not_chosen_onehot)
        
        action_not_displayed_onehot = 1 - (action_chosen_onehot + action_not_chosen_onehot)
        
        # tansform now the rewards for each item instead of only the displayed positions
        rewards_onehot = torch.where(action_chosen_onehot == 1, 1, torch.nan)
        rewards_onehot = rewards_onehot * torch.where(actions.argmax(dim=-1) == 0, rewards[..., 0], rewards[..., 1]).unsqueeze(-1)
        
        # transform shown_at_0 and shown_at_1 to one hot encoded arrays
        shown_at_0_bool = torch.nn.functional.one_hot(shown_at_0.nan_to_num(0.).long(), num_classes=self.n_items)
        shown_at_0_bool = torch.where(nan_mask.unsqueeze(-1).repeat(1, 1, self.n_items), torch.nan, shown_at_0_bool)
        
        shown_at_1_bool = torch.nn.functional.one_hot(shown_at_1.nan_to_num(0.).long(), num_classes=self.n_items)
        shown_at_1_bool = torch.where(nan_mask.unsqueeze(-1).repeat(1, 1, self.n_items), torch.nan, shown_at_1_bool)
        
        return action_chosen_onehot, action_not_chosen_onehot, action_not_displayed_onehot, rewards_onehot, shown_at_0_bool, shown_at_1_bool
    
    def transform_values_to_action_space(self, values, shown_at_0, shown_at_1):
      """Transform the state values from the item space back into the action space"""
      # values: (batch_size, n_items)
      # shown_at_0, shown_at_1: (batch_size, n_items) - one-hot encoded

      # Get the indices of the shown items
      shown_at_0_idx = shown_at_0.argmax(dim=-1)  # (batch_size,)
      shown_at_1_idx = shown_at_1.argmax(dim=-1)  # (batch_size,)

      # Use gather to extract the values for the shown items
      batch_indices = torch.arange(values.shape[0], device=values.device)
      value_at_0 = values[batch_indices, shown_at_0_idx]  # (batch_size,)
      value_at_1 = values[batch_indices, shown_at_1_idx]  # (batch_size,)

      # Stack them to create action space values
      values_action_space = torch.stack([value_at_0, value_at_1], dim=-1)  # (batch_size, n_actions=2)

      # Mask out invalid timesteps (where shown_at_0 is all NaN)
      valid_mask = ~shown_at_0.sum(dim=-1).isnan()
      values_action_space = torch.where(
          valid_mask.unsqueeze(-1).expand_as(values_action_space),
          values_action_space,
          torch.zeros_like(values_action_space)  # Use 0 instead of NaN
      )

      return values_action_space
    
    def forward(self, inputs, prev_state, batch_first=False):
        
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)
        
        # transform actions from positional encoding (0: right, 1: left) -> one hot encoded items
        shown_at_0, shown_at_1 = spice_signals.additional_inputs[..., 0]-1, spice_signals.additional_inputs[..., 1]-1
        actions_chosen, actions_not_chosen, actions_not_displayed, rewards, shown_at_0, shown_at_1 = self.transform_signals_to_item_space(spice_signals.actions, spice_signals.rewards, shown_at_0, shown_at_1)
        
        # time-invariant participant features
        participant_embeddings = self.participant_embedding(spice_signals.participant_ids)
        beta_reward = self.betas['value_reward'](participant_embeddings)
        
        for timestep in spice_signals.timesteps:
            
            # update chosen value
            self.call_module(
                key_module='value_reward_chosen',
                key_state='value_reward',
                action_mask=actions_chosen[timestep],
                inputs=rewards[timestep],
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            # update not chosen value
            self.call_module(
                key_module='value_reward_not_chosen',
                key_state='value_reward',
                action_mask=actions_not_chosen[timestep],
                inputs=None,
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            # update not displayed values
            self.call_module(
                key_module='value_reward_not_displayed',
                key_state='value_reward',
                action_mask=actions_not_displayed[timestep],
                inputs=None,
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            # transform logits from item-space to action-space
            spice_signals.logits[timestep] = self.transform_values_to_action_space(self.state['value_reward'], shown_at_0[timestep], shown_at_1[timestep]) * beta_reward
            
        spice_signals = self.post_forward_pass(spice_signals, batch_first)
        
        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [7]:
estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SPICERNN,
        spice_config=spice_config,
        n_actions=2,
        n_items=6,
        n_participants=n_participants,
        n_experiments=1,
        
        # training parameters
        epochs=10,
        l2_weight_decay=0.01,
        sindy_epochs=10,
        sindy_threshold=0.1,
        sindy_threshold_frequency=100,
        sindy_weight=0.01,
        sindy_library_polynomial_degree=2,
        verbose=True,
        save_path_spice='../params/augustat2025/spice_augustat2025.pkl',
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cpu...

Training the RNN...
Epoch 10/10 --- L(Train): 0.7108493; Time: 2.33s; Convergence: 2.38e-03
Maximum number of training epochs reached.
Model did not converge yet.
Starting second stage SINDy fitting (threshold=0, single model)
SINDy Stage 2 - Epoch 10/10 --- L(Train): 0.0016816; Time: 0.02s
Second stage SINDy fitting complete!

Refitted SPICE model (participant 0):
--------------------------------------------------------------------------------
value_reward_chosen[t+1] = 0.0234 1 + 0.0012 value_reward_chosen[t] + 0.0021 value_reward_chosen^2 + -0.0026 value_reward_chosen*reward + -0.0018 reward^2 
value_reward_not_chosen[t+1] = 0.0205 1 + -0.0025 value_reward_not_chosen[t] + -0.0054 value_reward_not_chosen^2 
value_reward_not_displayed[t+1] = 0.0169 1 + -0.0273 value_reward_not_displayed[t] + -0.0311 value_reward_not_displayed^2 
beta(value_reward) = -0.0114
--------------------------------------------------------------------------------

RNN training fini