In [None]:
!git clone --branch sindy-reg --single-branch https://github.com/whyhardt/SPICE.git

Cloning into 'SPICE'...
remote: Enumerating objects: 763, done.[K
remote: Counting objects: 100% (272/272), done.[K
remote: Compressing objects: 100% (165/165), done.[K
remote: Total 763 (delta 162), reused 190 (delta 106), pack-reused 491 (from 1)[K
Receiving objects: 100% (763/763), 20.06 MiB | 17.04 MiB/s, done.
Resolving deltas: 100% (411/411), done.


In [None]:
!pip install -e SPICE

Obtaining file:///content/SPICE
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting arviz<0.21.0,>=0.20.0 (from autospice==0.2.0)
  Downloading arviz-0.20.0-py3-none-any.whl.metadata (8.8 kB)
Collecting jax<0.5.0,>=0.4.35 (from autospice==0.2.0)
  Downloading jax-0.4.38-py3-none-any.whl.metadata (22 kB)
Collecting numpy<2.0.0,>=1.21.0 (from autospice==0.2.0)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpyro<0.16.0,>=0.15.0 (from autospice==0.2.0)
  Downloading numpyro-0.15.3-py3-none-any.whl.metadata (36 kB)
Collecting pyro_ppl<2.0.0,>=1.9.0 (from autospice==0.2.0)
  Downloading p

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from spice.estimator import SpiceEstimator
from spice.utils.convert_dataset import convert_dataset
from spice.resources.rnn import BaseRNN

# For custom RNN
import torch
import torch.nn as nn

In [14]:
import torch
import torch.nn as nn

class HoverRNN(BaseRNN):
    """
    Custom RNN for modeling hover behavior.

    CRITICAL: Must match the interface expected by SPICE!
    The RNN should:
    - Take input of shape (batch, seq_len, input_size)
    - Return output of shape (batch, seq_len, output_size)
    - Optionally return hidden states
    """

    """init_values = {
        'x_value_0': 0,
        'x_value_1': 0,
        'x_value_2': 0,
        'x_value_3': 0,
        'x_value_4': 0,
        'x_value_5': 0,
        'x_value_6': 0,
        'x_value_7': 0,
        'x_value_8': 0,
        'x_value_9': 0,
        'x_value_10': 0,
        'x_value_11': 0,
        'x_value_12': 0,
        'x_value_13': 0,
        'x_value_14': 0,
        'x_value_15': 0,
    }"""

    init_values = {
        'x_value_tile': 0
    }

    def __init__(self, n_actions, n_participants, **kwargs):
        super(HoverRNN, self).__init__(n_actions=n_actions, n_participants=n_participants, embedding_size=32)

        self.sindy_polynomial_degree = 2

        self.participant_embedding = self.setup_embedding(
            n_participants, self.embedding_size, dropout=0.
        )

                # Value learning module (slow updates)
        # Can use recent reward history to modulate learning
        self.submodules_rnn['x_update_tile_visited_self'] = self.setup_module(input_size=self.embedding_size)
        self.submodules_rnn['x_update_tile_visited_partner'] = self.setup_module(input_size=self.embedding_size)
        self.submodules_rnn['x_update_tile_not_visited'] = self.setup_module(input_size=self.embedding_size)

        self.setup_sindy_coefficients(polynomial_degree=self.sindy_polynomial_degree)

    def forward(self, inputs, prev_state=None, batch_first=False):
      """
      Forward pass.

      Args:
          inputs: Tuple containing (actions, rewards, additional_inputs, participant_ids)
          prev_state: Optional previous hidden state
          batch_first: Whether first dimension is batch (True) or timesteps (False)

      Returns:
          logits: (batch, seq_len, n_actions) - Action logits for each tile
          state: Updated hidden state dictionary
      """

      # Initialize inputs, outputs, and timesteps
      input_variables, ids, logits, timesteps, sindy_loss_timesteps = self.init_forward_pass(inputs, prev_state, batch_first)
      actions, rewards, additional_inputs, participant_ids = input_variables

      # Extract tile indices from additional_inputs
      # Assuming additional_inputs contains [tile_index_self, tile_index_partner, ...]
      tile_indexes_self = additional_inputs[..., 0]
      tile_indexes_partner = additional_inputs[..., 1]
      # TODO: compute one-hot encoded action array for partner from tile index and remove corresponding line in for-loop

      # Get participant embeddings
      participant_embedding = self.participant_embedding(participant_ids[:, 0].int())

      # Main loop: process each timestep
      for timestep, action, reward, tile_index_self, tile_index_partner, participant_id in zip(
          timesteps, actions, rewards, tile_indexes_self, tile_indexes_partner, participant_ids
      ):
          # Convert tile indices to integers for indexing
          tile_index_self = int(tile_index_self.item()) if hasattr(tile_index_self, 'item') else int(tile_index_self)
          tile_index_partner = int(tile_index_partner.item()) if hasattr(tile_index_partner, 'item') else int(tile_index_partner)

          # one-hot encoded action array based on tile index self and partner
          # action <- tile_index_self
          # we need an action of the partner:
          # action_partner <- tile_index_partner
          action_partner = torch.eye(len(self.state), device=self.device, dtype=torch.float32)[tile_index_partner]

          # 2. Update value for tile visited by self
          next_value_self, sindy_loss_module = self.call_module(
              key_module='x_update_tile_visited_self',
              key_state=f'x_value_tile',
              action=action,
              inputs=None,
              participant_embedding=participant_embedding,
              participant_index=participant_id,
              activation_rnn=torch.nn.functional.sigmoid,
          )
          sindy_loss_timesteps[timestep] = sindy_loss_timesteps[timestep] + sindy_loss_module

          # 3. Update value for tile visited by partner
          next_value_partner, sindy_loss_module = self.call_module(
              key_module='x_update_tile_visited_partner',
              key_state=f'x_value_tile',
              action=action_partner,
              inputs=None,
              participant_embedding=participant_embedding,
              participant_index=participant_id,
              activation_rnn=torch.nn.functional.sigmoid,
          )
          sindy_loss_timesteps[timestep] = sindy_loss_timesteps[timestep] + sindy_loss_module

          # 4. Update values for tiles not visited (decay/maintenance)
          # Get list of all tile indices not visited
          all_tiles = set(range(16))
          visited_tiles = {tile_index_self, tile_index_partner}
          not_visited_tiles = all_tiles - visited_tiles

          next_values_not_visited = []
          for tile_idx in not_visited_tiles:
            action_not_visited = torch.eye(len(self.state), device=self.device, dtype=torch.float32)[tile_idx]
            next_values_not_visited_idx, sindy_loss_module = self.call_module(
                key_module='x_update_tile_not_visited',
                key_state='x_value_tile',
                action=action_not_visited,
                inputs=None,
                participant_embedding=participant_embedding,
                participant_index=participant_id,
                activation_rnn=torch.nn.functional.sigmoid,
            )
            next_values_not_visited.append(next_values_not_visited_idx)
            sindy_loss_timesteps[timestep] = sindy_loss_timesteps[timestep] + sindy_loss_module * 1/self._n_actions
            # normalizing sindy_loss for not visited tiles because it would be otherwise over-represented compared to the modules for the visited tiles (self and partner)

          # OLD: format of memory state:
          # memory_state -> {
          #         'x_value_0': (current_value, 0, 0, ..., 0) -> 15 0s in array
          #         'x_value_1': (0, current_value, 0, ..., 0)
          #         ...
          #         'x_value_15': (0, 0, 0, ..., current_value)
          #}

          # 5. Update memory state
          self.state['x_value_tile'] = next_value_self + next_value_partner
          for idx, value in enumerate(next_values_not_visited):
            self.state[f'x_value_tile'] = self.state[f'x_value_tile'] + value

          # 6. Compute output logits for all tiles
          # Stack all tile values and apply beta scaling
          # tile_values = torch.stack([self.state[f'x_value_{i}'] for i in range(16)], dim=-1)

          # Apply beta parameters for each tile
          logits[timestep] = self.state['x_value_tiles'] * self.betas['x_value_tiles'](participant_embedding)

      # Post-process the forward pass
      logits, sindy_loss_timesteps = self.post_forward_pass(logits, sindy_loss_timesteps, batch_first)

      return logits, self.get_state()

    # def forward(self, inputs, prev_state=None, batch_first=False):
    #     """
    #     Forward pass.

    #     Args:
    #         x: (batch, seq_len, input_size) - Input sequences
    #         hidden: Optional hidden state from previous step

    #     Returns:
    #         output: (batch, seq_len, output_size) - Predictions
    #         hidden: Hidden state for next step
    #     """

    #     input_variables, ids, logits, timesteps = self.init_forward_pass()
    #     action, reward, additional_inputs, _ = input_variables

    #     for timestep, index_self, index_partner in zip(timesteps, tile_indexes_self, tile_indexes_partner):
    #       self.record_signal('c_action', action)
    #       # get other signals also recorded e.g. state values for tiles

    #       # define which state was visited by you
    #       tile_index_self = ...
    #       tile_index_partner = ...

    #       next_value_tile_visited_self = self.call_module(
    #             key_module='x_update_tile_visited_self',
    #             key_state=f'x_value_{tile_index_self}',
    #             action=tile_index_self,
    #             inputs=None
    #             participant_embedding=participant_embedding,
    #             participant_index=participant_id,
    #             activation_rnn=torch.nn.functional.sigmoid,
    #         )

    #       next_value_tile_visited_partner = self.call_module(
    #             key_module='x_update_tile_visited_partner',
    #             key_state=f'x_value_{tile_index_partner}',
    #             action=tile_index_partner,
    #             inputs=None
    #             participant_embedding=participant_embedding,
    #             participant_index=participant_id,
    #             activation_rnn=torch.nn.functional.sigmoid,
    #         )

    #       next_value_tile_not_visited = self.call_module(
    #             key_module='x_update_tile_visited_partner',
    #             key_state=f'x_value_{tile_index_partner}',
    #             action=tile_index_partner,
    #             inputs=None
    #             participant_embedding=participant_embedding,
    #             participant_index=participant_id,
    #             activation_rnn=torch.nn.functional.sigmoid,
    #         )


    #     return logits, state




In [None]:
from spice.estimator import SpiceConfig

CONFIG = SpiceConfig(
    rnn_modules=['x_update_tile_visited_self', 'x_update_tile_visited_partner', 'x_update_tile_not_visited'],
    control_parameters=['c_action_self', 'c_action_partner'],
    filter_setup={
        'x_update_tile_visited_self': ['c_action_self', 1, True],
        'x_update_tile_visited_partner': ['c_action_partner', 1, True],
        'x_update_tile_not_visited': ['c_action_self', 0, True],
    },
    library_setup={
        'x_update_tile_visited_self': [],
        'x_update_tile_visited_partner': [],
        'x_update_tile_not_visited': [],
    }
)

In [12]:
# Load your hover data
dataset = convert_dataset(
    file='SPICYCOLLAB.csv',
    df_participant_id = "subject_ID",
    df_block ='currentRound',
    df_choice = 'hover_tile_index',
    df_reward = 'score',
    additional_inputs = ['partner_tile_index', 'sample_number']
    )[0]

n_participants = len(dataset.xs[..., -1].unique())

In [15]:
path_spice = 'spice.pkl'

estimator = SpiceEstimator(
        rnn_class=HoverRNN,
        spice_config=CONFIG,
        n_actions=16,
        n_participants=n_participants,
        epochs=1,
        bagging=True,
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        l1_weight_decay=0,
        learning_rate=1e-2,
        sindy_weight=0.1,
        spice_library_polynomial_degree=2,
        save_path_spice=path_spice,
        sindy_threshold_frequency=32,
        spice_optim_threshold=0.01,
    )

print(f"\nStarting training on {estimator.device}...")
print("=" * 80)
estimator.fit(dataset.xs, dataset.ys, data_test=dataset.xs, target_test=dataset.ys)
# estimator.load_spice(args.model)
print("=" * 80)
print("\nTraining complete!")

# Print example SPICE model for first participant
print("\nExample SPICE model (participant 0):")
print("-" * 80)
estimator.print_spice_model(participant_id=0)
print("-" * 80)

KeyError: 'x_value_reward'