# All oscillation analysis

Brief 1-2 sentence description of notebook.

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os
import collections
import itertools
from collections import defaultdict
from itertools import combinations

In [3]:
os.environ["SPECTRAL_CONNECTIVITY_ENABLE_GPU"] = "true"
import cupy as cp

In [4]:
# Imports of all used packages and libraries
import numpy as np
import pandas as pd
from scipy import stats
from scipy.stats import mannwhitneyu


In [5]:
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import colorsys

In [6]:
from sklearn.metrics import confusion_matrix

In [7]:
from spectral_connectivity import Multitaper, Connectivity
import spectral_connectivity

In [8]:
import fcwt

In [9]:
FONTSIZE = 20

In [10]:
font = {'weight' : 'medium',
        'size'   : 20}

matplotlib.rc('font', **font)

# Functions

In [11]:
def generate_pairs(lst):
    """
    Generates all unique pairs from a list.

    Parameters:
    - lst (list): The list to generate pairs from.

    Returns:
    - list: A list of tuples, each containing a unique pair from the input list.
    """
    n = len(lst)
    return [(lst[i], lst[j]) for i in range(n) for j in range(i+1, n)]

In [12]:
def update_array_by_mask(array, mask, value=np.nan):
    """
    Update elements of an array based on a mask and replace them with a specified value.

    Parameters:
    - array (np.array): The input numpy array whose values are to be updated.
    - mask (np.array): A boolean array with the same shape as `array`. Elements of `array` corresponding to True in the mask are replaced.
    - value (scalar, optional): The value to assign to elements of `array` where `mask` is True. Defaults to np.nan.

    Returns:
    - np.array: A copy of the input array with updated values where the mask is True.

    Example:
    >>> array = np.array([1, 2, 3, 4])
    >>> mask = np.array([False, True, False, True])
    >>> update_array_by_mask(array, mask, value=0)
    array([1, 0, 3, 0])
    """
    result = array.copy()
    result[mask] = value
    return result

In [13]:
def nan_helper(y):
    """Helper to handle indices and logical indices of NaNs.

    Input:
        - y, 1d numpy array with possible NaNs
    Output:
        - nans, logical indices of NaNs
        - index, a function, with signature indices= index(logical_indices),
          to convert logical indices of NaNs to 'equivalent' indices
    Example:
        # linear interpolation of NaNs
        nans, x= nan_helper(y)
        y[nans]= np.interp(x(nans), x(~nans), y[~nans])
    """

    return np.isnan(y), lambda z: z.nonzero()[0]

In [14]:
def interpolate_signal(signal):
    """
    Interpolates missing values (NaNs) in a given signal array using linear interpolation.

    The function finds NaN elements in the signal, computes linear interpolation based on
    non-NaN values, and fills the NaNs with these interpolated values.

    Parameters:
        signal (numpy.ndarray): The input signal array containing NaN values that need to be interpolated.

    Returns:
        numpy.ndarray: The signal array with NaN values interpolated.
    """
    if signal is None or not isinstance(signal, np.ndarray):
        raise ValueError("Input must be a numpy ndarray.")
    
    result = signal.copy()
    nans, x = nan_helper(result)
    result[nans] = np.interp(x(nans), x(~nans), result[~nans])
    return result

## Inputs & Data

Explanation of each input and where it comes from.

In [15]:
# Inputs and Required data loading
# input varaible names are in all caps snake case
# Whenever an input changes or is used for processing 
# the vairables are all lower in snake case
OUTPUT_DIR = r"./proc" # where data is saved should always be shown in the inputs

In [16]:
TIME_HALFBANDWIDTH_PRODUCT = 2
TIME_WINDOW_DURATION = 1
TIME_WINDOW_STEP = 0.5
RESAMPLE_RATE=1000

In [17]:
zscore_threshold = 4
VOLTAGE_SCALING_VALUE = 0.195

In [18]:
BAND_TO_FREQ = {"theta": (4,12), "gamma": (30,51)}

In [19]:
LFP_TRACES_DF = pd.read_pickle("./proc/rce_pilot_2_01_lfp_traces_and_frames.pkl")

In [20]:
LFP_TRACES_DF.shape

(61, 23)

## Preprocessing

In [21]:
original_trace_columns = [col for col in LFP_TRACES_DF.columns if "trace" in col]

In [22]:
original_trace_columns

['mPFC_lfp_trace',
 'MD_lfp_trace',
 'LH_lfp_trace',
 'BLA_lfp_trace',
 'vHPC_lfp_trace']

In [23]:
for col in original_trace_columns:
    print(col)
    LFP_TRACES_DF[col] = LFP_TRACES_DF[col].apply(lambda x: x.astype(np.float32) * VOLTAGE_SCALING_VALUE)

mPFC_lfp_trace
MD_lfp_trace
LH_lfp_trace
BLA_lfp_trace
vHPC_lfp_trace


In [24]:
LFP_TRACES_DF.head()

Unnamed: 0,cohort,session_dir,tone_frames,box_1_port_entry_frames,box_2_port_entry_frames,video_name,session_path,recording,current_subject,subject,...,video_timestamps,tone_timestamps,box_1_port_entry_timestamps,box_2_port_entry_timestamps,lfp_timestamps,mPFC_lfp_trace,MD_lfp_trace,LH_lfp_trace,BLA_lfp_trace,vHPC_lfp_trace
0,2,20230612_101430_standard_comp_to_training_D1_s...,"[[980, 1181], [3376, 3575], [5672, 5871], [746...","[[490, 514], [518, 558], [558, 637], [638, 640...","[[33137, 33147], [33665, 33666], [33668, 33669...",20230612_101430_standard_comp_to_training_D1_s...,/scratch/back_up/reward_competition_extention/...,20230612_101430_standard_comp_to_training_D1_s...,1.3,1.3,...,"[-2, 1384, 2770, 4156, 4156, 5542, 6928, 6928,...","[[982229, 1182226], [3382227, 3582224], [56822...","[[491029, 515227], [519426, 558629], [559427, ...","[[33082200, 33090003], [33565003, 33567000], [...","[0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 2...","[95.354996, 82.09499, 97.5, 132.405, 123.825, ...","[46.019997, 49.335, 75.27, 97.89, 77.61, 40.55...","[61.425, 66.104996, 81.899994, 90.479996, 71.5...","[54.6, 54.405, 73.32, 86.189995, 59.085, 19.89...","[55.574997, 79.365, 128.11499, 170.43, 189.344..."
1,2,20230612_101430_standard_comp_to_training_D1_s...,"[[980, 1180], [3376, 3575], [5672, 5871], [746...","[[490, 514], [518, 558], [558, 637], [638, 640...","[[33021, 33027], [33502, 33503], [33504, 33506...",20230612_101430_standard_comp_to_training_D1_s...,/scratch/back_up/reward_competition_extention/...,20230612_101430_standard_comp_to_training_D1_s...,1.3,1.3,...,"[-2, 1384, 2770, 4156, 4156, 5542, 6928, 6928,...","[[982229, 1182226], [3382227, 3582224], [56822...","[[491029, 515227], [519426, 558629], [559427, ...","[[33082200, 33090003], [33565003, 33567000], [...","[0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 2...","[95.354996, 82.09499, 97.5, 132.405, 123.825, ...","[46.019997, 49.335, 75.27, 97.89, 77.61, 40.55...","[61.425, 66.104996, 81.899994, 90.479996, 71.5...","[54.6, 54.405, 73.32, 86.189995, 59.085, 19.89...","[55.574997, 79.365, 128.11499, 170.43, 189.344..."
2,2,20230612_101430_standard_comp_to_training_D1_s...,"[[980, 1181], [3376, 3575], [5672, 5871], [746...","[[490, 514], [518, 558], [558, 637], [638, 640...","[[33137, 33147], [33665, 33666], [33668, 33669...",20230612_101430_standard_comp_to_training_D1_s...,/scratch/back_up/reward_competition_extention/...,20230612_101430_standard_comp_to_training_D1_s...,1.4,1.4,...,"[-2, 1384, 2770, 4156, 4156, 5542, 6928, 6928,...","[[982229, 1182226], [3382227, 3582224], [56822...","[[491029, 515227], [519426, 558629], [559427, ...","[[33082200, 33090003], [33565003, 33567000], [...","[0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 2...","[19.89, 29.445, 33.149998, 37.829998, 45.43499...","[29.445, 28.859999, 25.935, 23.205, 21.449999,...","[28.47, 25.349998, 22.035, 22.814999, 23.00999...","[68.64, 90.284996, 93.795, 71.564995, 90.09, 1...","[62.984997, 86.774994, 104.13, 86.96999, 75.65..."
3,2,20230612_101430_standard_comp_to_training_D1_s...,"[[980, 1180], [3376, 3575], [5672, 5871], [746...","[[490, 514], [518, 558], [558, 637], [638, 640...","[[33021, 33027], [33502, 33503], [33504, 33506...",20230612_101430_standard_comp_to_training_D1_s...,/scratch/back_up/reward_competition_extention/...,20230612_101430_standard_comp_to_training_D1_s...,1.4,1.4,...,"[-2, 1384, 2770, 4156, 4156, 5542, 6928, 6928,...","[[982229, 1182226], [3382227, 3582224], [56822...","[[491029, 515227], [519426, 558629], [559427, ...","[[33082200, 33090003], [33565003, 33567000], [...","[0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 2...","[19.89, 29.445, 33.149998, 37.829998, 45.43499...","[29.445, 28.859999, 25.935, 23.205, 21.449999,...","[28.47, 25.349998, 22.035, 22.814999, 23.00999...","[68.64, 90.284996, 93.795, 71.564995, 90.09, 1...","[62.984997, 86.774994, 104.13, 86.96999, 75.65..."
4,2,20230612_112630_standard_comp_to_training_D1_s...,"[[1125, 1324], [3519, 3720], [5815, 6014], [76...","[[192, 248], [389, 405], [916, 929], [929, 948...","[[33019, 33020], [33246, 33251], [33253, 33255...",20230612_112630_standard_comp_to_training_D1_s...,/scratch/back_up/reward_competition_extention/...,20230612_112630_standard_comp_to_training_D1_s...,1.1,1.1,...,"[1384, 2444, 2769, 4155, 5541, 6708, 6927, 831...","[[1126742, 1326741], [3526740, 3726740], [5826...","[[192745, 249350], [389747, 407142], [917544, ...","[[33037711, 33038706], [33264908, 33270313], [...","[0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 2...","[-4.875, 21.255, 75.465, 113.295, 100.619995, ...","[13.65, 36.66, 53.82, 33.735, 3.12, -15.794999...","[7.995, 32.76, 56.159996, 48.164997, 29.445, 3...","[4.4849997, 8.969999, 19.109999, 26.324999, 16...","[34.32, 50.504997, 44.265, -1.365, -31.784998,..."


# Calculating modified zscore

`0.6745(xi – x̃) / MAD`

In [25]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_MAD".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: stats.median_abs_deviation(x))

mPFC_lfp_trace
MD_lfp_trace
LH_lfp_trace


KeyboardInterrupt: 

In [None]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_modified_zscore".format(brain_region)
    MAD_column = "{}_lfp_MAD".format(brain_region)

    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF.apply(lambda x: 0.6745 * (x[col] - np.median(x[col])) / x[MAD_column], axis=1)

In [None]:
LFP_TRACES_DF[updated_column]

## calculating root mean sequare

In [None]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_RMS".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: (x / np.sqrt(np.mean(x**2))).astype(np.float32))


In [None]:
LFP_TRACES_DF.head()

## Filtering for zscore value

In [None]:
zscore_columns = [col for col in LFP_TRACES_DF.columns if "zscore" in col]

In [None]:
zscore_columns

In [None]:
for col in zscore_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_mask".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: np.abs(x) >= zscore_threshold)

In [None]:
LFP_TRACES_DF[updated_column].head()

In [None]:
LFP_TRACES_DF[updated_column].iloc[0].shape

In [None]:
sum(LFP_TRACES_DF[updated_column].iloc[0])

- Filtering raw traces by zscore

In [None]:
LFP_TRACES_DF[col].head()

In [None]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_trace_filtered".format(brain_region)    
    mask_column = "{}_lfp_mask".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF.apply(lambda x: update_array_by_mask(x[col], x[mask_column]), axis=1)

In [None]:
LFP_TRACES_DF[col].head()

In [None]:
sum(np.isnan(LFP_TRACES_DF[col].iloc[0]))

- Calculating RMS of filtered signal

In [None]:
filtered_trace_column = [col for col in LFP_TRACES_DF if "lfp_trace_filtered" in col]

In [None]:
for col in filtered_trace_column:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_RMS_filtered".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: (x / np.sqrt(np.nanmean(x**2))).astype(np.float32))

- Dropping unnecessary columns

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "zscore" in col or "MAD" in col], errors="ignore")

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "RMS" in col and not "filtered" in col], errors="ignore")


In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "trace" in col and not "filtered" in col], errors="ignore")


In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "mask" in col and not "filtered" in col], errors="ignore")


In [None]:
LFP_TRACES_DF.columns

- Interpolating filtered RMS

In [None]:
filtered_RMS_column = [col for col in LFP_TRACES_DF if "lfp_RMS_filtered" in col]

In [None]:
filtered_RMS_column

In [None]:
for col in filtered_RMS_column:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_RMS_interpolated".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: interpolate_signal(x))

In [None]:
# RMS_columns = [col for col in LFP_TRACES_DF if "RMS" in col and "filtered" not in col]
RMS_columns = [col for col in LFP_TRACES_DF if "interpolated" in col]

In [None]:
RMS_columns

In [None]:
for col in RMS_columns:
    print(col)
    LFP_TRACES_DF[col] = LFP_TRACES_DF[col].apply(lambda x: x.astype(np.float16))

In [None]:
LFP_TRACES_DF[col]

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "filtered" in col], errors="ignore")

In [None]:
LFP_TRACES_DF.columns

# Calculating the power with CWT

In [None]:
RMS_columns

In [None]:
signal = LFP_TRACES_DF["LH_lfp_RMS_interpolated"].iloc[0][:10000]

# Calculate CWT without plotting...


In [None]:
RMS_columns

In [None]:
#Initialize
fs = 500

f0 = 1 #lowest frequency
# f1 = 51.5 #highest frequency
# fn = 51 #number of frequencies
f1 = 13 #highest frequency
fn = 12 #number of frequencies
nthreads = 8


In [None]:
freqs, out = fcwt.cwt(signal, fs, f0, f1, fn, nthreads=nthreads, norm=True)

In [None]:
freqs

In [None]:
for col in RMS_columns:
    # brain_region = col.split("_")[0]
    brain_region = col.replace("_lfp", "")
    print(brain_region)

    frequencies_col = f"{brain_region}_power_theta_frequencies"
    power_col = f"{brain_region}_power_theta_frequencies_all_windows"
    
    try:
        LFP_TRACES_DF[power_col] = LFP_TRACES_DF[col].apply(lambda x: fcwt.cwt(x[::2], fs, f0, f1, fn, nthreads=nthreads, norm=True)[1][::50])
        LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[col].apply(lambda x: fcwt.cwt(x[:10000], fs, f0, f1, fn, nthreads=nthreads, norm=True)[0])
   
    except Exception as e: 
        print(e)

In [None]:
#Initialize
fs = 500

f0 = 30 #lowest frequency
# f1 = 51.5 #highest frequency
# fn = 51 #number of frequencies
f1 = 40 #highest frequency
fn = 10 #number of frequencies
nthreads = 8

In [None]:
freqs, out = fcwt.cwt(signal, fs, f0, f1, fn, nthreads=nthreads, norm=True)

In [None]:
freqs

In [None]:
for col in RMS_columns:
    # brain_region = col.split("_")[0]
    brain_region = col.replace("_lfp", "")
    print(brain_region)

    frequencies_col = f"{brain_region}_power_beta30_40_frequencies"
    power_col = f"{brain_region}_power_beta30_40_frequencies_all_windows"
    
    try:
        LFP_TRACES_DF[power_col] = LFP_TRACES_DF[col].apply(lambda x: fcwt.cwt(x[::2], fs, f0, f1, fn, nthreads=nthreads, norm=True)[1][::50])
        LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[col].apply(lambda x: fcwt.cwt(x[:10000], fs, f0, f1, fn, nthreads=nthreads, norm=True)[0])
   
    except Exception as e: 
        print(e)

In [None]:
#Initialize
fs = 500

f0 = 40 #lowest frequency
# f1 = 51.5 #highest frequency
# fn = 51 #number of frequencies
f1 = 51 #highest frequency
fn = 11 #number of frequencies
nthreads = 8


In [None]:
freqs, out = fcwt.cwt(signal, fs, f0, f1, fn, nthreads=nthreads, norm=True)

In [None]:
freqs

In [None]:
for col in RMS_columns:
    # brain_region = col.split("_")[0]
    brain_region = col.replace("_lfp", "")
    print(brain_region)

    frequencies_col = f"{brain_region}_power_beta40_50_frequencies"
    power_col = f"{brain_region}_power_beta40_50_frequencies_all_windows"
    
    try:
        LFP_TRACES_DF[power_col] = LFP_TRACES_DF[col].apply(lambda x: fcwt.cwt(x[::2], fs, f0, f1, fn, nthreads=nthreads, norm=True)[1][::50])
        LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[col].apply(lambda x: fcwt.cwt(x[:10000], fs, f0, f1, fn, nthreads=nthreads, norm=True)[0])
   
    except Exception as e: 
        print(e)

In [None]:
LFP_TRACES_DF.to_pickle("./proc/rce2_spectral_CWT.pkl")

In [None]:
raise ValueError()

## Calculating phase of signals

In [None]:
from scipy.signal import butter, filtfilt, hilbert

- Filtering for theta and gamma

In [None]:
fs = 500
order=4

In [None]:
freq_band = [4, 12]
b, a = butter(order, freq_band, fs=fs, btype='band')

In [None]:
for col in RMS_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_theta_band".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: filtfilt(b, a, x, padtype=None).astype(np.float16))

In [None]:
LFP_TRACES_DF[updated_column]

In [None]:
freq_band = [30, 50]
b, a = butter(order, freq_band, fs=fs, btype='band')

In [None]:
for col in RMS_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_gamma_band".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: filtfilt(b, a, x, padtype=None).astype(np.float16))

- Calculating the phase

In [None]:
band_columns = [col for col in LFP_TRACES_DF if "band" in col]

In [None]:
band_columns

In [None]:
for col in band_columns:
    print(col)
    brain_region = col.replace("_band", "")
    updated_column = "{}_phase".format(brain_region)
    print(updated_column)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: np.angle(hilbert(x), deg=False).astype(np.float16))

In [None]:
LFP_TRACES_DF.columns

# Getting all the pairs

In [None]:
input_columns = [col for col in LFP_TRACES_DF.columns if "interpolated" in col]

In [None]:
input_columns

In [None]:
all_suffixes = set(["_".join(col.split("_")[1:]) for col in input_columns])

In [None]:
all_suffixes

In [None]:
brain_region_pairs = generate_pairs(sorted(list(set([col.split("lfp")[0].strip("_") for col in sorted(input_columns)]))))

In [None]:
brain_region_pairs

In [None]:
for first_region, second_region in brain_region_pairs:
    for suffix in all_suffixes:
        region_1 = "_".join([first_region, suffix])
        region_2 = "_".join([second_region, suffix])
        print(region_1)
        print(region_2)

# Calculate Granger's

In [None]:
for first_region, second_region in brain_region_pairs:
    for suffix in all_suffixes:
        suffix_for_name = suffix.replace("lfp", "").strip("_")
        region_1 = "_".join([first_region, suffix])
        region_2 = "_".join([second_region, suffix])
         # Define base name for pair
        pair_base_name = f"{region_1.split('_')[0]}_{region_2.split('_')[0]}_{suffix_for_name}"
        print(pair_base_name)

In [None]:
for first_region, second_region in brain_region_pairs:
    for suffix in all_suffixes:
        region_1 = "_".join([first_region, suffix])
        region_2 = "_".join([second_region, suffix])
        region_1_base_name = region_1.split('_')[0]
        region_2_base_name = region_2.split('_')[0]
        pair_base_name = f"{region_1_base_name}_{region_2_base_name}"
        print(pair_base_name)

        try:
            # Define column names
            multitaper_col = f"{pair_base_name}_granger_multitaper"
            connectivity_col = f"{pair_base_name}_granger_connectivity"
            frequencies_col = f"{pair_base_name}_granger_frequencies"
            granger_1_2_col = f"{region_1_base_name}_{region_2_base_name}_granger_all_frequencies_all_windows"
            granger_2_1_col = f"{region_2_base_name}_{region_1_base_name}_granger_all_frequencies_all_windows"

            # Apply Multitaper function
            LFP_TRACES_DF[multitaper_col] = LFP_TRACES_DF.apply(
                lambda x: Multitaper(
                    time_series=np.array([x[region_1], x[region_2]]).T, 
                    sampling_frequency=RESAMPLE_RATE, 
                    time_halfbandwidth_product=TIME_HALFBANDWIDTH_PRODUCT, 
                    time_window_step=TIME_WINDOW_STEP, 
                    time_window_duration=TIME_WINDOW_DURATION
                ), 
                axis=1
            )

            # Apply Connectivity function
            LFP_TRACES_DF[connectivity_col] = LFP_TRACES_DF[multitaper_col].apply(
                lambda x: Connectivity.from_multitaper(x)
            )

            # Apply frequencies and granger functions
            LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[connectivity_col].apply(
                lambda x: x.frequencies[:62]
            )

            LFP_TRACES_DF[granger_1_2_col] = LFP_TRACES_DF[connectivity_col].apply(
                lambda x: x.pairwise_spectral_granger_prediction()[:,:,0,1]
            )

            LFP_TRACES_DF[granger_2_1_col] = LFP_TRACES_DF[connectivity_col].apply(
                lambda x: x.pairwise_spectral_granger_prediction()[:,:,1,0]
            )

            LFP_TRACES_DF[granger_1_2_col] = LFP_TRACES_DF[granger_1_2_col].apply(lambda x: x.astype(np.float32)[:, :62])
            LFP_TRACES_DF[granger_2_1_col] = LFP_TRACES_DF[granger_2_1_col].apply(lambda x: x.astype(np.float32)[:, :62])

        except Exception as e: 
            print(e)

        # Drop temporary columns
        LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[multitaper_col, connectivity_col], errors="ignore")
        


- Getting the timestamps of the granger

In [None]:
LFP_TRACES_DF["granger_timestamps"] = LFP_TRACES_DF["lfp_timestamps"].apply(lambda x: x[(RESAMPLE_RATE//2):(-RESAMPLE_RATE//2):(RESAMPLE_RATE//2)])


- Making sure that the timestamps for granger makes sense with shape and values

In [None]:
LFP_TRACES_DF["granger_timestamps"].head().apply(lambda x: x.shape)

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_all_frequencies_all_windows" in col][0]].iloc[0].shape

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "lfp_timestamps" in col][0]].iloc[0]

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_timestamps" in col][0]].iloc[0]

- Checking if the right frequencies are being used

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_frequencies" in col]].head()

In [None]:
LFP_TRACES_DF["granger_calculation_frequencies"] = LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_frequencies" in col][0]].copy()

- Dropping unnecessary columns

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "granger_frequencies" in col], errors="ignore")

In [None]:
LFP_TRACES_DF.head()

In [None]:
LFP_TRACES_DF.to_pickle("./proc/rce2_spectral_granger.pkl")

## Coherece Calculation

- Calculating the coherence

In [None]:
for first_region, second_region in brain_region_pairs:
    for suffix in all_suffixes:
        suffix_for_name = suffix.replace("lfp", "").strip("_")
        region_1 = "_".join([first_region, suffix])
        region_2 = "_".join([second_region, suffix])
         # Define base name for pair
        pair_base_name = f"{region_1.split('_')[0]}_{region_2.split('_')[0]}_{suffix_for_name}"
        print(region_1)
        print(region_2)
        print(pair_base_name)

        try:
            # Define column names
            multitaper_col = f"{pair_base_name}_coherence_multitaper"
            connectivity_col = f"{pair_base_name}_coherence_connectivity"
            frequencies_col = f"{pair_base_name}_coherence_frequencies"
            coherence_col = f"{pair_base_name}_coherence_all_frequencies_all_windows"

            # Apply Multitaper function
            LFP_TRACES_DF[multitaper_col] = LFP_TRACES_DF.apply(
                lambda x: Multitaper(
                    time_series=np.array([x[region_1], x[region_2]]).T, 
                    sampling_frequency=RESAMPLE_RATE, 
                    time_halfbandwidth_product=TIME_HALFBANDWIDTH_PRODUCT, 
                    time_window_step=TIME_WINDOW_STEP, 
                    time_window_duration=TIME_WINDOW_DURATION
                ), 
                axis=1
            )

            # Apply Connectivity function
            LFP_TRACES_DF[connectivity_col] = LFP_TRACES_DF[multitaper_col].apply(
                lambda x: Connectivity.from_multitaper(x)
            )

            # Apply frequencies and coherence functions
            LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[connectivity_col].apply(
                lambda x: x.frequencies[:62]
            )
            LFP_TRACES_DF[coherence_col] = LFP_TRACES_DF[connectivity_col].apply(
                lambda x: x.coherence_magnitude()[:,:,0,1]
            )

            LFP_TRACES_DF[coherence_col] = LFP_TRACES_DF[coherence_col].apply(lambda x: x[:,:62].astype(np.float32))

        except Exception as e: 
            print(e)

        # Drop temporary columns
        LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[multitaper_col, connectivity_col], errors="ignore")

- Getting the timestamps of the coherence

In [None]:
LFP_TRACES_DF["coherence_timestamps"] = LFP_TRACES_DF["lfp_timestamps"].apply(lambda x: x[(RESAMPLE_RATE//2):(-RESAMPLE_RATE//2):(RESAMPLE_RATE//2)])


- Making sure that the timestamps for coherence makes sense with shape and values

In [None]:
LFP_TRACES_DF["coherence_timestamps"].head().apply(lambda x: x.shape)

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_all_frequencies_all_windows" in col][0]].iloc[0].shape

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "lfp_timestamps" in col][0]].iloc[0]

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_timestamps" in col][0]].iloc[0]

- Checking if the right frequencies are being used

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_frequencies" in col]].head()

In [None]:
LFP_TRACES_DF["coherence_calculation_frequencies"] = LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_frequencies" in col][0]].copy()

- Dropping unnecessary columns

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "coherence_frequencies" in col], errors="ignore")

In [None]:
LFP_TRACES_DF.head()

In [None]:
LFP_TRACES_DF.to_pickle("./proc/rce2_spectral_coherence.pkl")

In [None]:
raise ValueError()

# Power Calcuation

- Getting the column name of all the traces

In [None]:
input_columns = [col for col in LFP_TRACES_DF.columns if "trace" in col or "RMS" in col]

In [None]:
input_columns

In [None]:
for col in input_columns:
    print(col)
    LFP_TRACES_DF[col] = LFP_TRACES_DF[col].apply(lambda x: x.astype(np.float16))

- Calcuating the power at each frequency band

In [None]:
LFP_TRACES_DF[col].iloc[0]

In [None]:
for col in input_columns:
    # brain_region = col.split("_")[0]
    brain_region = col.replace("_lfp", "")
    print(brain_region)

    # Define column names
    multitaper_col = f"{brain_region}_power_multitaper"
    connectivity_col = f"{brain_region}_power_connectivity"
    frequencies_col = f"{brain_region}_power_frequencies"
    power_col = f"{brain_region}_power_all_frequencies_all_windows"
    
    try:
        # Apply Multitaper function to the lfp_trace column
        LFP_TRACES_DF[multitaper_col] = LFP_TRACES_DF[col].apply(
            lambda x: Multitaper(
                time_series=x, 
                sampling_frequency=RESAMPLE_RATE, 
                time_halfbandwidth_product=TIME_HALFBANDWIDTH_PRODUCT,
                time_window_duration=TIME_WINDOW_DURATION, 
                time_window_step=TIME_WINDOW_STEP
            )
        )

        # Apply Connectivity function to the multitaper column
        LFP_TRACES_DF[connectivity_col] = LFP_TRACES_DF[multitaper_col].apply(
            lambda x: Connectivity.from_multitaper(x)
        )

        # Apply frequencies and power functions to the connectivity column
        LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[connectivity_col].apply(
            lambda x: x.frequencies
        )
        LFP_TRACES_DF[power_col] = LFP_TRACES_DF[connectivity_col].apply(
            lambda x: x.power().squeeze()
        )
        
        LFP_TRACES_DF[power_col] = LFP_TRACES_DF[power_col].apply(lambda x: x.astype(np.float16))
            
        # Removing unnecessary columns
        LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[multitaper_col, connectivity_col], errors="ignore")
    
    except Exception as e: 
        print(e)

- Getting the timestamps of the power

In [None]:
LFP_TRACES_DF["power_timestamps"] = LFP_TRACES_DF["lfp_timestamps"].apply(lambda x: x[(RESAMPLE_RATE//2):(-RESAMPLE_RATE//2):(RESAMPLE_RATE//2)])
# .iloc[0][500:-500:500].shape

- Making sure that the timestamps for power makes sense with shape and values

In [None]:
LFP_TRACES_DF["power_timestamps"].head().apply(lambda x: x.shape)

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_all_frequencies_all_windows" in col][0]].iloc[0].shape

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "lfp_timestamps" in col][0]].iloc[0]

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_timestamps" in col][0]].iloc[0]

- Checking if the right frequencies are being used

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_frequencies" in col]].head()

In [None]:
LFP_TRACES_DF["power_calculation_frequencies"] = LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_frequencies" in col][0]].copy()

- Dropping unnecessary columns

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "power_frequencies" in col], errors="ignore")

In [None]:
LFP_TRACES_DF.head()

In [None]:
LFP_TRACES_DF["mPFC_RMS_filtered_power_all_frequencies_all_windows"].head()

In [None]:
LFP_TRACES_DF["mPFC_RMS_filtered_power_all_frequencies_all_windows"].iloc[4].shape

In [None]:
LFP_TRACES_DF["mPFC_lfp_RMS_filtered"].head()

In [None]:
plt.plot(LFP_TRACES_DF["BLA_lfp_trace"].iloc[0])

In [None]:
plt.plot(LFP_TRACES_DF["BLA_lfp_RMS"].iloc[0])

In [None]:
plt.plot(LFP_TRACES_DF["BLA_lfp_RMS"].iloc[0])

In [None]:
LFP_TRACES_DF["BLA_"].iloc[0]

In [None]:
raise ValueError()

In [None]:
plt.plot(LFP_TRACES_DF["BLA_lfp_RMS_filtered"].iloc[0])

In [None]:
LFP_TRACES_DF["BLA_trace_power_all_frequencies_all_windows"].apply(lambda x: np.sum(np.isnan(x[:,3:13])))

In [None]:
LFP_TRACES_DF["BLA_RMS_filtered_power_all_frequencies_all_windows"].apply(lambda x: np.sum(np.isnan(x[:,3:13])))

In [None]:
print("hello")

In [None]:
# LFP_TRACES_DF.to_pickle("./proc/rce2_spectral_granger.pkl")
LFP_TRACES_DF.to_pickle("./proc/rce_pilot_2_02_full_spectral.pkl")
# LFP_TRACES_DF.to_pickle("/blue/npadillacoreano/ryoi360/projects/reward_comp/final_proc/rce_pilot_2_02_spectral_granger.pkl")

## Calculating the averages

In [None]:
LFP_TRACES_DF.columns

In [None]:
# LFP_TRACES_DF.to_pickle("./proc/rce2_spectral_granger.pkl")
LFP_TRACES_DF.to_pickle("./proc/rce_pilot_2_02_full_spectral.pkl")
# LFP_TRACES_DF.to_pickle("/blue/npadillacoreano/ryoi360/projects/reward_comp/final_proc/rce_pilot_2_02_spectral_granger.pkl")

In [None]:
raise ValueError()