In [None]:
#!git clone https://github.com/whyhardt/SPICE.git

In [None]:
# !pip install -e SPICE

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator
from spice.resources.spice_utils import SpiceConfig, SpiceDataset
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [None]:
# Load your data
dataset = convert_dataset(
    file = '../data/hwang2025/hwang2025.csv',
    df_participant_id='ID1',
    df_choice='SigAct_ID1',
    df_reward='Grooming_ID2',
    additional_inputs=['ID2', 'SigAct_ID2', 'Grooming_ID1'],
    timeshift_additional_inputs=False,
    )

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
n_actions = dataset.ys.shape[-1]
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 496 but got size 612 for tensor number 1 in the list.

### Dataset description

In [None]:
dataset.xs.shape # shape -> (n_participants: 41, timesteps: 496, features: 16)

# normal RL exp:    [A] [B] [C] [D] [E]
# choice:           [x] [ ] [ ] [ ] [ ]
# reward:           [1] [ ] [ ] [ ] [ ]    (partial feedback)
# reward:           [1] [0] [1] [1] [0]    (full feedback)

# features: (action0, action1, action2, action3, action4, reward0, reward1, reward2, reward3, reward4, 'ID2', 'SigAct_ID2', 'Grooming_ID1', block number, experiment id, ID1)
# in your case: (x, x, x, x, x, -, -, -, -, -, x, x, x, -, -, -, -, x)    -> x: keep; -: ignore

Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [None]:
spice_config = SpiceConfig(
    library_setup={
        'value_action': ['sig_action', 'sig_grooming', 'sig_non_contact', 'sig_contact', 'sig_scratch', 'sig_waiting', 'prev_action', 'prev_grooming', 'prev_non_contact', 'prev_contact', 'prev_scratch', 'prev_waiting'],
        'value_grooming': ['sig_action', 'sig_grooming', 'sig_non_contact', 'sig_contact', 'sig_scratch', 'sig_waiting', 'prev_action', 'prev_grooming', 'prev_non_contact', 'prev_contact', 'prev_scratch', 'prev_waiting'],
        'value_non_conctant': ['sig_action', 'sig_grooming', 'sig_non_contact', 'sig_contact', 'sig_scratch', 'sig_waiting', 'prev_action', 'prev_grooming', 'prev_non_contact', 'prev_contact', 'prev_scratch', 'prev_waiting'],
        'value_contact': ['sig_action', 'sig_grooming', 'sig_non_contact', 'sig_contact', 'sig_scratch', 'sig_waiting', 'prev_action', 'prev_grooming', 'prev_non_contact', 'prev_contact', 'prev_scratch', 'prev_waiting'],
        'value_scratch': ['sig_action', 'sig_grooming', 'sig_non_contact', 'sig_contact', 'sig_scratch', 'sig_waiting', 'prev_action', 'prev_grooming', 'prev_non_contact', 'prev_contact', 'prev_scratch', 'prev_waiting'],
        'waiting': ['sig_action', 'sig_grooming', 'sig_non_contact', 'sig_contact', 'sig_scratch', 'sig_waiting', 'prev_action', 'prev_grooming', 'prev_non_contact', 'prev_contact', 'prev_scratch', 'prev_waiting'],
    },
    
    memory_state={
        # 'value_action': 0,
        # 'value_grooming': 0,
        # 'value_non_conctant': 0,
        # 'value_contact': 0,
        # 'value_scratch': 0,
        # 'waiting': 0,
        'values': 0,
        },
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [None]:
class SPICERNN(BaseRNN):
    
    def __init__(self, n_actions, spice_config, n_participants, n_items, **kwargs):
        super().__init__(n_actions=n_actions, spice_config=spice_config, n_participants=n_participants, n_items=n_items, embedding_size=32)
        
        # participant embedding
        self.participant_embedding = self.setup_embedding(num_embeddings=n_participants, embedding_size=self.embedding_size, dropout=0.)
        
        # rnn modules
        # reward-based modules
        self.submodules_rnn['value_action'] = self.setup_module(10+self.embedding_size*2)
        self.submodules_rnn['value_grooming'] = self.setup_module(10+self.embedding_size*2)
        self.submodules_rnn['value_non_contact'] = self.setup_module(10+self.embedding_size*2)
        self.submodules_rnn['value_contact'] = self.setup_module(10+self.embedding_size*2)
        self.submodules_rnn['value_scratch'] = self.setup_module(10+self.embedding_size*2)
        self.submodules_rnn['value_waiting'] = self.setup_module(10+self.embedding_size*2)
    
    def forward(self, inputs, prev_state, batch_first=False):
        
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)
        
        # get participant id of Ape 2 (not implemented in spice_signals.participant_ids)
        participant_id_2 = spice_signals.additional_inputs[..., 0]
        participant_id_1 = spice_signals.participant_ids
        
        # time-invariant participant features
        participant_embeddings_1 = self.participant_embedding_1(participant_id_1)
        participant_embeddings_2 = self.participant_embedding_2(participant_id_2)
        participant_embeddings = torch.concat((participant_embeddings_1, participant_embeddings_2), dim=-1)
        
        # setup all variables
        sig_action = None  # make that a proper onehot-tensor; shape = (timesteps, batch, binary)
        sig_grooming = None
        sig_non_contact = None
        sig_contact = None
        sig_scratch = None
        sig_waiting = None
        prev_action = None
        prev_grooming = None
        prev_non_contact = None
        prev_contact = None
        prev_scratch = None
        prev_waiting = None
        
        for timestep in spice_signals.timesteps:
            
            # update chosen value
            self.call_module(
                key_module='value_action',
                key_state='values',
                action_mask=torch.tensor((1, 0, 0, 0, 0, 0)).reshape(1, -1).repeat(spice_signals.actions.shape[1], 1), # dummy-solution; make torch-tensor; reshape in 
                inputs=(
                    sig_action, 
                    sig_grooming, 
                    sig_non_contact, 
                    sig_contact, 
                    sig_scratch, 
                    sig_waiting, 
                    prev_action, 
                    prev_grooming, 
                    prev_non_contact, 
                    prev_contact, 
                    prev_scratch, 
                    prev_waiting,
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            self.call_module(
                key_module='value_grooming',
                key_state='values',
                action_mask=torch.tensor((0, 1, 0, 0, 0, 0)).reshape(1, -1).repeat(spice_signals.actions.shape[1], 1), # dummy-solution; make torch-tensor; reshape in 
                inputs=(
                    sig_action, 
                    sig_grooming, 
                    sig_non_contact, 
                    sig_contact, 
                    sig_scratch, 
                    sig_waiting, 
                    prev_action, 
                    prev_grooming, 
                    prev_non_contact, 
                    prev_contact, 
                    prev_scratch, 
                    prev_waiting,
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            self.call_module(
                key_module='value_non_contact',
                key_state='values',
                action_mask=torch.tensor((0, 0, 1, 0, 0, 0)).reshape(1, -1).repeat(spice_signals.actions.shape[1], 1), # dummy-solution; make torch-tensor; reshape in 
                inputs=(
                    sig_action, 
                    sig_grooming, 
                    sig_non_contact, 
                    sig_contact, 
                    sig_scratch, 
                    sig_waiting, 
                    prev_action, 
                    prev_grooming, 
                    prev_non_contact, 
                    prev_contact, 
                    prev_scratch, 
                    prev_waiting,
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            self.call_module(
                key_module='value_contact',
                key_state='values',
                action_mask=torch.tensor((0, 0, 0, 1, 0, 0)).reshape(1, -1).repeat(spice_signals.actions.shape[1], 1), # dummy-solution; make torch-tensor; reshape in 
                inputs=(
                    sig_action, 
                    sig_grooming, 
                    sig_non_contact, 
                    sig_contact, 
                    sig_scratch, 
                    sig_waiting, 
                    prev_action, 
                    prev_grooming, 
                    prev_non_contact, 
                    prev_contact, 
                    prev_scratch, 
                    prev_waiting,
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            self.call_module(
                key_module='value_scratch',
                key_state='values',
                action_mask=torch.tensor((0, 0, 0, 0, 1, 0)).reshape(1, -1).repeat(spice_signals.actions.shape[1], 1), # dummy-solution; make torch-tensor; reshape in 
                inputs=(
                    sig_action, 
                    sig_grooming, 
                    sig_non_contact, 
                    sig_contact, 
                    sig_scratch, 
                    sig_waiting, 
                    prev_action, 
                    prev_grooming, 
                    prev_non_contact, 
                    prev_contact, 
                    prev_scratch, 
                    prev_waiting,
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            self.call_module(
                key_module='value_waiting',
                key_state='values',
                action_mask=torch.tensor((0, 0, 0, 0, 0, 1)).reshape(1, -1).repeat(spice_signals.actions.shape[1], 1), # dummy-solution; make torch-tensor; reshape in 
                inputs=(
                    sig_action, 
                    sig_grooming, 
                    sig_non_contact, 
                    sig_contact, 
                    sig_scratch, 
                    sig_waiting, 
                    prev_action, 
                    prev_grooming, 
                    prev_non_contact, 
                    prev_contact, 
                    prev_scratch, 
                    prev_waiting,
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            spice_signals.logits[timestep] = self.state['values']
            
        spice_signals = self.post_forward_pass(spice_signals, batch_first)
        
        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [None]:
estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SPICERNN,
        spice_config=spice_config,
        n_actions=2,
        n_items=6,
        n_participants=n_participants,
        n_experiments=1,
        
        # training parameters
        epochs=10,
        l2_weight_decay=0.01,
        sindy_epochs=10,
        sindy_threshold=0.1,
        sindy_threshold_frequency=100,
        sindy_weight=0.01,
        sindy_library_polynomial_degree=2,
        verbose=True,
        save_path_spice='../params/augustat2025/spice_augustat2025.pkl',
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)

Let's code up a general RNN

In [None]:
class GRU(torch.nn.Module):
    
    def __init__(self, input_size, n_actions):
        super().__init__()
        
        self.input_size = input_size
        self.gru_features = 32
        self.n_actions = n_actions
        
        self.linear_in = torch.nn.Linear(in_features=input_size, out_features=self.gru_features)
        self.gru = torch.nn.GRU(input_size=self.gru_features,hidden_size=n_actions, batch_first=True)
        self.linear_out = torch.nn.Linear(in_features=n_actions, out_features=n_actions)
        
    def forward(self, inputs):
        
        y = self.linear_in(inputs.nan_to_num(0))
        y, _ = self.gru(y)
        y = self.linear_out(y)
        
        return y

In [None]:
num_epochs = 1000

model = GRU(dataset.xs.shape[-1], 6)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    logits = model(dataset.xs)
    
    # Reshape for loss computation
    nan_mask = dataset.xs[:, :, 0].reshape(-1) != torch.nan
    logits_flat = logits.reshape(-1, 2)
    labels_flat = dataset.ys[..., 1].reshape(-1).nan_to_num(0).long()
    
    # Compute loss
    loss = criterion(logits_flat[nan_mask], labels_flat[nan_mask])
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}: Loss: {loss.item()}")

### TO-DO:

1. **implement some logic to ignore SigAct_ID1 == 5 (waiting)** 

-> nothing to predict here

-> whenever SigAct_ID1[t+1] == 5: Don't let the RNN predict because there's actually nothing to predict

2. **add reversed blocks in csv file (ID1<->ID2) to double the amount of predictable data:**

ID1,Dominan0 rank_ID1,ID2,Dominan0 rank_ID2,SigAct_ID1,SigAct_ID2,interaction_id,community_id,Grooming_ID1,Grooming_ID2

Original block:

13,13,6,6,1.0,5.0,1,0,1,0

13,13,6,6,2.0,5.0,1,0,0,0

13,13,6,6,0.0,5.0,1,0,0,0

Add reversed block:

6,6,13,13,5.0,1.0,1,0,0,1

6,6,13,13,5.0,2.0,1,0,0,0

6,6,13,13,5.0,0.0,1,0,0,0
