In [1]:
import os
import pandas as pd
import warnings
import numpy as np

import session_processing_helper as helper
import utils

# Suppress pandas FutureWarning about groupby.apply
warnings.filterwarnings('ignore', category=FutureWarning, message='.*DataFrameGroupBy.apply operated on the grouping columns.*')

# Configuration
data_dir = '/Users/rebekahzhang/data/behavior_data'
exp = 'exp2'
data_folder = os.path.join(data_dir, exp)

sessions_training = utils.load_session_log(data_folder, f'sessions_training_{exp}.csv')

In [2]:
session_info = sessions_training.iloc[55]
events = utils.load_data(utils.generate_events_processed_path(data_folder, session_info))
trials = utils.load_data(utils.generate_trials_path(data_folder, session_info))
trials_data = helper.get_trial_data_df(events.groupby('session_trial_num'))
trials_analyzed = pd.merge(trials, trials_data, on='session_trial_num')

In [3]:
trials_with_features = helper.get_trial_features(trials_analyzed, events)

In [5]:
trials_with_features.keys()

Index(['session_trial_num', 'block_trial_num', 'block_num', 'start_time',
       'end_time', 'bg_drawn', 'bg_length', 'bg_repeats', 'num_bg_licks',
       'miss_trial', 'time_waited', 'reward', 'num_consumption_lick',
       'num_pump', 'first_lick', 'second_lick', 'third_lick', 'fourth_lick',
       'good_trial', 'previous_trial_bg_repeats', 'previous_trial_time_waited',
       'previous_trial_reward', 'previous_trial_miss_trial',
       'bg_repeats_rolling_mean_5', 'bg_repeats_rolling_mean_10',
       'time_waited_rolling_mean_5', 'time_waited_rolling_mean_10',
       'trial_fraction_in_session', 'trial_fraction_in_block',
       'block_fraction_in_session', 'rewarded_streak', 'unrewarded_streak',
       'reward_rate_since_block_start', 'time_since_last_reward_in_block',
       'cumulative_reward_in_block', 'reward_rate_past_1min_in_block',
       'reward_rate_past_5min_in_block', 'reward_rate_past_10min_in_block',
       'cumulative_reward'],
      dtype='object')

In [None]:
def get_previous_trial_performance(trials, rolling_windows=[5, 10]):
    """Add lagged trial features and rolling averages."""
    trials = trials.copy()
    
    # Lagged features
    lagged_columns = ['bg_repeats', 'time_waited', 'reward', 'miss_trial']
    for col in lagged_columns:
        if col in trials.columns:
            trials[f'previous_trial_{col}'] = trials[col].shift(1).fillna(0)
            
    # Rolling averages
    rolling_metrics = ['bg_repeats', 'time_waited']
    for metric in rolling_metrics:
        if metric in trials.columns:
            for window in rolling_windows:
                trials[f'{metric}_rolling_mean_{window}'] = (
                    trials[metric].rolling(window=window, min_periods=1).mean().shift(1, fill_value=0)
                )
    return trials

In [None]:
def get_trial_progress(trials):
    """Add trial progress features."""
    trials = trials.copy()

    trials['trial_fraction_in_session'] = (trials['session_trial_num'] + 1) / len(trials)
    trials['trial_fraction_in_block'] = (
        (trials['block_trial_num'] + 1) / 
        trials['block_num'].map(trials['block_num'].value_counts())
    )
    trials['block_fraction_in_session'] = (trials['block_num'] + 1) / trials['block_num'].nunique()
    
    return trials

In [None]:
def get_rewarded_streak(trials):
    """Add rewarded and unrewarded streak features to trials dataframe."""
    trials = trials.copy()
    rewarded = trials['reward'].fillna(0) > 0
    
    # Rewarded streak (consecutive rewarded trials)
    rewarded_streak = rewarded.groupby((~rewarded).cumsum()).cumsum()
    trials['rewarded_streak'] = rewarded_streak.shift(1, fill_value=0)
    
    # Unrewarded streak (consecutive unrewarded trials)
    unrewarded_streak = (~rewarded).groupby(rewarded.cumsum()).cumsum()
    trials['unrewarded_streak'] = unrewarded_streak.shift(1, fill_value=0)
    
    return trials

In [None]:
def get_block_reward_metrics(trials, events):
    """
    Add reward metrics within blocks to trials dataframe:
    - Reward rate since block start
    - Reward rates for past 1, 5, 10 minutes within block
    - Time since last reward in block
    - Cumulative reward in block
    """
    trials = trials.copy()
    reward_events = events[
        (events['key'] == 'consumption') & 
        (events['reward_size'].notna()) & 
        (events['reward_size'] > 0)
    ]
    time_windows = [60, 300, 600]  # seconds

    # Initialize columns
    trials['reward_rate_since_block_start'] = 0.0
    trials['time_since_last_reward_in_block'] = np.nan
    trials['cumulative_reward_in_block'] = 0.0
    for w in time_windows:
        trials[f'reward_rate_past_{w//60}min_in_block'] = 0.0

    for block_num, block_trials in trials.groupby('block_num'):
        block_start = block_trials['start_time'].min()
        block_rewards = reward_events[
            (reward_events['block_num'] == block_num) &
            (reward_events['session_time'] >= block_start)
        ].sort_values('session_time')
        r_times = block_rewards['session_time'].values
        r_sizes = block_rewards['reward_size'].values

        for idx, row in block_trials.iterrows():
            t = row['start_time']
            mask = r_times < t
            cum_rew = r_sizes[mask].sum()
            trials.at[idx, 'cumulative_reward_in_block'] = cum_rew

            elapsed = t - block_start
            if elapsed > 0:
                trials.at[idx, 'reward_rate_since_block_start'] = cum_rew / elapsed

            if mask.any():
                trials.at[idx, 'time_since_last_reward_in_block'] = t - r_times[mask].max()

            for w in time_windows:
                w_start = max(t - w, block_start)
                mask_w = (r_times >= w_start) & (r_times < t)
                w_rew = r_sizes[mask_w].sum()
                actual_w = t - w_start
                if actual_w > 0:
                    trials.at[idx, f'reward_rate_past_{w//60}min_in_block'] = w_rew / actual_w
    
    trials['cumulative_reward'] = trials['reward'].fillna(0).cumsum().shift(fill_value=0)
    return trials

In [None]:
def get_trial_features(trials_analyzed, events):
    trials_with_performance = get_previous_trial_performance(trials_analyzed)
    trials_with_progress = get_trial_progress(trials_with_performance)
    trials_with_streak = get_rewarded_streak(trials_with_progress)
    trials_with_block_reward_metrics = get_block_reward_metrics(trials_with_streak, events)
    return trials_with_block_reward_metrics

In [None]:
trials_with_features = get_trial_features(trials_analyzed, events)

In [None]:
trials_with_features