# How to use SPICE

## Initialization

In [2]:
#!git clone https://github.com/whyhardt/SPICE.git

In [3]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator
from spice.resources.spice_utils import SpiceConfig
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

## Load dataset

In [4]:
# Load your hover data
dataset = convert_dataset(
    file='../data/huang2025/huang2025.csv',
    df_participant_id = "subject_ID",
    df_block ='currentRound',
    df_choice = 'hover_tile_index',
    df_reward = 'score',
    additional_inputs = ['partner_tile_index', 'sample_number']
    )

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)


# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())
n_actions = dataset.ys.shape[-1]

## Setup model

In [5]:
spice_config = SpiceConfig(
    library_setup={
        'visited_self': [],
        'visited_partner': [],
        'not_visited': [],
    },
    
    memory_state={
        'value_tile': 0,
    },
)

In [12]:
class HoverRNN(BaseRNN):
    """
    Custom RNN for modeling hover behavior.

    CRITICAL: Must match the interface expected by SPICE!
    The RNN should:
    - Take input of shape (batch, seq_len, input_size)
    - Return output of shape (batch, seq_len, output_size)
    - Optionally return hidden states
    """

    def __init__(self, spice_config, n_actions, n_participants, sindy_polynomial_degree = 2, **kwargs):
        super(HoverRNN, self).__init__(
            spice_config=spice_config,
            n_actions=n_actions, 
            n_participants=n_participants, 
            embedding_size=32,
            sindy_polynomial_degree=sindy_polynomial_degree,
            **kwargs,
            )
            
        dropout = 0.1
        
        self.participant_embedding = self.setup_embedding(n_participants, self.embedding_size, dropout=dropout)
        
        # Value learning module (slow updates)
        # Can use recent reward history to modulate learning
        self.setup_module(key_module='visited_self', input_size=self.embedding_size, dropout=dropout)
        self.setup_module(key_module='visited_partner', input_size=self.embedding_size, dropout=dropout)
        self.setup_module(key_module='not_visited', input_size=self.embedding_size, dropout=dropout)
        
    def forward(self, inputs, prev_state=None, batch_first=False):
        """
        Forward pass.

        Args:
            inputs: Tuple containing (actions, rewards, additional_inputs, participant_ids)
            prev_state: Optional previous hidden state
            batch_first: Whether first dimension is batch (True) or timesteps (False)

        Returns:
            logits: (batch, seq_len, n_actions) - Action logits for each tile
            state: Updated hidden state dictionary
        """

        # Initialize inputs, outputs, and timesteps
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)

        nan_mask = spice_signals.actions.isnan()
        tiles_visited_partner = spice_signals.additional_inputs[..., 0].nan_to_num(0).long()
        actions_partner = torch.nn.functional.one_hot(tiles_visited_partner, num_classes=self.n_actions)
        actions_partner = torch.where(nan_mask, torch.nan, actions_partner)
        actions_not_visited = 1 - (spice_signals.actions + actions_partner)
        
        # Get participant embeddings
        participant_embedding = self.participant_embedding(spice_signals.participant_ids)
        
        # Main loop: process each timestep
        for timestep in spice_signals.timesteps:
            
            # Update value for tile visited by self
            self.call_module(
                key_module='visited_self',
                key_state='value_tile',
                action_mask=spice_signals.actions[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
            )
            
            # Update value for tile visited by partner
            self.call_module(
                key_module='visited_partner',
                key_state='value_tile',
                action_mask=actions_partner[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
            )
            
            self.call_module(
                key_module='not_visited',
                key_state='value_tile',
                action_mask=actions_not_visited[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
            )
            
            # Apply beta parameters for each tile and compute logits
            spice_signals.logits[timestep] = self.state['value_tile']
        
        # Post-process the forward pass
        spice_signals = self.post_forward_pass(spice_signals, batch_first)

        return spice_signals.logits, self.get_state()

## Training

In [None]:
path_spice = 'spice_huang2025.pkl'

# rapid setup:
# epochs=1000; warmup_steps=100; sindy_weight=0.001;
# recommended setup:
# epochs=4000; warmup_steps=1000; sindy_weight=1;

estimator = SpiceEstimator(
        device = torch.device("cuda"),
        # SPICE parameters
        rnn_class=HoverRNN,
        spice_config=spice_config,
        n_actions=n_actions,
        n_participants=n_participants,
        
        # rnn training parameters
        epochs=1000,
        warmup_steps=100,
        learning_rate=0.01,
        batch_size=512,
        bagging=True,
        
        # sindy fitting parameters
        sindy_weight=0.1,
        sindy_pruning_threshold=0.05,
        sindy_pruning_frequency=1,
        sindy_pruning_terms=1,
        sindy_pruning_patience=100,
        sindy_epochs=1000,
        sindy_l2_lambda=0.0001,
        sindy_library_polynomial_degree=2,
        sindy_ensemble_size=1,
        
        save_path_spice=path_spice,
        verbose=True,
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cuda...

Training the RNN...
Epoch 1/1 --- L(Train): 2.3133764 --- Time: 6.31s; --- Convergence: 1.16e+00
--------------------------------------------------------------------------------
SPICE Model (Coefficients: 9):
visited_self[t+1] = 0.004 1 + 1.016 visited_self[t] + 0.037 visited_self^2 
visited_partner[t+1] = 0.0 1 + 0.999 visited_partner[t] + -0.004 visited_partner^2 
not_visited[t+1] = -0.004 1 + 0.975 not_visited[t] + -0.034 not_visited^2 
--------------------------------------------------------------------------------
Cutoff patience:
visited_self: [0, 0, 0]
visited_partner: [0, 0, 0]
not_visited: [0, 0, 0]
Maximum number of training epochs reached.
Model did not converge yet.

Starting second stage SINDy fitting (threshold=0.05, single model)
Epoch 1/10 --- L(Train): 0.2234384 --- L(Val): 0.0000000 --- Time: 0.15s;
--------------------------------------------------------------------------------
SPICE Model (Coefficients: 9):
visited_self[t+1] = 0.101 1 

## Analysis

In [13]:
# load the model from the saved file

# 1. initialize SpiceEstimator if not done already (training specific parameters are not necessarily needed but can be added for convenience)
estimator = SpiceEstimator(
        rnn_class=HoverRNN,
        spice_config=spice_config,
        n_actions=n_actions,
        n_participants=n_participants,
        sindy_library_polynomial_degree=2,
)

# 2. load the mode from your file
estimator.load_spice(path_spice)

In [14]:
# print a spice model with
estimator.print_spice_model(participant_id=0)

visited_self[t+1] = 0.865 1 + 0.061 visited_self[t] + 0.264 visited_self^2 
visited_partner[t+1] = -0.011 1 + 0.413 visited_partner[t] + -0.02 visited_partner^2 
not_visited[t+1] = -0.389 1 + 0.409 not_visited[t] + -0.253 not_visited^2 


In [23]:
# extract the trained coefficients for further (statistical) analysis 

# module dictionary with each value holding the coefficients for this module in the shape (n_participants, n_experiments, n_ensemble, n_coefficients_module)
# usually it's fine to investigate only the first ensemble member
coefs = estimator.get_sindy_coefficients()

# module dictionary with each value holding the candidate terms (variable names) for this module
candidate_terms = estimator.get_candidate_terms()

participant_id = 0
print('Coefficients of participant', participant_id)
for module in estimator.get_modules():
    coefs_module = np.round(coefs[module][participant_id, 0, 0], 3)
    candidate_terms_module = candidate_terms[module]
    print('\nModule: ', module)
    for term, coef in zip(candidate_terms_module, coefs_module):
        print(term+': ', coef)

Coefficients of participant 0

Module:  visited_self
1:  0.865
visited_self:  0.061
visited_self^2:  0.264

Module:  visited_partner
1:  -0.011
visited_partner:  0.413
visited_partner^2:  -0.02

Module:  not_visited
1:  -0.389
not_visited:  0.409
not_visited^2:  -0.253
