In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D 
import seaborn as sns
from scipy.ndimage import gaussian_filter1d
import pickle as pkl
from mne import read_epochs

In [2]:
behav_df = pd.read_csv('/home/qmoreau/Documents/Beta_bursts/Behavioral/behav_df_cleaned_new.csv')
burst_df = pd.read_csv('/home/qmoreau/Documents/Beta_bursts/Burst files/burst_features.csv')

In [3]:

def calculate_burst_rates(behav_df, df_burst_behav, epoch, group, time_bin=0.05, time_buffer=0.125):
    """
    Calculate burst rates for a given epoch, group, and time range.

    Parameters:
    - behav_df: DataFrame containing behavioral data
    - df_burst_behav: DataFrame containing burst behavioral data
    - epoch: The epoch for which burst rates are calculated
    - group: The group for which burst rates are calculated
    - time_bin: Time bin for burst rate calculation
    - time_buffer: Time buffer for extending the time range

    Returns:
    - burst_rates_dict: Dictionary containing burst rates for each subject, block, and coh_cat combination
    """
# Select subjects based on the group
    group_subjects = behav_df[behav_df['group'] == group]['subject'].unique()

    # Define the time range
    time_range = np.arange(-1 - time_buffer, 2 + time_buffer, time_bin)

    burst_df = df_burst_behav[df_burst_behav['epoch'] == epoch]

    # Create a dictionary to store burst rates for each subject, block, and coh_cat combination
    burst_rates_dict = {}

    for subject in group_subjects:
        # Initialize an array for burst rates for the current subject
        subject_burst_rates = np.zeros((len(time_range),))

        subject_df = burst_df[burst_df['subject'] == subject]

        # Group by block and coh_cat
        grouped_df = subject_df.groupby(['block', 'coh_cat'])

        for (block, coh_cat), group in grouped_df:
            # Calculate burst rates for the current subject
            for i, t in enumerate(time_range):
                subject_burst_rates[i] = len(group[(group['peak_time'] >= t) & (group['peak_time'] < t + time_bin)]) / time_bin

            subject_burst_rates_smoothed = gaussian_filter1d(subject_burst_rates, sigma=1)

            # baseline correction
            subject_burst_rates_smoothed = ((subject_burst_rates_smoothed - np.mean(subject_burst_rates_smoothed[12:19])) / np.mean(subject_burst_rates_smoothed[12:19])) * 100

            # Store the burst rates in the dictionary with subject, block, and coh_cat as keys
            key = (block, coh_cat)
            if key not in burst_rates_dict:
                burst_rates_dict[key] = []

            burst_rates_dict[key].append(subject_burst_rates_smoothed)


    # Convert the lists of arrays into a single NumPy array for each key
    for key in burst_rates_dict:
        burst_rates_dict[key] = np.array(burst_rates_dict[key])

    return burst_rates_dict

In [4]:

PC_list = ['PC_7']
# , 'PC_2', 'PC_3', 'PC_4', 'PC_5', 'PC_6', 'PC_7', 'PC_8', 'PC_9', 'PC_10',
#         'PC_11', 'PC_12', 'PC_13', 'PC_14', 'PC_15', 'PC_16', 'PC_17', 'PC_18', 'PC_19', 'PC_20']

for PC in PC_list:
    thresh=np.percentile(burst_df[PC], [25,50,75])
    # Filter bursts within the specified time range and epoch
    filtered_bursts_Q0 = burst_df[(burst_df['peak_time'] >= 0.25) & (burst_df['peak_time'] <= 0.75) & (burst_df['epoch'] == 'mot') & (burst_df[PC] <= thresh[0])]

    # Group by subject and trial, and count the number of bursts
    burst_counts_Q0 = filtered_bursts_Q0.groupby(['subject', 'block', 'trial']).size().reset_index(name='burst_count')

    # Merge burst counts with behav_df
    behav_df_Q0 = behav_df.merge(burst_counts_Q0, on=['subject', 'block', 'trial'], how='left')

    # Filter bursts within the specified time range and epoch
    filtered_bursts_Q1 = burst_df[(burst_df['peak_time'] >= 0.25) & (burst_df['peak_time'] <= 0.75) & (burst_df['epoch'] == 'mot') & (burst_df[PC] > thresh[0]) & (burst_df[PC] <= thresh[1])]

    # Group by subject and trial, and count the number of bursts
    burst_counts_Q1 = filtered_bursts_Q1.groupby(['subject', 'block', 'trial']).size().reset_index(name='burst_count')

    # Merge burst counts with behav_df
    behav_df_Q1 = behav_df.merge(burst_counts_Q1, on=['subject', 'block', 'trial'], how='left')

    # Filter bursts within the specified time range and epoch
    filtered_bursts_Q2 = burst_df[(burst_df['peak_time'] >= 0.25) & (burst_df['peak_time'] <= 0.75) & (burst_df['epoch'] == 'mot') &  (burst_df[PC] > thresh[1]) & (burst_df[PC] <= thresh[2])]
    # Group by subject and trial, and count the number of bursts

    burst_counts_Q2 = filtered_bursts_Q2.groupby(['subject', 'block', 'trial']).size().reset_index(name='burst_count')

    # Merge burst counts with behav_df
    behav_df_Q2 = behav_df.merge(burst_counts_Q2, on=['subject', 'block', 'trial'], how='left')

    # Filter bursts within the specified time range and epoch
    filtered_bursts_Q3 = burst_df[(burst_df['peak_time'] >= 0.25) & (burst_df['peak_time'] <= 0.75) & (burst_df['epoch'] == 'mot') & (burst_df[PC] > thresh[2])]
    # Group by subject and trial, and count the number of bursts

    burst_counts_Q3 = filtered_bursts_Q3.groupby(['subject', 'block', 'trial']).size().reset_index(name='burst_count')

    # Merge burst counts with behav_df
    behav_df_Q3 = behav_df.merge(burst_counts_Q3, on=['subject', 'block', 'trial'], how='left')

    behav_df_Q0.to_csv(f'/home/qmoreau/Documents/Beta_bursts/Burst files/Motor_PCs/{PC}_Motor/behav_df_Q0.csv', index=False)
    behav_df_Q1.to_csv(f'/home/qmoreau/Documents/Beta_bursts/Burst files/Motor_PCs/{PC}_Motor/behav_df_Q1.csv', index=False)
    behav_df_Q2.to_csv(f'/home/qmoreau/Documents/Beta_bursts/Burst files/Motor_PCs/{PC}_Motor/behav_df_Q2.csv', index=False)
    behav_df_Q3.to_csv(f'/home/qmoreau/Documents/Beta_bursts/Burst files/Motor_PCs/{PC}_Motor/behav_df_Q3.csv', index=False)




In [5]:
# Define sliding window parameters
window_size = 0.2 # 200 ms in seconds
step_size = 0.1    # Slide by 100 ms

PC_list_mot = ['PC_7']
for PC in PC_list_mot:
    # Calculate the thresholds for quartiles
    thresh = np.percentile(burst_df[PC], [25, 50, 75])
    
    # Iterate over time windows
    start_time = -1
    end_time = 1.5
    
    window_start = start_time
    while window_start + window_size <= end_time:
        window_end = window_start + window_size
        
        # Process for each quartile
        for q, (lower, upper) in enumerate(zip([None, thresh[0], thresh[1], thresh[2]], [thresh[0], thresh[1], thresh[2], None])):
            
            # Define filters for quartiles
            if lower is None:
                filtered_bursts = burst_df[(burst_df['peak_time'] >= window_start) & 
                                           (burst_df['peak_time'] < window_end) & 
                                           (burst_df['epoch'] == 'mot') & 
                                           (burst_df[PC] <= upper)]
            elif upper is None:
                filtered_bursts = burst_df[(burst_df['peak_time'] >= window_start) & 
                                           (burst_df['peak_time'] < window_end) & 
                                           (burst_df['epoch'] == 'mot') & 
                                           (burst_df[PC] > lower)]
            else:
                filtered_bursts = burst_df[(burst_df['peak_time'] >= window_start) & 
                                           (burst_df['peak_time'] < window_end) & 
                                           (burst_df['epoch'] == 'mot') & 
                                           (burst_df[PC] > lower) & 
                                           (burst_df[PC] <= upper)]
            
            # Group by subject and trial, count bursts in the window
            burst_counts = filtered_bursts.groupby(['subject', 'block', 'trial']).size().reset_index(name='burst_count')
            
            # Merge burst counts with behav_df
            behav_df_window = behav_df.merge(burst_counts, on=['subject', 'block', 'trial'], how='left')
            
            # Use string formatting and handle -0 explicitly
            window_start_str = f"{window_start:.3f}".rstrip('0').rstrip('.')
            window_end_str = f"{window_end:.3f}".rstrip('0').rstrip('.')
            
            # Handle the case where small negative values might display as '-0'
            if window_start_str == '-0':
                window_start_str = '0'
            if window_end_str == '-0':
                window_end_str = '0'
            
            behav_df_window.to_csv(f'/home/qmoreau/Documents/Beta_bursts/Burst files/Motor_PCs_sliding_window/{PC}_Motor_SW/behav_df_Q{q}_window_{window_start_str}_{window_end_str}.csv', index=False)
            #print(window_start)
        
        # Slide the window forward by the step size
        window_start += step_size


In [6]:
# Filter bursts within the specified time range and epoch
overall_bursts = burst_df[(burst_df['peak_time'] >= 0.75) & 
                               (burst_df['peak_time'] <= 1) & 
                               (burst_df['epoch'] == 'mot')]

overall_bursts_count = overall_bursts.groupby(['subject', 'block', 'trial']).size().reset_index(name='burst_count')

# Merge burst counts with behav_df
behav_df_overall = behav_df.merge(overall_bursts_count, on=['subject', 'block', 'trial'], how='left')

behav_df_overall.to_csv(f'/home/qmoreau/Documents/Beta_bursts/Burst files/Motor_PCs/overall_bursts.csv', index=False)
