In [1]:
#!git clone https://github.com/whyhardt/SPICE.git

In [2]:
# !pip install -e SPICE

In [3]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator, SpiceDataset
from spice.resources.spice_utils import SpiceConfig
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [4]:
# Load your data
dataset = convert_dataset(
    file = '../data/augustat2025/augustat2025.csv',
    df_participant_id='participant_id',
    df_choice='choice',
    df_reward='reward',
    additional_inputs=['shown_at_0', 'shown_at_1'],
    timeshift_additional_inputs=False,
    )

# instead of timeshift add the predictor states shown_at_0 and shown_at_1 of the next trial to the inputs
xs = dataset.xs[:, :-1]
ys = dataset.ys[:, :-1]
shown_next = dataset.xs[:, 1:, 2*2:-3]
xs = torch.concat((xs[..., :-3], shown_next, xs[..., -3:]), dim=-1)
dataset = SpiceDataset(xs, ys)

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
n_actions = dataset.ys.shape[-1]
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

Shape of dataset: torch.Size([277, 299, 11])
Number of participants: 277
Number of actions in dataset: 2
Number of additional inputs: 4


Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [5]:
spice_config = SpiceConfig(
    library_setup={
        'value_reward_chosen': ['reward'],
        'value_reward_not_chosen': [],
        'value_reward_not_displayed': [],
    },
    
    memory_state={
        'value_reward': 0.5,
        },
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [6]:
class SPICERNN(BaseRNN):

    def __init__(
        self, 
        n_actions, 
        spice_config, 
        n_participants, 
        n_items,
        use_sindy=False, 
        **kwargs):
        super().__init__(
            n_actions=n_actions, 
            spice_config=spice_config,
            n_participants=n_participants, 
            n_items=n_items, 
            embedding_size=32,
            sindy_ensemble_size=1,
            use_sindy=use_sindy,
            )

        self.participant_embedding = self.setup_embedding(num_embeddings=n_participants, embedding_size=self.embedding_size)

        self.submodules_rnn['value_reward_chosen'] = self.setup_module(1+self.embedding_size)
        self.submodules_rnn['value_reward_not_chosen'] = self.setup_module(self.embedding_size)
        self.submodules_rnn['value_reward_not_displayed'] = self.setup_module(self.embedding_size)

    def forward(self, inputs, prev_state, batch_first=False):

        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)

        # Get shown items (raw indices) - these are time-shifted, so they refer to the NEXT trial
        shown_at_0_current = spice_signals.additional_inputs[..., 0].long()
        shown_at_1_current = spice_signals.additional_inputs[..., 1].long()
        shown_at_0_next = spice_signals.additional_inputs[..., 2].long()
        shown_at_1_next = spice_signals.additional_inputs[..., 3].long()

        participant_embeddings = self.participant_embedding(spice_signals.participant_ids)

        for timestep in spice_signals.timesteps:

            # Transform input data from action space to item space

            # Determine which action was chosen
            action_idx = spice_signals.actions[timestep].argmax(dim=-1)

            # Map to item indices using current trial's shown items
            item_chosen_idx = torch.where(action_idx == 0, shown_at_0_current[timestep], shown_at_1_current[timestep])
            item_not_chosen_idx = torch.where(action_idx == 1, shown_at_0_current[timestep], shown_at_1_current[timestep])

            # Create one-hot masks
            item_chosen_onehot = torch.nn.functional.one_hot(item_chosen_idx, num_classes=self.n_items).float()
            item_not_chosen_onehot = torch.nn.functional.one_hot(item_not_chosen_idx, num_classes=self.n_items).float()
            item_not_displayed_onehot = 1 - (item_chosen_onehot + item_not_chosen_onehot)

            # Map rewards from action space to item space
            reward_action = spice_signals.rewards[timestep, :]  # shape: (batch, n_actions)

            # Create reward tensor in item space (batch, n_items)
            reward_item = torch.zeros(reward_action.shape[0], self.n_items, device=reward_action.device)

            # Scatter rewards to the corresponding items:
            # Item at shown_at_0_current gets reward for action 0
            # Item at shown_at_1_current gets reward for action 1
            reward_item.scatter_(1, shown_at_0_current[timestep].unsqueeze(-1), reward_action[:, 0].unsqueeze(-1))
            reward_item.scatter_(1, shown_at_1_current[timestep].unsqueeze(-1), reward_action[:, 1].unsqueeze(-1))
            
            # Update chosen
            self.call_module(
                key_module='value_reward_chosen',
                key_state='value_reward',
                action_mask=item_chosen_onehot,
                inputs=reward_item,
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )

            # Update not chosen
            self.call_module(
                key_module='value_reward_not_chosen',
                key_state='value_reward',
                action_mask=item_not_chosen_onehot,
                inputs=None,
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )

            # Update not displayed
            self.call_module(
                key_module='value_reward_not_displayed',
                key_state='value_reward',
                action_mask=item_not_displayed_onehot,
                inputs=None,
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )

            # Transform values from item space to action space for NEXT trial (for prediction)
            # Use the time-shifted items (next trial's items)
            value_at_0 = torch.gather(self.state['value_reward'], 1, shown_at_0_next[timestep].unsqueeze(-1))
            value_at_1 = torch.gather(self.state['value_reward'], 1, shown_at_1_next[timestep].unsqueeze(-1))

            # log action values
            spice_signals.logits[timestep] = torch.concat([value_at_0, value_at_1], dim=-1)

        spice_signals = self.post_forward_pass(spice_signals, batch_first)

        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [None]:
estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SPICERNN,
        spice_config=spice_config,
        n_actions=2,
        n_items=6,
        n_participants=n_participants,
        
        # training parameters
        epochs=100,
        learning_rate=0.1,
        l2_rnn=0.00001,
        l2_sindy=0.00001,
        
        sindy_epochs=10,
        sindy_weight=0.,#1,  # --> sindy_weight=0: SINDy coefficient optimization not activated; not necessary as long as SPICERNN does not learn
        sindy_threshold=0.05,
        sindy_threshold_frequency=100,
        sindy_library_polynomial_degree=2,
        
        verbose=True,
        save_path_spice='../params/augustat2025/spice_augustat2025.pkl',
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cpu...

Training the RNN...
Epoch 12/100 --- L(Train): 0.5929404; Time: 1.43s; Convergence: 8.60e-03

Let's code up a general RNN

In [None]:
class GRU(torch.nn.Module):
    
    def __init__(self, input_size, n_items, n_actions):
        super().__init__()
        
        self.input_size = input_size
        self.gru_features = 32
        self.n_items = n_items
        self.n_actions = n_actions
        
        self.linear_in = torch.nn.Linear(in_features=input_size, out_features=self.gru_features)
        self.gru = torch.nn.GRU(input_size=self.gru_features,hidden_size=n_items, batch_first=True)
        self.linear_out = torch.nn.Linear(in_features=n_items, out_features=n_items)
        
    def forward(self, inputs):
        
        # Get item pairs (already 0-indexed in CSV, no need to subtract 1)
        item_pairs = inputs[..., 2*self.n_actions:2*self.n_actions+2]
        
        y = self.linear_in(inputs.nan_to_num(0))
        y, _ = self.gru(y)
        y = self.linear_out(y)
        
        item1_values = torch.gather(y, 2, item_pairs[..., 0].unsqueeze(-1).nan_to_num(0).long())
        item2_values = torch.gather(y, 2, item_pairs[..., 1].unsqueeze(-1).nan_to_num(0).long())
        
        # Stack to create logits for the pair
        logits = torch.cat([item1_values, item2_values], dim=-1)
        
        return logits

In [None]:
num_epochs = 20

model = GRU(dataset.xs.shape[-1], 6, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    logits = model(dataset.xs)
    
    # Reshape for loss computation
    # FIX: Use .isnan() instead of != torch.nan (NaN != NaN is always True!)
    valid_mask = ~dataset.xs[:, :, 0].reshape(-1).isnan()
    logits_flat = logits.reshape(-1, 2)
    labels_flat = dataset.ys[..., 1].reshape(-1).nan_to_num(0).long()
    
    # Compute loss
    loss = criterion(logits_flat[valid_mask], labels_flat[valid_mask])
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}: Loss: {loss.item()}", end='\r')