# Data Preparation Tutorial

This tutorial demonstrates how to prepare and process data for use with SPICE. We'll cover:

1. Basic data format requirements
2. Converting raw data to SPICE format
3. Creating synthetic datasets
4. Working with different data types
5. Common data preprocessing steps

## Prerequisites

Before starting this tutorial, make sure you have:
- SPICE installed
- Required dependencies (pandas, numpy, etc.)
- Basic understanding of reinforcement learning data structure

In [8]:
import sys
import os
import pandas as pd
import numpy as np
from spice.resources.bandits import create_dataset, BanditsDrift, get_update_dynamics
from spice.resources.rnn_utils import DatasetRNN
from spice.utils.plotting import plot_session

## 1. Basic Data Format Requirements

SPICE expects data in a specific format for training and analysis. The basic requirements are:

- Data should be in CSV format
- Column names can be customized by setting `df_participant_id`, `df_block`, `df_experiment_id`, `df_choice` and `df_reward`.
- Additional inputs can be given as a list of strings (`additional_inputs`) corresponding to column names
- Required columns:
  - `df_participant_id (default: 'session')`: Unique identifier for each experimental session/participant
  - `df_choice (default: 'choice')`: The action taken by the participant (0-indexed)
  - `df_reward (default: 'reward')`: The reward received for the action

Let's look at an example of properly formatted data:

In [9]:
# Create a sample dataset
sample_data = {
    'session': [1, 1, 1, 2, 2, 2],
    'choice': [0, 1, 0, 1, 0, 1],
    'reward': [1, 0, 1, 0, 1, 0],
    'rt': [0.5, 0.6, 0.4, 0.7, 0.5, 0.6]
}

df = pd.DataFrame(sample_data)
print("Sample data format:")
display(df)

Sample data format:


Unnamed: 0,session,choice,reward,rt
0,1,0,1,0.5
1,1,1,0,0.6
2,1,0,1,0.4
3,2,1,0,0.7
4,2,0,1,0.5
5,2,1,0,0.6


Let's save it as a .csv file.

In [10]:
df.to_csv('sample_data.csv', index=False)

## 2. Converting Experimental Data to SPICE Format

Often, your raw data might not be in the exact format SPICE expects. Here's how to convert it:

In [11]:
from spice.utils.convert_dataset import convert_dataset

dataset, experiment_list, df, dynamics = convert_dataset(file='sample_data.csv')

In [12]:
dataset.xs.shape
dataset.ys.shape

torch.Size([2, 3, 2])

## 3. Creating Synthetic Datasets

SPICE provides utilities to create synthetic datasets for testing and validation. Here's how to create a synthetic dataset using a simple bandit task:

In [13]:
from spice.resources.bandits import AgentQ

# Create a simple Q-learning agent
agent = AgentQ(
    beta_reward=1.0,
    alpha_reward=0.5,
    alpha_penalty=0.5
)

# Create environment
environment = BanditsDrift(sigma=0.2)

# Generate synthetic data
n_sessions = 2
n_trials = 10

dataset, experiments, _ = create_dataset(
    agent=agent,
    environment=environment,
    n_trials=n_trials,
    n_sessions=n_sessions,
    verbose=False
)

# Convert to DataFrame
synthetic_data = []
for i in range(len(dataset)):
    experiment = dataset.xs[i].numpy()
    session_data = pd.DataFrame({
        'session': [i] * n_trials,
        'choice': np.argmax(experiment[:, :2], axis=1),
        'reward': np.max(experiment[:, 2:4], axis=1)
    })
    synthetic_data.append(session_data)

synthetic_df = pd.concat(synthetic_data, ignore_index=True)
print("Synthetic dataset:")
display(synthetic_df)

Creating dataset...


100%|██████████| 2/2 [00:00<00:00, 981.24it/s]

Synthetic dataset:





Unnamed: 0,session,choice,reward
0,0,1,1.0
1,0,0,1.0
2,0,0,0.0
3,0,0,1.0
4,0,1,1.0
5,0,1,0.0
6,0,0,0.0
7,0,1,0.0
8,0,0,1.0
9,0,0,0.0
