In [1]:
#!pip install -e SPICE --no-deps

In [2]:
!ls ~/SPICE/examples/

ls: cannot access '/home/daniel/SPICE/examples/': No such file or directory


In [3]:
#!git clone https://github.com/whyhardt/SPICE.git

In [4]:
#!git clone https://github.com/whyhardt/SPICE.git
!python --version

import sys
import os

# 1. Get the absolute path to your cloned folder
# Assuming the folder 'SPICE' is in your current working directory
project_root = os.path.abspath("SPICE")

# 2. Check for the likely internal structure
# Many modern packages hide the actual code inside a 'src' folder
if os.path.exists(os.path.join(project_root, "src")):
    package_path = os.path.join(project_root, "src")
else:
    package_path = project_root

if package_path not in sys.path:
    sys.path.append(package_path)
    print(f"✅ Added {package_path} to system path")

# 4. Now try the import
try:
    from spice.estimator import SpiceEstimator
    print("✅ Success! SpiceEstimator imported.")
except ImportError as e:
    print(f"❌ Still failing. Debug info below:")
    print(f"Contents of {package_path}: {os.listdir(package_path)}")
    raise e

Python 3.11.13
✅ Added /home/daniel/repositories/SPICE/weinhardt2025/aux/SPICE to system path
✅ Success! SpiceEstimator imported.


In [5]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator
from spice.resources.spice_utils import SpiceConfig
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

In [6]:
# Load your hover data
dataset = convert_dataset(
    file='../data/huang2025/huang2025.csv',
    df_participant_id = "subject_ID",
    df_block ='currentRound',
    df_choice = 'hover_tile_index',
    df_reward = 'score',
    additional_inputs = ['partner_tile_index', 'sample_number']
    )

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)


# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())
n_actions = dataset.ys.shape[-1]

In [7]:
spice_config = SpiceConfig(
    library_setup={
        'visited_self': [],
        'visited_partner': [],
        'not_visited': [],
    },
    
    memory_state={
        'value_tile': 0,
    },
)

In [8]:
class HoverRNN(BaseRNN):
    """
    Custom RNN for modeling hover behavior.

    CRITICAL: Must match the interface expected by SPICE!
    The RNN should:
    - Take input of shape (batch, seq_len, input_size)
    - Return output of shape (batch, seq_len, output_size)
    - Optionally return hidden states
    """

    def __init__(self, spice_config, n_actions, n_participants, **kwargs):
        super(HoverRNN, self).__init__(
            spice_config=spice_config,
            n_actions=n_actions, 
            n_participants=n_participants, 
            embedding_size=32,
            sindy_polynomial_degree=2,
            )

        dropout = 0.1
        
        self.participant_embedding = self.setup_embedding(n_participants, self.embedding_size, dropout=dropout)
        
        # Value learning module (slow updates)
        # Can use recent reward history to modulate learning
        self.setup_module(key_module='visited_self', input_size=self.embedding_size, dropout=dropout)
        self.setup_module(key_module='visited_partner', input_size=self.embedding_size, dropout=dropout)
        self.setup_module(key_module='not_visited', input_size=self.embedding_size, dropout=dropout)
        
    def forward(self, inputs, prev_state=None, batch_first=False):
        """
        Forward pass.

        Args:
            inputs: Tuple containing (actions, rewards, additional_inputs, participant_ids)
            prev_state: Optional previous hidden state
            batch_first: Whether first dimension is batch (True) or timesteps (False)

        Returns:
            logits: (batch, seq_len, n_actions) - Action logits for each tile
            state: Updated hidden state dictionary
        """

        # Initialize inputs, outputs, and timesteps
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)

        nan_mask = spice_signals.actions.isnan()
        tiles_visited_partner = spice_signals.additional_inputs[..., 0].nan_to_num(0).long()
        actions_partner = torch.nn.functional.one_hot(tiles_visited_partner, num_classes=self.n_actions)
        actions_partner = torch.where(nan_mask, torch.nan, actions_partner)
        actions_not_visited = 1 - (spice_signals.actions + actions_partner)
        
        # Get participant embeddings
        participant_embedding = self.participant_embedding(spice_signals.participant_ids)
        
        # Main loop: process each timestep
        for timestep in spice_signals.timesteps:
            
            # Update value for tile visited by self
            self.call_module(
                key_module='visited_self',
                key_state='value_tile',
                action_mask=spice_signals.actions[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
            )
            
            # Update value for tile visited by partner
            self.call_module(
                key_module='visited_partner',
                key_state='value_tile',
                action_mask=actions_partner[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
            )
            
            self.call_module(
                key_module='not_visited',
                key_state='value_tile',
                action_mask=actions_not_visited[timestep],
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=spice_signals.participant_ids,
            )
            
            # Apply beta parameters for each tile and compute logits
            spice_signals.logits[timestep] = self.state['value_tile']
        
        # Post-process the forward pass
        spice_signals = self.post_forward_pass(spice_signals, batch_first)

        return spice_signals.logits, self.get_state()

In [None]:
path_spice = 'spice_huang2025.pkl'

estimator = SpiceEstimator(
        device = torch.device("cuda"),
        # SPICE parameters
        rnn_class=HoverRNN,
        spice_config=spice_config,
        n_actions=n_actions,
        n_participants=n_participants,
        
        # rnn training parameters
        epochs=1000, # better: 4000 or 5000
        warmup_steps=1000,
        learning_rate=0.01,
        batch_size=512,
        bagging=True
        
        # sindy fitting parameters
        sindy_weight=1,
        sindy_threshold=0.05,
        sindy_threshold_frequency=1,
        sindy_threshold_terms=1,
        sindy_cutoff_patience=100,
        sindy_epochs=1000,
        sindy_alpha=0.001,
        sindy_library_polynomial_degree=2,
        
        save_path_spice=path_spice,
    )



print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cuda...

Training the RNN...

Starting second stage SINDy fitting (threshold=0.05, single model)
Epoch 1/1000 --- L(Train): 0.4788740 --- L(Val): 0.0000000 --- Time: 0.13s;
--------------------------------------------------------------------------------
SPICE Model (Coefficients: 9):
visited_self[t+1] = 0.102 1 + 0.9 visited_self[t] + 0.101 visited_self^2 
visited_partner[t+1] = -0.099 1 + 1.1 visited_partner[t] + -0.099 visited_partner^2 
not_visited[t+1] = -0.101 1 + 0.899 not_visited[t] + -0.1 not_visited^2 
--------------------------------------------------------------------------------
Cutoff patience:
visited_self: [0, 0, 0]
visited_partner: [0, 0, 0]
not_visited: [0, 0, 0]
Epoch 2/1000 --- L(Train): 0.4259432 --- L(Val): 0.0000000 --- Time: 0.09s;
--------------------------------------------------------------------------------
SPICE Model (Coefficients: 9):
visited_self[t+1] = 0.189 1 + 0.801 visited_self[t] + 0.183 visited_self^2 
visited_partner[t+1] = -0

KeyboardInterrupt: 

In [None]:
estimator.print_spice_model(participant_id=1)

In [None]:
estimator.rnn_model.sindy_coefficients['visited_self']