In [None]:
#!git clone https://github.com/whyhardt/SPICE.git

In [None]:
# !pip install -e SPICE

In [1]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice import SpiceEstimator, SpiceConfig, convert_dataset, BaseRNN

# For custom RNN
import torch
import torch.nn as nn

Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [2]:
# Load your data
file = '../data/bustamante2023/bustamante2023_processed.csv'
dataset = convert_dataset(
    file = file,
    df_participant_id='subject_id',
    df_choice='decision',
    df_reward='reward',
    df_block='overall_round',
    additional_inputs=['harvest_duration', 'travel_duration'],
    timeshift_additional_inputs=False,
    )

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
n_actions = dataset.ys.shape[-1]
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

Shape of dataset: torch.Size([1997, 112, 9])
Number of participants: 250
Number of actions in dataset: 2
Number of additional inputs: 2


In [3]:
# inspect dataset
dataset.xs[0, :10, :]

tensor([[1.0000, 0.0000, 0.7561,    nan, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.6859,    nan, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000,    nan, 0.0510, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.7515,    nan, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.6407,    nan, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.5662,    nan, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000,    nan, 0.0510, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.8706,    nan, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.7226,    nan, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000,    nan, 0.0510, 2.0000, 8.3333, 0.0000, 0.0000, 0.0000]])

Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [4]:
spice_config = SpiceConfig(
    library_setup={
        'value_stay': ['reward', 'harvest_duration'],
        'value_exit': ['travel_duration'],
    },
    
    memory_state={
            'value': 0.,
        }
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [5]:
class SPICERNN(BaseRNN):
    
    def __init__(self, spice_config, **kwargs):
        super().__init__(spice_config=spice_config, **kwargs)
        
        # participant embedding
        self.participant_embedding = self.setup_embedding(num_embeddings=n_participants, embedding_size=self.embedding_size, dropout=0.)
        
        # set up the submodules
        self.setup_module(key_module='value_stay', input_size=2+self.embedding_size)
        self.setup_module(key_module='value_exit', input_size=1+self.embedding_size)
        
    def forward(self, inputs, prev_state, batch_first=False):
        
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)
        
        harvest_duration = spice_signals.additional_inputs[..., 0].unsqueeze(-1).repeat(1, 1, self.n_actions)
        travel_duration = spice_signals.additional_inputs[..., 1].unsqueeze(-1).repeat(1, 1, self.n_actions)
        rewards_chosen = (spice_signals.actions * spice_signals.rewards).sum(dim=-1, keepdim=True).repeat(1, 1, self.n_actions)
        
        # time-invariant participant features
        participant_embeddings = self.participant_embedding(spice_signals.participant_ids)
        mask_stay = torch.tensor((1,0)).reshape(1, self.n_actions).repeat(rewards_chosen.shape[1], 1)
        mask_exit = torch.tensor((0,1)).reshape(1, self.n_actions).repeat(rewards_chosen.shape[1], 1)
        
        for timestep in spice_signals.timesteps:
            
            # update chosen value
            self.call_module(
                key_module='value_stay',
                key_state='value',
                action_mask=mask_stay,
                inputs=(
                    rewards_chosen[timestep], 
                    harvest_duration[timestep], 
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            # update not chosen value
            self.call_module(
                key_module='value_exit',
                key_state='value',
                action_mask=mask_exit,
                inputs=(
                    travel_duration[timestep], 
                    ),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
            )
            
            # transform logits from item-space to action-space
            spice_signals.logits[timestep] = self.state['value']
            
        spice_signals = self.post_forward_pass(spice_signals, batch_first)
        
        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [6]:
estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SPICERNN,
        spice_config=spice_config,
        n_actions=2,
        n_participants=n_participants,
        n_experiments=1,
        
        # rnn training parameters
        epochs=1000,
        warmup_steps=200,
        learning_rate=0.01,
        
        # sindy fitting parameters
        sindy_weight=0.1,
        sindy_threshold=0.05,
        sindy_threshold_frequency=1,
        sindy_threshold_terms=1,
        sindy_cutoff_patience=100,
        sindy_epochs=1000,
        sindy_alpha=0.0001,
        sindy_library_polynomial_degree=2,
        sindy_ensemble_size=1,
        
        # additional generalization parameters
        batch_size=1024,
        bagging=True,
        scheduler=True,
        
        verbose=True,
        save_path_spice='../params/bustamante2023/spice_bustamante2023.pkl',
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys, dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cpu...

Training the RNN...
Epoch 1/1000 --- L(Train): 0.5092973 --- L(Val, RNN): 0.3817756 --- L(Val, SINDy): 3.8043301 --- Time: 1.18s; --- Convergence: 8.09e-01; LR: 1.00e-02; Metric: inf; Bad epochs: 0/100
--------------------------------------------------------------------------------
SPICE Model (Coefficients: 16):
value_stay[t+1] = 0.002 1 + 1.001 value_stay[t] + 0.002 reward + -0.0 harvest_duration + -0.0 value_stay^2 + 0.001 value_stay*reward + -0.001 value_stay*harvest_duration + 0.002 reward^2 + 0.004 reward*harvest_duration + -0.001 harvest_duration^2 
value_exit[t+1] = 0.011 1 + 1.011 value_exit[t] + -0.01 travel_duration + 0.01 value_exit^2 + -0.009 value_exit*travel_duration + -0.009 travel_duration^2 
--------------------------------------------------------------------------------
Cutoff patience:
value_stay: 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
value_exit: 0, 0, 0, 0, 0, 0
Epoch 2/1000 --- L(Train): 0.3793685 --- L(Val, RNN): 0.3679760 --- L(Val, SINDy): 

In [None]:
df = pd.read_csv(file)
df['last_reward'] = df.groupby('subject_id')['last_reward'].fillna(method='bfill')
df['participant_id'] = pd.factorize(df['subject_id'])[0]
exit_df = df[df['decision'] == 1]
mean_exit_threshold = exit_df.groupby(['participant_id', 'subject_id'])['last_reward'].mean().reset_index()
mean_exit_threshold['relative_optimal'] = mean_exit_threshold['last_reward'] - 6.78 #from Bustamante et al. Table S7, experiment 1
mean_exit_threshold['over_harvester'] = np.where(mean_exit_threshold['relative_optimal'] <= 0, 1, 0)
print(mean_exit_threshold)
overharvesters = mean_exit_threshold[mean_exit_threshold['over_harvester'] == 1]['participant_id'].unique()
underharvesters = mean_exit_threshold[mean_exit_threshold['over_harvester'] == 0]['participant_id'].unique()

     participant_id                subject_id  last_reward  relative_optimal  \
0                 0  08aiu2bm6t15qij5826jxz50     8.241283          1.461283   
1                 1  09j932f828pn7h7bozp9mpnl     5.344722         -1.435278   
2                 2  0ax9htcbhfi3ncsospqzwjx2     8.945571          2.165571   
3                 3  0e6zivqly335lojgb4c6606t     7.554696          0.774696   
4                 4  0fawro1pivqnh4lem4ayf4o0     3.089682         -3.690318   
..              ...                       ...          ...               ...   
245             245  fzllq0yp08zefacpmy7dqq0u     9.041866          2.261866   
246             246  g2n23l2w8uf3brbllm4sbrcx     8.106489          1.326489   
247             247  g9wqksieqbldodjoyuci048q     6.765186         -0.014814   
248             248  garkh3hmuozi9loxpee54z20     7.720383          0.940383   
249             249  gd6af6pqeo2d0x2dcumwirmy     3.908357         -2.871643   

     over_harvester  
0                

In [8]:
print('OVERHARVESTERS') 
for p in overharvesters:
    print('Participant number', p)
    estimator.print_spice_model(participant_id=p)

OVERHARVESTERS
Participant number 1
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 4
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 7
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 12
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 13
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 14
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 20
value_stay[t+1] = 0.875 value_stay[t] + -0.062 harvest_duration 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 23
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 24
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 25
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 


In [9]:
print('UNDERHARVESTERS') 
for p in underharvesters:
    print('Participant number', p)
    estimator.print_spice_model(participant_id=p)

UNDERHARVESTERS
Participant number 0
value_stay[t+1] = -0.16 1 + 0.379 value_stay[t] + -0.16 harvest_duration 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 2
value_stay[t+1] = 0.239 1 + 0.273 value_stay[t] + 0.24 harvest_duration 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 3
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 5
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 6
value_stay[t+1] = -0.237 1 + 0.042 value_stay[t] + -0.237 harvest_duration 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 8
value_stay[t+1] = 1.0 value_stay[t] 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 9
value_stay[t+1] = -0.13 1 + 0.628 value_stay[t] + -0.129 harvest_duration 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 10
value_stay[t+1] = 0.137 1 + 0.569 value_stay[t] + 0.138 harvest_duration 
value_exit[t+1] = 1.0 value_exit[t] 
Participant number 11
value_stay[t+1

## Benchmarking

### MVT Model by Constantino et al. (2015)

In [10]:
import sys

sys.path.append("../..")
from weinhardt2025.benchmarking.benchmarking_bustamante2023 import MarginalValueTheoremModel

mvt = MarginalValueTheoremModel(
    n_participants=n_participants,
    depletion=None,  # if None: learn value; else: fix to given value;
    baseline_gain=None,  # if None: learn value; else: fix to given value;
    )

In [11]:
# benchmark training
epochs = 1000
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=mvt.parameters(), lr=0.01)

for epoch in range(epochs):
    
    random_index = torch.randint(len(dataset.xs), (len(dataset.xs), 1))[:, 0]
    
    mask = ~torch.isnan(dataset.xs[random_index, :, 0]).reshape(-1)
    
    logits, state = mvt(inputs=dataset.xs[random_index], batch_first=True)
    
    # compute loss
    loss = criterion(
        logits.reshape(-1, mvt.n_actions)[mask],
        dataset.ys.argmax(dim=-1, keepdim=True).long().reshape(-1)[mask], 
        )
    
    # backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}/{epochs} --- Loss: {loss.item():.5f}")

Epoch 1/1000 --- Loss: 0.72331
Epoch 2/1000 --- Loss: 0.71865
Epoch 3/1000 --- Loss: 0.71365
Epoch 4/1000 --- Loss: 0.70874
Epoch 5/1000 --- Loss: 0.70381
Epoch 6/1000 --- Loss: 0.69908
Epoch 7/1000 --- Loss: 0.69476
Epoch 8/1000 --- Loss: 0.69014
Epoch 9/1000 --- Loss: 0.68553
Epoch 10/1000 --- Loss: 0.68091
Epoch 11/1000 --- Loss: 0.67673
Epoch 12/1000 --- Loss: 0.67237
Epoch 13/1000 --- Loss: 0.66785
Epoch 14/1000 --- Loss: 0.66333
Epoch 15/1000 --- Loss: 0.65945
Epoch 16/1000 --- Loss: 0.65527
Epoch 17/1000 --- Loss: 0.65087
Epoch 18/1000 --- Loss: 0.64707
Epoch 19/1000 --- Loss: 0.64334
Epoch 20/1000 --- Loss: 0.63907
Epoch 21/1000 --- Loss: 0.63501
Epoch 22/1000 --- Loss: 0.63115
Epoch 23/1000 --- Loss: 0.62727
Epoch 24/1000 --- Loss: 0.62378
Epoch 25/1000 --- Loss: 0.61989
Epoch 26/1000 --- Loss: 0.61617
Epoch 27/1000 --- Loss: 0.61258
Epoch 28/1000 --- Loss: 0.60865
Epoch 29/1000 --- Loss: 0.60535
Epoch 30/1000 --- Loss: 0.60167
Epoch 31/1000 --- Loss: 0.59840
Epoch 32/1000 ---

In [12]:
print("Fitted parameters:")
print("\nAlpha")
print(mvt.alpha_env)
print("\nBeta")
print(mvt.beta)
print("\nC")
print(mvt.c)
print("\nBaseline Gain")
print(mvt.baseline_gain)
print("\nDepletion")
print(mvt.depletion)

Fitted parameters:

Alpha
tensor([0.1216, 0.0686, 0.1614, 0.0934, 0.0554, 0.0959, 0.0936, 0.0669, 0.5242,
        0.1769, 0.2570, 0.1074, 0.0808, 0.2032, 0.0811, 0.2943, 0.1229, 0.0838,
        0.1248, 0.0769, 0.1100, 0.1705, 0.2785, 0.0743, 0.0964, 0.0763, 0.0746,
        0.0816, 0.1129, 0.2516, 0.1096, 0.1177, 0.0824, 0.5659, 0.0924, 0.0818,
        0.0265, 0.1639, 0.1019, 0.1195, 0.0851, 0.4803, 0.1283, 0.1025, 0.0625,
        0.0579, 0.0807, 0.4586, 0.1150, 0.4680, 0.0829, 0.0659, 0.1560, 0.1225,
        0.0914, 0.1224, 0.1121, 0.1102, 0.2327, 0.0417, 0.1890, 0.0930, 0.0874,
        0.1368, 0.1928, 0.2621, 0.0763, 0.1173, 0.0676, 0.0788, 0.1362, 0.0498,
        0.0698, 0.4696, 0.0845, 0.5709, 0.1174, 0.0819, 0.0799, 0.5079, 0.0850,
        0.0955, 0.0629, 0.2425, 0.0798, 0.1497, 0.0971, 0.0971, 0.3654, 0.0969,
        0.0742, 0.1245, 0.0737, 0.0723, 0.1288, 0.0846, 0.1343, 0.0518, 0.0962,
        0.0699, 0.2615, 0.0785, 0.1222, 0.0757, 0.0676, 0.0836, 0.1100, 0.1044,
        0.2302

### GRU Model

In [13]:
class GRU(torch.nn.Module):
    
    def __init__(self, input_size, n_actions):
        super().__init__()
        
        self.input_size = input_size
        self.gru_features = 32
        self.n_actions = n_actions
        
        self.linear_in = torch.nn.Linear(in_features=input_size, out_features=self.gru_features)
        self.gru = torch.nn.GRU(input_size=self.gru_features,hidden_size=n_actions, batch_first=True)
        self.linear_out = torch.nn.Linear(in_features=n_actions, out_features=n_actions)
        
    def forward(self, inputs):
        
        y = self.linear_in(inputs.nan_to_num(0))
        y, _ = self.gru(y)
        logits = self.linear_out(y)
        
        return logits

In [14]:
epochs = 1000

gru = GRU(dataset.xs.shape[-1], 2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gru.parameters(), lr=0.01)

for epoch in range(epochs):
    
    random_index = torch.randint(len(dataset.xs), (len(dataset.xs), 1))[:, 0]
    logits, state = mvt(inputs=dataset.xs[random_index], batch_first=True)
    
    mask = ~torch.isnan(dataset.xs[random_index, :, 0]).reshape(-1)
    
    # compute loss
    loss = criterion(
        logits.reshape(-1, mvt.n_actions)[mask],
        dataset.ys.argmax(dim=-1, keepdim=True).long().reshape(-1)[mask], 
        )
    
    # backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}/{epochs} --- Loss: {loss.item():.5f}")

Epoch 1/1000 --- Loss: 0.34180
Epoch 2/1000 --- Loss: 0.34218
Epoch 3/1000 --- Loss: 0.34284
Epoch 4/1000 --- Loss: 0.34271
Epoch 5/1000 --- Loss: 0.34323
Epoch 6/1000 --- Loss: 0.34310
Epoch 7/1000 --- Loss: 0.34271
Epoch 8/1000 --- Loss: 0.34328
Epoch 9/1000 --- Loss: 0.34264
Epoch 10/1000 --- Loss: 0.34249
Epoch 11/1000 --- Loss: 0.34319
Epoch 12/1000 --- Loss: 0.34258
Epoch 13/1000 --- Loss: 0.34264
Epoch 14/1000 --- Loss: 0.34278
Epoch 15/1000 --- Loss: 0.34253
Epoch 16/1000 --- Loss: 0.34292
Epoch 17/1000 --- Loss: 0.34200
Epoch 18/1000 --- Loss: 0.34345
Epoch 19/1000 --- Loss: 0.34315
Epoch 20/1000 --- Loss: 0.34302
Epoch 21/1000 --- Loss: 0.34299
Epoch 22/1000 --- Loss: 0.34327
Epoch 23/1000 --- Loss: 0.34208
Epoch 24/1000 --- Loss: 0.34224
Epoch 25/1000 --- Loss: 0.34330
Epoch 26/1000 --- Loss: 0.34241
Epoch 27/1000 --- Loss: 0.34289
Epoch 28/1000 --- Loss: 0.34309
Epoch 29/1000 --- Loss: 0.34224
Epoch 30/1000 --- Loss: 0.34246
Epoch 31/1000 --- Loss: 0.34273
Epoch 32/1000 ---