In [1]:
#!git clone https://github.com/whyhardt/SPICE.git

In [1]:
# !pip install -e SPICE

[31mERROR: SPICE is not a valid editable requirement. It should either be a path to a local project or a VCS URL (beginning with bzr+http, bzr+https, bzr+ssh, bzr+sftp, bzr+ftp, bzr+lp, bzr+file, git+http, git+https, git+ssh, git+git, git+file, hg+file, hg+http, hg+https, hg+ssh, hg+static-http, svn+ssh, svn+http, svn+https, svn+svn, svn+file).[0m[31m
[0m

In [1]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator
from spice.resources.spice_utils import SpiceConfig
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

  from pkg_resources import DistributionNotFound


Let's load the data first with the `convert_dataset` method. This method returns a `SpiceDataset` object which we can use right away 

In [5]:
import pandas as pd

df = pd.read_csv('../data/bustamante2023/bustamante2023.csv')
df['decision'] = df['decision'].replace({'stay': 0, 'exit': 1})
df['decision_duration'] = df['decision'].replace({'stay': 2, 'exit': 8.333})
df['harvest_duration'] = (df['choice_duration'] + df['harvest_duration'])/1000
df['travel_duration'] = df['travel_duration']/1000
df.to_csv('../data/bustamante2023/bustamante2023_processed.csv', index=False)

In [6]:
# Load your data
dataset = convert_dataset(
    file = '../data/bustamante2023/bustamante2023_processed.csv',
    df_participant_id='subject_id',
    df_choice='decision',
    df_reward='reward',
    df_block='overall_round',
    additional_inputs=['decision_duration'],
    timeshift_additional_inputs=False,
    )

# structure of dataset:
# dataset has two main attributes: xs -> inputs; ys -> targets (next action)
# shape: (n_participants*n_blocks*n_experiments, n_timesteps, features)
# features are (n_actions * action, n_actions * reward, n_additional_inputs * additional_input, block_number, experiment_id, participant_id)

# in order to set up the participant embedding we have to compute the number of unique participants in our data 
# to get the number of participants n_participants we do:
n_participants = len(dataset.xs[..., -1].unique())

print(f"Shape of dataset: {dataset.xs.shape}")
print(f"Number of participants: {n_participants}")
n_actions = dataset.ys.shape[-1]
print(f"Number of actions in dataset: {n_actions}")
print(f"Number of additional inputs: {dataset.xs.shape[-1]-2*n_actions-3}")

Shape of dataset: torch.Size([4296, 112, 8])
Number of participants: 537
Number of actions in dataset: 2
Number of additional inputs: 1


In [7]:
dataset.xs[0, :10, :]

tensor([[1.0000, 0.0000, 0.7286,    nan, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.6610,    nan, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000,    nan, 0.0492, 1.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.7242,    nan, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.6174,    nan, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.5456,    nan, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000,    nan, 0.0492, 1.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.8390,    nan, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.6964,    nan, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000,    nan, 0.0492, 1.0000, 0.0000, 0.0000, 0.0000]])

Now we are going to define the configuration for SPICE with a `SpiceConfig` object.

The `SpiceConfig` takes as arguments 
1. `library_setup (dict)`: Defining the variable names of each module.
2. `memory_state (dict)`: Defining the memory state variables and their initial values.
3. `states_in_logit (list)`: Defining which of the memory state variables are used later for the logit computation. This is necessary for some background processes.  

In [8]:
spice_config = SpiceConfig(
    library_setup={
        'value_stay': ['reward', 'decision_duration', 'value_stay'],
        'value_exit': ['reward', 'decision_duration', 'value_exit'],
    },
    
    memory_state={
            'value': 0.5,
        }
)

And now we are going to define the SPICE model which is a child of the `BaseRNN` and `torch.nn.Module` class and takes as required arguments:
1. `spice_config (SpiceConfig)`: previously defined SpiceConfig object
2. `n_actions (int)`: number of possible actions in your dataset (including non-displayed ones if applicable).
3. `n_participants (int)`: number of participants in your dataset.

As usual for a `torch.nn.Module` we have to define at least the `__init__` method and the `forward` method.
The `forward` method gets called when computing a forward pass through the model and takes as inputs `(inputs (SpiceDataset.xs), prev_state (dict, default: None), batch_first (bool, default: False))` and returns `(logits (torch.Tensor, shape: (n_participants*n_blocks*n_experiments, timesteps, n_actions)), updated_state (dict))`. Two necessary method calls inside the forward pass are:
1. `self.init_forward_pass(inputs, prev_state, batch_first) -> SpiceSignals`: returns a `SpiceSignals` object which carries all relevant information already processed.
2. `self.post_forward_pass(SpiceSignals, batch_first) -> SpiceSignals`: does some re-arranging of the logits to adhere to `batch_first`.

In [9]:
class SPICERNN(BaseRNN):
    
    def __init__(self, spice_config, n_participants, **kwargs):
        super().__init__(n_actions=2, spice_config=spice_config, n_participants=n_participants, embedding_size=32)
        
        # participant embedding
        self.participant_embedding = self.setup_embedding(num_embeddings=n_participants, embedding_size=self.embedding_size, dropout=0.)
        
        # scaling factor (inverse noise temperature) for each participant for the values which are handled by an hard-coded equation
        self.betas['value'] = self.setup_constant(embedding_size=self.embedding_size)
        
        # set up the submodules
        self.submodules_rnn['value_stay'] = self.setup_module(input_size=3+self.embedding_size)
        self.submodules_rnn['value_exit'] = self.setup_module(input_size=3+self.embedding_size)
        
    def forward(self, inputs, prev_state, batch_first=False):
        
        spice_signals = self.init_forward_pass(inputs, prev_state, batch_first)
        
        decision_duration = spice_signals.additional_inputs[..., 0].unsqueeze(-1).repeat(1, 1, self.n_actions)
        decision_duration = torch.where(decision_duration==0, 0.240, decision_duration)
        rewards_chosen = (spice_signals.actions * spice_signals.rewards.nan_to_num(0)).sum(dim=-1, keepdim=True).repeat(1, 1, self.n_actions)
        
        # time-invariant participant features
        participant_embeddings = self.participant_embedding(spice_signals.participant_ids)
        beta_reward = self.betas['value'](participant_embeddings)
        mask_stay = torch.tensor((1,0)).reshape(1, 1, self.n_actions).repeat(rewards_chosen.shape[0], rewards_chosen.shape[1], 1)
        mask_exit = torch.tensor((0,1)).reshape(1, 1, self.n_actions).repeat(rewards_chosen.shape[0], rewards_chosen.shape[1], 1)
        
        for timestep in spice_signals.timesteps:
            
            value_stay = self.state['value'][..., 0][:, None].repeat(1, 2)
            value_exit = self.state['value'][..., 1][:, None].repeat(1, 2)
            
            # update chosen value
            self.call_module(
                key_module='value_stay',
                key_state='value',
                action_mask=mask_stay[timestep],
                inputs=(rewards_chosen[timestep], decision_duration[timestep], value_exit),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            # update not chosen value
            self.call_module(
                key_module='value_exit',
                key_state='value',
                action_mask=mask_exit[timestep],
                inputs=(rewards_chosen[timestep], decision_duration[timestep], value_stay),
                participant_index=spice_signals.participant_ids,
                participant_embedding=participant_embeddings,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            
            # transform logits from item-space to action-space
            spice_signals.logits[timestep] = self.state['value'] * beta_reward
            
        spice_signals = self.post_forward_pass(spice_signals, batch_first)
        
        return spice_signals.logits, self.get_state()

Let's setup now the `SpiceEstimator` object and fit it to the data!

In [15]:
estimator = SpiceEstimator(
        # model paramaeters
        rnn_class=SPICERNN,
        spice_config=spice_config,
        n_actions=2,
        n_participants=n_participants,
        n_experiments=1,
        
        # training parameters
        epochs=1000,
        l2_weight_decay=0.01,
        sindy_epochs=5000, # could also be 1000,
        sindy_threshold=0.1,
        sindy_threshold_frequency=100,
        sindy_weight=0.01,
        sindy_library_polynomial_degree=2,
        verbose=True,
        save_path_spice='../params/bustamante2023/spice_bustamante2023.pkl',
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)


Starting training on cpu...

Training the RNN...
Epoch 599/1000 --- L(Train): 0.3720970; Time: 2.77s; Convergence: 2.06e-03
SPICE model before 600 epochs:
value_stay[t+1] = 0.0751 1 + 0.8881 value_stay[t] + 0.0253 reward + -0.0135 decision_duration + 1.0612 value_stay[t] + -0.1852 value_stay^2 + -0.1907 value_stay*reward + -0.1897 value_stay*decision_duration + -0.0922 value_stay*value_stay + 0.2586 reward^2 + 0.0452 reward*decision_duration + 0.0211 reward*value_stay + 0.2366 decision_duration^2 + -0.0123 decision_duration*value_stay + 0.0403 value_stay^2 
value_exit[t+1] = 0.0056 1 + 0.9997 value_exit[t] + 0.005 decision_duration + 1.0058 value_exit[t] + -0.0108 value_exit^2 + -0.0123 value_exit*reward + -0.0129 value_exit*value_exit + 0.0039 reward^2 + -0.0014 reward*decision_duration + 0.0101 reward*value_exit + -0.0038 decision_duration^2 + -0.0095 value_exit^2 
beta(value) = 24.9759

SPICE model after 600 epochs:
value_stay[t+1] = 0.0751 1 + 0.8881 value_stay[t] + 0.0253 reward 

In [43]:
df['last_reward'] = df.groupby('subject_id')['last_reward'].fillna(method='bfill')
df['participant_id'] = pd.factorize(df['subject_id'])[0]
exit_df = df[df['decision'] == 1]
mean_exit_threshold = exit_df.groupby(['participant_id', 'subject_id'])['last_reward'].mean().reset_index()
mean_exit_threshold['relative_optimal'] = mean_exit_threshold['last_reward'] - 6.78 #from Bustamante et al. Table S7, experiment 1
mean_exit_threshold['over_harvester'] = np.where(mean_exit_threshold['relative_optimal'] <= 0, 1, 0)
print(mean_exit_threshold)
overharvesters = mean_exit_threshold[mean_exit_threshold['over_harvester'] == 1]['participant_id'].unique()
underharvesters = mean_exit_threshold[mean_exit_threshold['over_harvester'] == 0]['participant_id'].unique()


     participant_id                subject_id  last_reward  relative_optimal  \
0                 0  08aiu2bm6t15qij5826jxz50     8.241283          1.461283   
1                 1  09j932f828pn7h7bozp9mpnl     5.344722         -1.435278   
2                 2  0ax9htcbhfi3ncsospqzwjx2     8.945571          2.165571   
3                 3  0e6zivqly335lojgb4c6606t     7.554696          0.774696   
4                 4  0fawro1pivqnh4lem4ayf4o0     3.089682         -3.690318   
..              ...                       ...          ...               ...   
532             532  zevrdzmkwvzsihx3hwzpudkr     9.071613          2.291613   
533             533  zhgkeunzm8bshh2zg2kywfjv    10.476446          3.696446   
534             534  zhyb93wwd8tvuvagdkrfh108     9.012543          2.232543   
535             535  zn94ngriwmlw6fdmvx1onez4     5.058637         -1.721363   
536             536  zpsyqyrcmklbww1o9f2r5hku     6.024730         -0.755270   

     over_harvester  
0                

In [45]:
print('OVERHARVESTERS') 
for p in overharvesters:
    print('Participant number', p)
    estimator.print_spice_model(participant_id=p)

OVERHARVESTERS
Participant number 1
value_stay[t+1] = 0.188 1 + 0.7851 value_stay[t] + 0.3444 reward + 0.209 decision_duration + 1.0 value_stay[t] + -0.3207 value_stay*decision_duration + -0.4346 value_stay*value_stay + -0.2226 reward^2 + 0.3536 reward*decision_duration + 0.1708 decision_duration^2 + -0.331 value_stay^2 
value_exit[t+1] = 0
beta(value) = 46.0193
Participant number 4
value_stay[t+1] = 0.1576 1 + 0.8503 value_stay[t] + 0.4281 reward + 0.1835 decision_duration + 1.0 value_stay[t] + -0.2565 value_stay*reward + -0.406 value_stay*decision_duration + -0.3052 value_stay*value_stay + -0.2531 reward^2 + 0.4188 reward*decision_duration + 0.1683 decision_duration^2 + -0.2212 value_stay^2 
value_exit[t+1] = 0
beta(value) = 41.1108
Participant number 7
value_stay[t+1] = 0.1782 1 + 0.8371 value_stay[t] + 0.4067 reward + 0.19 decision_duration + 1.0 value_stay[t] + -0.1224 value_stay^2 + -0.2607 value_stay*decision_duration + -0.3784 value_stay*value_stay + -0.3138 reward^2 + 0.4092 r

In [46]:
print('UNDERHARVESTERS') 
for p in underharvesters:
    print('Participant number', p)
    estimator.print_spice_model(participant_id=p)

UNDERHARVESTERS
Participant number 0
value_stay[t+1] = 0.2411 1 + 0.7746 value_stay[t] + 0.239 reward + 0.2075 decision_duration + 1.0 value_stay[t] + -0.2934 value_stay*decision_duration + -0.4747 value_stay*value_stay + 0.241 reward*decision_duration + -0.172 reward*value_stay + 0.1384 decision_duration^2 + -0.339 value_stay^2 
value_exit[t+1] = 0
beta(value) = 27.7513
Participant number 2
value_stay[t+1] = 0.254 1 + 1.0 value_stay[t] + -0.2787 reward + 1.0 value_stay[t] + -0.2553 value_stay^2 + -0.2022 value_stay*reward + -0.1779 value_stay*value_stay + 0.6467 reward^2 + -0.4428 reward*value_stay 
value_exit[t+1] = 0
beta(value) = 66.9040
Participant number 3
value_stay[t+1] = 0.1424 1 + 1.0 value_stay[t] + 0.1433 decision_duration + 1.0 value_stay[t] + -0.2246 value_stay^2 + -0.2348 value_stay*decision_duration + -0.3834 value_stay*value_stay + 0.2932 reward^2 + -0.2295 reward*value_stay + 0.1167 decision_duration^2 
value_exit[t+1] = 0
beta(value) = 63.2399
Participant number 5
va