In [45]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from scipy.ndimage import label


In [46]:
behav_df = pd.read_csv('./data/derivatives/behav_df_cleaned_new.csv')
behav_df = behav_df[behav_df['reach_vis_abs_err']<=60].reset_index(drop=True)

burst_df = pd.read_csv('./data/derivatives/burst_features.csv')

lookup_df = behav_df[["coh_cat", "perturb_cat", "trial", "block", "subject"]]

# Merge the two DataFrames based on the specified columns and fill missing values with NaN
df_burst_behav = burst_df.merge(lookup_df, on=["trial", "block", "subject"], how='left')


In [47]:
vis_epoch_lims=(-1.0, 2.0)
mot_epoch_lims=(-1.0, 1.5)
vis_plot_lims=(-0.25, 1.85)
mot_plot_lims=(-0.7, 1.2)

In [50]:
def compute_burst_counts_by_pc_quartile(
    df_burst_behav, behav_df, epoch,
    window_width=0.2, step_size=0.025,
    epoch_lims=(-1.0, 1.5), n_q=4
):
    output_dir = './output'
    os.makedirs(output_dir, exist_ok=True)

    time_centers = np.arange(epoch_lims[0] + window_width / 2,
                             epoch_lims[1] - window_width / 2 + step_size,
                             step_size)
    time_columns = [f"time_{round(tc, 3)}" for tc in time_centers]

    pcs = [col for col in df_burst_behav.columns if col.startswith("PC_")]
    burst_feature_cols = ['peak_time', 'peak_freq', 'peak_amp_base', 'fwhm_freq', 'fwhm_time']
    metadata_cols = [col for col in df_burst_behav.columns if col not in pcs + burst_feature_cols]
    group_cols = ["subject", "block", "trial"]

    behav_df = behav_df.set_index(group_cols)
    extra_behav_cols = [
        'group',
        'trial_perturb', 'reach_dur', 'reach_rt',
        'trial_directions', 'trial_target', 'aim_target', 'reach_target',
        'aim_real_angle', 'reach_real_angle', 'true_target_angle',
        'reach_vis_angle', 'reach_vis_err', 'reach_vis_abs_err',
        'aim_vis_angle', 'aim_vis_err', 'aim_vis_abs_err'
    ]

    for pc in ['PC_7','PC_8','PC_9','PC_10']:#pcs:
        print(f"Processing {pc}")
        step = 100 / n_q
        q_bins = np.percentile(df_burst_behav[pc], np.arange(0, 100 + step, step))
        quartile_dfs = []

        for q in range(n_q):
            df_q = df_burst_behav[
                (df_burst_behav[pc] >= q_bins[q]) &
                (df_burst_behav[pc] < q_bins[q + 1])
            ].copy()
            df_q["quartile"] = q + 1

            trial_groups = df_q.groupby(group_cols)
            trial_records = []

            for (subject, block, trial), trial_df in trial_groups:
                trial_meta = trial_df.iloc[0][metadata_cols].to_dict()
                trial_meta["quartile"] = q + 1

                # Count bursts in time windows
                peak_times = trial_df["peak_time"].values
                for tc, col in zip(time_centers, time_columns):
                    count = np.sum((peak_times >= tc - window_width / 2) & (peak_times < tc + window_width / 2))
                    trial_meta[col] = count

                # Add behavioral values
                key = (subject, block, trial)
                if key in behav_df.index:
                    for col in extra_behav_cols:
                        trial_meta[col] = behav_df.at[key, col]

                trial_records.append(trial_meta)

            quartile_df = pd.DataFrame(trial_records)
            quartile_dfs.append(quartile_df)

        pc_df = pd.concat(quartile_dfs, ignore_index=True)
        pc_df.to_csv(os.path.join(output_dir, f"{pc}_{epoch}_trial_burst_counts.csv"), index=False)


In [51]:
df_visual = df_burst_behav[df_burst_behav['epoch'] == 'vis']
df_motor = df_burst_behav[df_burst_behav['epoch'] == 'mot']

In [52]:
compute_burst_counts_by_pc_quartile(df_visual, behav_df, 'vis', epoch_lims=vis_epoch_lims)

Processing PC_7
Processing PC_8
Processing PC_9
Processing PC_10


In [53]:
compute_burst_counts_by_pc_quartile(df_motor, behav_df, 'mot', epoch_lims=mot_epoch_lims)

Processing PC_7
Processing PC_8
Processing PC_9
Processing PC_10
