# All oscillation analysis

Brief 1-2 sentence description of notebook.

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import os
import collections
import itertools
from collections import defaultdict
from itertools import combinations

In [None]:
# os.environ["SPECTRAL_CONNECTIVITY_ENABLE_GPU"] = "true"
# import cupy as cp

In [None]:
# Imports of all used packages and libraries
import numpy as np
import pandas as pd
from scipy import stats
from scipy.stats import mannwhitneyu


In [None]:
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import colorsys

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
from spectral_connectivity import Multitaper, Connectivity
import spectral_connectivity

In [None]:
FONTSIZE = 20

In [None]:
font = {'weight' : 'medium',
        'size'   : 20}

matplotlib.rc('font', **font)

# Functions

In [None]:
def generate_pairs(lst):
    """
    Generates all unique pairs from a list.

    Parameters:
    - lst (list): The list to generate pairs from.

    Returns:
    - list: A list of tuples, each containing a unique pair from the input list.
    """
    n = len(lst)
    return [(lst[i], lst[j]) for i in range(n) for j in range(i+1, n)]

In [None]:
def update_array_by_mask(array, mask, value=np.nan):
    """
    """
    result = array.copy()
    result[mask] = value
    return result

## Inputs & Data

Explanation of each input and where it comes from.

In [None]:
# Inputs and Required data loading
# input varaible names are in all caps snake case
# Whenever an input changes or is used for processing 
# the vairables are all lower in snake case
OUTPUT_DIR = r"./proc" # where data is saved should always be shown in the inputs

In [None]:
TIME_HALFBANDWIDTH_PRODUCT = 2
TIME_WINDOW_DURATION = 1
TIME_WINDOW_STEP = 0.5
RESAMPLE_RATE=1000

In [None]:
zscore_threshold = 4
VOLTAGE_SCALING_VALUE = 0.195

In [None]:
BAND_TO_FREQ = {"theta": (4,12), "gamma": (30,51)}

In [None]:
OUTPUT_PREFIX = "rce_pilot_3_comp_omission"

In [None]:
LFP_TRACES_DF = pd.read_pickle("./proc/{}_01_lfp_traces_and_frames.pkl".format(OUTPUT_PREFIX))

In [None]:
LFP_TRACES_DF.shape

## Preprocessing

In [None]:
original_trace_columns = [col for col in LFP_TRACES_DF.columns if "trace" in col]

In [None]:
original_trace_columns

In [None]:
for col in original_trace_columns:
    print(col)
    LFP_TRACES_DF[col] = LFP_TRACES_DF[col].apply(lambda x: x.astype(np.float32) * VOLTAGE_SCALING_VALUE)

In [None]:
LFP_TRACES_DF.head()

# Calculating modified zscore

`0.6745(xi – x̃) / MAD`

In [None]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_MAD".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: stats.median_abs_deviation(x))

In [None]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_modified_zscore".format(brain_region)
    MAD_column = "{}_lfp_MAD".format(brain_region)

    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF.apply(lambda x: 0.6745 * (x[col] - np.median(x[col])) / x[MAD_column], axis=1)

In [None]:
LFP_TRACES_DF[updated_column]

## calculating root mean sequare

In [None]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_RMS".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: (x / np.sqrt(np.mean(x**2))).astype(np.float32))


In [None]:
LFP_TRACES_DF.head()

## Filtering for zscore value

In [None]:
zscore_columns = [col for col in LFP_TRACES_DF.columns if "zscore" in col]

In [None]:
zscore_columns

In [None]:
for col in zscore_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_mask".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: np.abs(x) >= zscore_threshold)

In [None]:
LFP_TRACES_DF[updated_column].head()

In [None]:
LFP_TRACES_DF[updated_column].iloc[0].shape

In [None]:
sum(LFP_TRACES_DF[updated_column].iloc[0])

- Filtering raw traces by zscore

In [None]:
LFP_TRACES_DF[col].head()

In [None]:
for col in original_trace_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_trace_filtered".format(brain_region)    
    mask_column = "{}_lfp_mask".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF.apply(lambda x: update_array_by_mask(x[col], x[mask_column]), axis=1)

In [None]:
LFP_TRACES_DF[col].head()

In [None]:
sum(np.isnan(LFP_TRACES_DF[col].iloc[0]))

- Calculating RMS of filtered signal

In [None]:
filtered_trace_column = [col for col in LFP_TRACES_DF if "lfp_trace_filtered" in col]

In [None]:
for col in filtered_trace_column:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_lfp_RMS_filtered".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: (x / np.sqrt(np.nanmean(x**2))).astype(np.float32))

# Power Calcuation

- Getting the column name of all the traces

In [None]:
input_columns = [col for col in LFP_TRACES_DF.columns if "RMS" in col and "filtered" in col]

In [None]:
input_columns

In [None]:
for col in input_columns:
    print(col)
    LFP_TRACES_DF[col] = LFP_TRACES_DF[col].apply(lambda x: np.round(x.astype(np.float32), decimals=3))

- Calcuating the power at each frequency band

In [None]:
LFP_TRACES_DF[col].iloc[0]

In [None]:
for col in input_columns:
    # brain_region = col.split("_")[0]
    brain_region = col.replace("_lfp", "")
    print(brain_region)

    # Define column names
    multitaper_col = f"{brain_region}_power_multitaper"
    connectivity_col = f"{brain_region}_power_connectivity"
    frequencies_col = f"{brain_region}_power_frequencies"
    power_col = f"{brain_region}_power_all_frequencies_all_windows"
    
    try:
        # Apply Multitaper function to the lfp_trace column
        LFP_TRACES_DF[multitaper_col] = LFP_TRACES_DF[col].apply(
            lambda x: Multitaper(
                time_series=x, 
                sampling_frequency=RESAMPLE_RATE, 
                time_halfbandwidth_product=TIME_HALFBANDWIDTH_PRODUCT,
                time_window_duration=TIME_WINDOW_DURATION, 
                time_window_step=TIME_WINDOW_STEP
            )
        )

        # Apply Connectivity function to the multitaper column
        LFP_TRACES_DF[connectivity_col] = LFP_TRACES_DF[multitaper_col].apply(
            lambda x: Connectivity.from_multitaper(x)
        )

        # Apply frequencies and power functions to the connectivity column
        LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[connectivity_col].apply(
            lambda x: x.frequencies
        )
        LFP_TRACES_DF[power_col] = LFP_TRACES_DF[connectivity_col].apply(
            lambda x: x.power().squeeze()
        )
        
        LFP_TRACES_DF[power_col] = LFP_TRACES_DF[power_col].apply(lambda x: x.astype(np.float32))
            
        # Removing unnecessary columns
        LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[multitaper_col, connectivity_col], errors="ignore")
    
    except Exception as e: 
        print(e)

- Getting the timestamps of the power

In [None]:
LFP_TRACES_DF["power_timestamps"] = LFP_TRACES_DF["lfp_timestamps"].apply(lambda x: x[(RESAMPLE_RATE//2):(-RESAMPLE_RATE//2):(RESAMPLE_RATE//2)])
# .iloc[0][500:-500:500].shape

- Making sure that the timestamps for power makes sense with shape and values

In [None]:
LFP_TRACES_DF["power_timestamps"].head().apply(lambda x: x.shape)

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_all_frequencies_all_windows" in col][0]].iloc[0].shape

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "lfp_timestamps" in col][0]].iloc[0]

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_timestamps" in col][0]].iloc[0]

- Checking if the right frequencies are being used

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_frequencies" in col]].head()

In [None]:
LFP_TRACES_DF["power_calculation_frequencies"] = LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "power_frequencies" in col][0]].copy()

- Dropping unnecessary columns

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "power_frequencies" in col], errors="ignore")

In [None]:
LFP_TRACES_DF.head()

In [None]:
LFP_TRACES_DF["mPFC_RMS_filtered_power_all_frequencies_all_windows"].head()

In [None]:
LFP_TRACES_DF["mPFC_RMS_filtered_power_all_frequencies_all_windows"].iloc[4].shape

In [None]:
LFP_TRACES_DF["mPFC_lfp_RMS_filtered"].head()

In [None]:
plt.plot(LFP_TRACES_DF["BLA_lfp_trace"].iloc[0])

In [None]:
plt.plot(LFP_TRACES_DF["BLA_lfp_RMS"].iloc[0])

In [None]:
plt.plot(LFP_TRACES_DF["BLA_lfp_RMS_filtered"].iloc[0])

## Calculating phase of signals

In [None]:
from scipy.signal import butter, filtfilt, hilbert

- Filtering for theta and gamma

In [None]:
RMS_columns = [col for col in LFP_TRACES_DF if "RMS" in col and "filtered" not in col]

In [None]:
fs = 1000
order=4

In [None]:
freq_band = [4, 12]
b, a = butter(order, freq_band, fs=fs, btype='band')

In [None]:
for col in RMS_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_theta_band".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: filtfilt(b, a, x, padtype=None).astype(np.float32))

In [None]:
freq_band = [30, 50]
b, a = butter(order, freq_band, fs=fs, btype='band')

In [None]:
for col in RMS_columns:
    print(col)
    brain_region = col.split("_")[0]
    updated_column = "{}_gamma_band".format(brain_region)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: filtfilt(b, a, x, padtype=None).astype(np.float32))

- Calculating the phase

In [None]:
band_columns = [col for col in LFP_TRACES_DF if "band" in col]

In [None]:
band_columns

In [None]:
for col in band_columns:
    print(col)
    brain_region = col.replace("_band", "")
    updated_column = "{}_phase".format(brain_region)
    print(updated_column)
    LFP_TRACES_DF[updated_column] = LFP_TRACES_DF[col].apply(lambda x: np.angle(hilbert(x), deg=False))

In [None]:
LFP_TRACES_DF[col]

## Coherence Calculation

- Getting the trace column pairs

In [None]:
brain_region_pairs = generate_pairs(sorted(input_columns))
brain_region_pairs = sorted(brain_region_pairs)


In [None]:
brain_region_pairs

## Coherece Calculation

- Calculating the coherence

In [None]:
for region_1, region_2 in brain_region_pairs:
    # Define base name for pair
    pair_base_name = f"{region_1.split('_')[0]}_{region_2.split('_')[0]}"
    print(pair_base_name)

    try:
        # Define column names
        multitaper_col = f"{pair_base_name}_coherence_multitaper"
        connectivity_col = f"{pair_base_name}_coherence_connectivity"
        frequencies_col = f"{pair_base_name}_coherence_frequencies"
        coherence_col = f"{pair_base_name}_coherence_all_frequencies_all_windows"

        index_to_multitaper = {}
        index_to_connectivity = {}
        index_to_frequencies = {}
        index_to_coherence = {}

        for index, row in LFP_TRACES_DF.iterrows():
            index_to_multitaper[index] = Multitaper(
                time_series=np.array([row[region_1], row[region_2]]).T, 
                sampling_frequency=RESAMPLE_RATE, 
                time_halfbandwidth_product=TIME_HALFBANDWIDTH_PRODUCT, 
                time_window_step=TIME_WINDOW_STEP, 
                time_window_duration=TIME_WINDOW_DURATION
            )
            print(index)
            index_to_connectivity[index] = Connectivity.from_multitaper(index_to_multitaper[index])
            index_to_frequencies[index] = index_to_connectivity[index].frequencies
            index_to_coherence[index] = index_to_connectivity[index].coherence_magnitude()[:,:,0,1]

        LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF.index.to_series().map(index_to_frequencies)
        LFP_TRACES_DF[coherence_col] = LFP_TRACES_DF.index.to_series().map(index_to_coherence)
        LFP_TRACES_DF[coherence_col] = LFP_TRACES_DF[coherence_col].apply(lambda x: x.astype(np.float32))

    except Exception as e: 
        print(e)

    # Drop temporary columns
    LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[multitaper_col, connectivity_col], errors="ignore")

In [None]:
# for region_1, region_2 in brain_region_pairs:
#     # Define base name for pair
#     pair_base_name = f"{region_1.split('_')[0]}_{region_2.split('_')[0]}"
#     print(pair_base_name)

#     try:
#         # Define column names
#         multitaper_col = f"{pair_base_name}_coherence_multitaper"
#         connectivity_col = f"{pair_base_name}_coherence_connectivity"
#         frequencies_col = f"{pair_base_name}_coherence_frequencies"
#         coherence_col = f"{pair_base_name}_coherence_all_frequencies_all_windows"

#         # Apply Multitaper function
#         LFP_TRACES_DF[multitaper_col] = LFP_TRACES_DF.apply(
#             lambda x: Multitaper(
#                 time_series=np.array([x[region_1], x[region_2]]).T, 
#                 sampling_frequency=RESAMPLE_RATE, 
#                 time_halfbandwidth_product=TIME_HALFBANDWIDTH_PRODUCT, 
#                 time_window_step=TIME_WINDOW_STEP, 
#                 time_window_duration=TIME_WINDOW_DURATION
#             ), 
#             axis=1
#         )

#         # Apply Connectivity function
#         LFP_TRACES_DF[connectivity_col] = LFP_TRACES_DF[multitaper_col].apply(
#             lambda x: Connectivity.from_multitaper(x)
#         )

#         # Apply frequencies and coherence functions
#         LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[connectivity_col].apply(
#             lambda x: x.frequencies
#         )
#         LFP_TRACES_DF[coherence_col] = LFP_TRACES_DF[connectivity_col].apply(
#             lambda x: x.coherence_magnitude()[:,:,0,1]
#         )

#         LFP_TRACES_DF[coherence_col] = LFP_TRACES_DF[coherence_col].apply(lambda x: x.astype(np.float32))

#     except Exception as e: 
#         print(e)

#     # Drop temporary columns
#     LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[multitaper_col, connectivity_col], errors="ignore")

In [None]:
LFP_TRACES_DF.head()

- Getting the timestamps of the coherence

In [None]:
LFP_TRACES_DF["coherence_timestamps"] = LFP_TRACES_DF["lfp_timestamps"].apply(lambda x: x[(RESAMPLE_RATE//2):(-RESAMPLE_RATE//2):(RESAMPLE_RATE//2)])


- Making sure that the timestamps for coherence makes sense with shape and values

In [None]:
LFP_TRACES_DF["coherence_timestamps"].head().apply(lambda x: x.shape)

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_all_frequencies_all_windows" in col][0]].iloc[0].shape

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "lfp_timestamps" in col][0]].iloc[0]

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_timestamps" in col][0]].iloc[0]

- Checking if the right frequencies are being used

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_frequencies" in col]].head()

In [None]:
LFP_TRACES_DF["coherence_calculation_frequencies"] = LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "coherence_frequencies" in col][0]].copy()

- Dropping unnecessary columns

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "coherence_frequencies" in col], errors="ignore")

In [None]:
LFP_TRACES_DF.head()

In [None]:
# LFP_TRACES_DF.to_pickle("./proc/rce2_spectral_granger.pkl")
LFP_TRACES_DF.to_pickle("./proc/{}_02_full_spectral.pkl".format(OUTPUT_PREFIX))
# LFP_TRACES_DF.to_pickle("/blue/npadillacoreano/ryoi360/projects/reward_comp/final_proc/rce_pilot_2_02_spectral_granger.pkl")

# Calculate Granger's

In [None]:
for region_1, region_2 in brain_region_pairs:
    # Define base name for pair
    region_1_base_name = region_1.split('_')[0]
    region_2_base_name = region_2.split('_')[0]

    pair_base_name = f"{region_1_base_name}_{region_2_base_name}"
    print(pair_base_name)

    try:
        # Define column names
        multitaper_col = f"{pair_base_name}_granger_multitaper"
        connectivity_col = f"{pair_base_name}_granger_connectivity"
        frequencies_col = f"{pair_base_name}_granger_frequencies"
        granger_1_2_col = f"{region_1_base_name}_{region_2_base_name}_granger_all_frequencies_all_windows"
        granger_2_1_col = f"{region_2_base_name}_{region_1_base_name}_granger_all_frequencies_all_windows"

        # Apply Multitaper function
        LFP_TRACES_DF[multitaper_col] = LFP_TRACES_DF.apply(
            lambda x: Multitaper(
                time_series=np.array([x[region_1], x[region_2]]).T, 
                sampling_frequency=RESAMPLE_RATE, 
                time_halfbandwidth_product=TIME_HALFBANDWIDTH_PRODUCT, 
                time_window_step=TIME_WINDOW_STEP, 
                time_window_duration=TIME_WINDOW_DURATION
            ), 
            axis=1
        )
    
        # Apply Connectivity function
        LFP_TRACES_DF[connectivity_col] = LFP_TRACES_DF[multitaper_col].apply(
            lambda x: Connectivity.from_multitaper(x)
        )

        # Apply frequencies and granger functions
        LFP_TRACES_DF[frequencies_col] = LFP_TRACES_DF[connectivity_col].apply(
            lambda x: x.frequencies
        )
        
        LFP_TRACES_DF[granger_1_2_col] = LFP_TRACES_DF[connectivity_col].apply(
            lambda x: x.pairwise_spectral_granger_prediction()[:,:,0,1]
        )

        LFP_TRACES_DF[granger_2_1_col] = LFP_TRACES_DF[connectivity_col].apply(
            lambda x: x.pairwise_spectral_granger_prediction()[:,:,1,0]
        )

        LFP_TRACES_DF[granger_1_2_col] = LFP_TRACES_DF[granger_1_2_col].apply(lambda x: x.astype(np.float32))
        LFP_TRACES_DF[granger_2_1_col] = LFP_TRACES_DF[granger_2_1_col].apply(lambda x: x.astype(np.float32))
        
    except Exception as e: 
        print(e)

    # Drop temporary columns
    LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[multitaper_col, connectivity_col], errors="ignore")

- Getting the timestamps of the granger

In [None]:
LFP_TRACES_DF["granger_timestamps"] = LFP_TRACES_DF["lfp_timestamps"].apply(lambda x: x[(RESAMPLE_RATE//2):(-RESAMPLE_RATE//2):(RESAMPLE_RATE//2)])


- Making sure that the timestamps for granger makes sense with shape and values

In [None]:
LFP_TRACES_DF["granger_timestamps"].head().apply(lambda x: x.shape)

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_all_frequencies_all_windows" in col][0]].iloc[0].shape

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "lfp_timestamps" in col][0]].iloc[0]

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_timestamps" in col][0]].iloc[0]

- Checking if the right frequencies are being used

In [None]:
LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_frequencies" in col]].head()

In [None]:
LFP_TRACES_DF["granger_calculation_frequencies"] = LFP_TRACES_DF[[col for col in LFP_TRACES_DF.columns if "granger_frequencies" in col][0]].copy()

- Dropping unnecessary columns

In [None]:
LFP_TRACES_DF = LFP_TRACES_DF.drop(columns=[col for col in LFP_TRACES_DF.columns if "granger_frequencies" in col], errors="ignore")

In [None]:
LFP_TRACES_DF.head()

In [None]:
LFP_TRACES_DF.columns

In [None]:
# LFP_TRACES_DF.to_pickle("./proc/rce2_spectral_granger.pkl")
LFP_TRACES_DF.to_pickle("./proc/{}_02_full_spectral.pkl".format(OUTPUT_PREFIX))
# LFP_TRACES_DF.to_pickle("/blue/npadillacoreano/ryoi360/projects/reward_comp/final_proc/rce_pilot_2_02_spectral_granger.pkl")

In [None]:
raise ValueError()