In [1]:
#!git clone https://github.com/whyhardt/SPICE.git

In [2]:
# !pip install -e SPICE

In [3]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator
from spice.resources.spice_utils import SpiceConfig
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [4]:
# Load your hover data
dataset = convert_dataset(
    file='../data/huang2025/huang2025.csv',
    df_participant_id = "subject_ID",
    df_block ='currentRound',
    df_choice = 'hover_tile_index',
    df_reward = 'score',
    additional_inputs = ['partner_tile_index', 'sample_number']
    )[0]

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)


# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [5]:
spice_config = SpiceConfig(
    library_setup={
        'visited_self': [],
        'visited_partner': [],
        'not_visited': [],
    },
    
    memory_state={
        'value_tile': 0,
    },
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs, prev_state (SpiceDataset.xs), (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [6]:
class HoverRNN(BaseRNN):
    """
    Custom RNN for modeling hover behavior.

    CRITICAL: Must match the interface expected by SPICE!
    The RNN should:
    - Take input of shape (batch, seq_len, input_size)
    - Return output of shape (batch, seq_len, output_size)
    - Optionally return hidden states
    """

    def __init__(self, spice_config, n_actions, n_participants, **kwargs):
        super(HoverRNN, self).__init__(
            spice_config=spice_config,
            n_actions=n_actions, 
            n_participants=n_participants, 
            embedding_size=32,
            sindy_polynomial_degree=2,
            )

        self.participant_embedding = self.setup_embedding(n_participants, self.embedding_size, dropout=0.)
        
        self.betas['value_tile'] = self.setup_constant(embedding_size=self.embedding_size)

        # Value learning module (slow updates)
        # Can use recent reward history to modulate learning
        self.submodules_rnn['visited_self'] = self.setup_module(input_size=self.embedding_size)
        self.submodules_rnn['visited_partner'] = self.setup_module(input_size=self.embedding_size)
        self.submodules_rnn['not_visited'] = self.setup_module(input_size=self.embedding_size)

    def forward(self, inputs, prev_state=None, batch_first=False):
        """
        Forward pass.

        Args:
            inputs: Tuple containing (actions, rewards, additional_inputs, participant_ids)
            prev_state: Optional previous hidden state
            batch_first: Whether first dimension is batch (True) or timesteps (False)

        Returns:
            logits: (batch, seq_len, n_actions) - Action logits for each tile
            state: Updated hidden state dictionary
        """

        # Initialize inputs, outputs, and timesteps
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)

        nan_mask = spice_signals.actions.isnan()
        tiles_visited_partner = spice_signals.additional_inputs[..., 0].nan_to_num(0).long()
        actions_partner = torch.nn.functional.one_hot(tiles_visited_partner, num_classes=self.n_actions)
        actions_partner = torch.where(nan_mask, torch.nan, actions_partner)
        actions_not_visited = 1 - (spice_signals.actions + actions_partner)
        
        # Get participant embeddings
        participant_embedding = self.participant_embedding(spice_signals.participant_ids)
        beta = self.betas['value_tile'](participant_embedding)
        
        # Main loop: process each timestep
        for timestep in spice_signals.timesteps:
            
            # Update value for tile visited by self
            self.call_module(
                key_module='visited_self',
                key_state=f'value_tile',
                action_mask=spice_signals.actions[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            # Update value for tile visited by partner
            self.call_module(
                key_module='visited_partner',
                key_state=f'value_tile',
                action_mask=actions_partner[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            self.call_module(
                key_module='not_visited',
                key_state='value_tile',
                action_mask=actions_not_visited[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            # Apply beta parameters for each tile and compute logits
            spice_signals.logits[timestep] = self.state['value_tile'] * beta
        
        # Post-process the forward pass
        spice_signals = self.post_forward_pass(spice_signals, batch_first)

        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [7]:
path_spice = '../params/huang2025/spice_huang2025.pkl'

estimator = SpiceEstimator(
        
        # SPICE parameters
        rnn_class=HoverRNN,
        spice_config=spice_config,
        n_actions=16,
        n_participants=n_participants,
        
        # RNN training parameters
        epochs=10,
        learning_rate=1e-2,
        sindy_weight=0.1,
        
        # SINDy training parameters
        sindy_library_polynomial_degree=2,
        sindy_threshold_frequency=100,
        sindy_threshold=0.01,
        sindy_epochs=10,
        
        save_path_spice=path_spice,
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
# estimator.load_spice(args.model)
estimator.fit(dataset.xs, dataset.ys)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cpu...

Training the RNN...
Epoch 10/10 --- L(Train): 2.8993702; Time: 4.14s; Convergence: 2.80e-02
Maximum number of training epochs reached.
Model did not converge yet.
Starting second stage SINDy fitting (threshold=0, single model)
SINDy Stage 2 - Epoch 10/10 --- L(Train): 0.0018842; Time: 0.37s
Second stage SINDy fitting complete!
Saving SPICE model to ../params/huang2025/spice_huang2025.pkl...

Training complete!

Example SPICE model (participant 0):
--------------------------------------------------------------------------------
visited_self[t+1] = 0.0452 1 + -0.0222 visited_self[t] + -0.0415 visited_self^2 
visited_partner[t+1] = 0.035 1 + -0.0191 visited_partner[t] + -0.0345 visited_partner^2 
not_visited[t+1] = 0.0416 1 + -0.0696 not_visited[t] + -0.0825 not_visited^2 
beta(value_tile) = 1.8157
--------------------------------------------------------------------------------
