# All oscillation analysis

Brief 1-2 sentence description of notebook.

In [193]:
# Imports of all used packages and libraries
import sys
import os
import git
import glob
from collections import defaultdict

In [194]:
git_repo = git.Repo(".", search_parent_directories=True)
git_root = git_repo.git.rev_parse("--show-toplevel")

In [None]:
git_root

In [196]:
sys.path.insert(0, os.path.join(git_root, 'src'))

In [197]:
import warnings
warnings.filterwarnings('ignore')

In [198]:
import os
import collections
import itertools
from collections import defaultdict
from itertools import combinations

In [199]:
# Imports of all used packages and libraries
import numpy as np
import pandas as pd
from scipy import stats
from scipy.stats import mannwhitneyu
# import seaborn as sns



In [200]:
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import colorsys

In [201]:
from sklearn.metrics import confusion_matrix

In [202]:
# from spectral_connectivity import Multitaper, Connectivity
# import spectral_connectivity

In [203]:
import utilities.helper
import trodes.read_exported

In [204]:
FONTSIZE = 20

In [205]:
font = {'weight' : 'medium',
        'size'   : 20}

matplotlib.rc('font', **font)

In [206]:
# Define a function to horizontally stack arrays
def stack_arrays(arrays):
    return np.vstack(arrays)

In [207]:
# BAND_TO_FREQ_PLOT = {'theta': (4, 12), 'gamma': (30, 50)}
BAND_TO_FREQ_PLOT = {'theta': (4, 12)}
BAND_TO_FREQ_COLOR = {'theta': "#FFAF00", 'beta': "blue", 'gamma': "green"}

## Inputs & Data

In [208]:
EPHYS_SAMPLE_RATE = 20000

In [209]:
# GOOD_SUBJECTS = ["3.1", "3.3", "3.4", "4.2", "4.3", "5.2", "5.3"]
# GOOD_SUBJECTS = ["3.1", "4.2", "4.3"]
# GOOD_SUBJECTS = ["3.1", "3.3", "3.4", "4.2", "4.3"]
GOOD_SUBJECTS = ["3.1", "3.3", "3.4", "4.2", "4.3", "5.2", "5.3"]
# GOOD_SUBJECTS = ["3.1", "3.3", "3.4", "4.2", "4.3"]


Explanation of each input and where it comes from.

In [210]:
# Inputs and Required data loading
# input varaible names are in all caps snake case
# Whenever an input changes or is used for processing 
# the vairables are all lower in snake case
OUTPUT_DIR = r"./proc/" # where data is saved should always be shown in the inputs
os.makedirs(OUTPUT_DIR, exist_ok=True)
OUTPUT_PREFIXES = ["rce_pilot_3_alone_comp", "rce_pilot_3_long_comp"]
OUTPUT_PREFIX = "rce_pilot_3_combined"

In [211]:
# TRIAL_LABELS_DF = pd.read_excel("/blue/npadillacoreano/ryoi360/projects/reward_comp/repos/reward_comp_ext/results/2024_06_26_sleap_clustering/data/rce_pilot_3_alone_comp_per_video_trial_labels.xlsx")
TRIALS_AND_SPECTRAL_DF = pd.concat([pd.read_pickle("./proc/{}_10_per_trial_spectral_bans_sleap.pkl".format(prefix)) for prefix in OUTPUT_PREFIXES])

In [212]:
FULL_LFP_TRACES_PKL = "{}_12_per_cluster_spectral_bans_sleap.pkl".format(OUTPUT_PREFIX)

## Outputs

Describe each output that the notebook creates. 

- Is it a plot or is it data?

- How valuable is the output and why is it valuable or useful?

## Functions 

In [None]:
def combine_dicts(dicts):
    """
    Combine lists from multiple dictionaries that share the same key.

    This function takes a list of dictionaries where each dictionary's values are lists of numbers.
    It merges these lists for each corresponding key across all dictionaries, producing a single
    dictionary where each key has a combined list of all numbers from the input dictionaries.

    Parameters:
        dicts (list of dict): A list of dictionaries with values as lists of numbers.

    Returns:
        dict: A dictionary with keys from the input dictionaries and values as merged lists 
        of numbers from all corresponding input dictionary values.
    """
    combined = defaultdict(list)
    for dictionary in dicts:
        for key, value in dictionary.items():
            combined[key].extend(value)
    
    return dict(combined)


# Example usage
list_of_dicts = [
    {'a': [1, 2], 'b': [3, 4]},
    {'a': [5], 'b': [6, 7]},
    {'a': [8, 9], 'c': [10]}
]

combined_dict = combine_dicts(list_of_dicts)
print(combined_dict)

In [None]:
def find_consecutive_ranges(numbers, min_length=1):
    """
    Finds the start and end indices for consecutive ranges of each number in a list where the range meets a minimum length.

    This function iterates through a list of integers and identifies ranges where the same integer appears consecutively
    and the length of this sequence meets or exceeds the specified minimum length.

    Parameters:
        numbers (list): A list of integers to analyze for consecutive ranges.
        min_length (int): The minimum length of a range for it to be included in the results.

    Returns:
        dict: A dictionary with integers as keys and a list of tuples (start, end) as values,
              where each tuple represents the start and end indices (inclusive) of consecutive ranges
              for that integer. Only ranges that meet or exceed the minimum length are included.
    """
    ranges = {}
    n = len(numbers)
    if n == 0:
        return ranges
    
    start = 0
    current = numbers[0]

    for i in range(1, n):
        if numbers[i] != current:
            if (i - start) >= min_length:
                if current not in ranges:
                    ranges[current] = []
                ranges[current].append((start, i - 1))
            current = numbers[i]
            start = i

    # Handle the last range
    if (n - start) >= min_length:
        if current not in ranges:
            ranges[current] = []
        ranges[current].append((start, n - 1))

    return ranges

# Example usage:
numbers = [1, 1, 2, 2, 2, 3, 3, 3, 3, 2, 2, 1, 1]
print(find_consecutive_ranges(numbers, min_length=3))

In [None]:
def update_tuples_in_dict(original_dict, reference_list):
    """
    Updates the values in the tuples within a dictionary by replacing indices with corresponding values from a reference list.
    
    This function iterates through each key-value pair in the original dictionary. Each value is expected to be a list of tuples,
    where each tuple contains indices. These indices are used to fetch corresponding values from the reference list, creating new tuples.
    
    Parameters:
        original_dict (dict): Dictionary whose values are lists of tuples. Each tuple consists of indices into the reference_list.
        reference_list (list): List of elements that are referenced by the indices in the tuples of the original_dict.
        
    Returns:
        dict: A dictionary with the same keys as original_dict but with tuples transformed to contain elements from reference_list
              based on the indices in the original tuples.
    """
    # Create a new dictionary to store the updated key-value pairs
    new_dict = {}
    for key, list_of_tuples in original_dict.items():
        # Process each tuple in the list associated with the current key
        updated_tuples = [
            tuple(reference_list[idx] for idx in tup) for tup in list_of_tuples
        ]
        new_dict[key] = updated_tuples
    
    return new_dict

# Example usage:
original_dict = {
    'a': [(0, 1), (2, 3)],
    'b': [(1, 3), (0, 2)]
}
reference_list = ['alpha', 'beta', 'gamma', 'delta']

updated_dict = update_tuples_in_dict(original_dict, reference_list)
print(updated_dict)

In [216]:
# def find_indices_within_ranges(ranges_dict, values):
#     """
#     Creates a dictionary mapping keys to sorted indices of values that fall within specified ranges.
    
#     Parameters:
#         ranges_dict (dict): A dictionary with keys and values as lists of tuples representing ranges.
#         values (list): A list of values to check against the ranges.
        
#     Returns:
#         dict: A dictionary where each key maps to a sorted list of indices for values within the ranges.
#     """
#     result_dict = {}
#     for key, ranges in ranges_dict.items():
#         matched_indices = []
#         for index, value in enumerate(values):
#             if any(start <= value <= end for start, end in ranges):
#                 matched_indices.append(index)
#         result_dict[key] = sorted(matched_indices)
#     return list(result_dict.items())

# # Example usage:
# ranges_dict = {
#     'range1': [(1, 5), (10, 15)],
#     'range2': [(0, 2), (4, 8)]
# }
# values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

# result = find_indices_within_ranges(ranges_dict, values)
# print(result)


In [None]:
def update_tuples_in_list(original_list, reference_list):
    """
    Updates the values in the tuples within a list by replacing indices with corresponding values from a reference list.
    
    This function iterates through each tuple in the original list. Each tuple is expected to contain indices.
    These indices are used to fetch corresponding values from the reference list, creating new tuples.
    
    Parameters:
        original_list (list): List of tuples. Each inner tuple consists of indices into the reference_list.
        reference_list (list): List of elements that are referenced by the indices in the tuples of the original_list.
        
    Returns:
        list: A list with the same structure as original_list but with tuples transformed to contain elements from reference_list
              based on the indices in the original tuples.
    """
    # Create a new list to store the updated tuples
    new_list = [
        tuple(reference_list[idx] for idx in tup) for tup in original_list
    ]
    
    return new_list

# Example usage:
original_list = [
    (0, 1), (2, 3),
    (1, 3), (0, 2)
]
reference_list = ['alpha', 'beta', 'gamma', 'delta']

updated_list = update_tuples_in_list(original_list, reference_list)
print(updated_list)

In [None]:
def find_indices_within_ranges(ranges_list, values):
    """
    Finds the indices of values that fall within specified ranges.
    
    Parameters:
        ranges_list (list): A list of tuples representing ranges.
        values (list): A list of values to check against the ranges.
        
    Returns:
        list: A list of sorted indices for values within the ranges.
    """
    matched_indices = []
    for index, value in enumerate(values):
        if any(start <= value <= end for start, end in ranges_list):
            matched_indices.append(index)
    
    return sorted(matched_indices)

# Example usage:
ranges_list = [(1, 5), (10, 15), (0, 2), (4, 8)]
values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

result = find_indices_within_ranges(ranges_list, values)
print(result)

In [219]:
OUTCOME_TO_COLOR = {"both_rewarded": "green", 
"novel_lose": "#e67073",
"novel_win": "#93a5da",
"lose": "#951a1d",
"alone_rewarded": "#0499af",
"win": "#3853a3",
"omission": "orange",
"tie": "green"}

In [220]:
comp_id_to_color = {'competitive_1': "#281640",
'competitive_2': "#43246a",
'competitive_3': "#8e7ca6",
'no_comp_4': "#2f3600",
'no_comp_5': "#535f00",
'no_comp_6': "#768800",
'no_comp_7': "#9fac4d",
'no_comp_8': "#c8cf99",
'competitive': "#43246A",
'no_comp': "#768800",
'win': "#0045A6",
'win_competitive': "#003074",
'win_no_comp': "#4d7dc1",
'lose': "#792910",
'lose_competitive': "#551d0b",
'lose_no_comp': "#a16958",
'rewarded': "#FFAF00"
}

In [221]:
to_keep_columns = ['trial_label',
'tone_start_frame',
'reward_start',
'reward_dispensed',
'tone_stop_frame',
'condition',
'competition_closeness',
'get_reward_frame',
'out_reward_frame',
'notes',
'box_1_port_entry_frames',
'box_2_port_entry_frames',
'video_name',
'tone_start_timestamp',
'tone_stop_timestamp',
'box_1_port_entry_timestamps',
'box_2_port_entry_timestamps',
'current_subject',
'session_dir',
'experiment',
'sleap_name',
'video_id',
'agent',
'all_subjects',
'cohort',
'first_timestamp',
'last_timestamp',
'recording',
'session_path',
'subject',
'baseline_start_timestamp',
'post_trial_end_timestamp',]

## Processing

Describe what is done to the data here and how inputs are manipulated to generate outputs. 

In [222]:
# As much code and as many cells as required
# includes EDA and playing with data
# GO HAM!

# Ideally functions are defined here first and then data is processed using the functions

# function names are short and in snake case all lowercase
# a function name should be unique but does not have to describe the function
# doc strings describe functions not function names




## Renaming the trial labels

In [None]:
TRIALS_AND_SPECTRAL_DF.head()

In [224]:
comp_closeness_dict = {'Subj 1 blocking Subj 2': "competitive",
'Subj 2 Only': "no_comp",
'Subj 2 blocking Subj 1': "competitive",
'Subj 1 then Subj 2': "competitive", 
'Subj 1 Only': "no_comp",
'Subj 2 then Subj 1': "competitive",
'Close Call': "competitive",
'After trial': "no_comp"}

In [225]:
# cluster_to_competitiveness = {"0": "no_comp", "1": "competitive", "2": "competitive", "3": "no_comp", "4": "competitive", "5": "no_comp", "6": "no_comp", "7": "no_comp"}
# cluster_to_comp_id = {"0": "no_comp_8", "1": "competitive_3", "2": "competitive_1", "3": "no_comp_6", "4": "competitive_2", "5": "no_comp_7", "6": "no_comp_5", "7": "no_comp_4"}
# comp_id = {"no_comp_8", "competitive_3", "competitive_1", "no_comp_6", "competitive_2", "no_comp_7", "no_comp_5", "no_comp_4"}


Win base color
#0045A6 
Win competitive color
#003074
Win no comp color
#4d7dc1

Lose base color
#792910
Lose competitive color
#551d0b
Lose no comp color
#a16958

In [226]:
TRIALS_AND_SPECTRAL_DF["current_subject"] = TRIALS_AND_SPECTRAL_DF["current_subject"].apply(lambda x: str(x).strip().lower())

In [227]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF[TRIALS_AND_SPECTRAL_DF["current_subject"].isin(GOOD_SUBJECTS)]

In [None]:
TRIALS_AND_SPECTRAL_DF["current_subject"].unique()

In [229]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF[TRIALS_AND_SPECTRAL_DF["condition"] != "tie"].reset_index(drop=True)

In [230]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF[TRIALS_AND_SPECTRAL_DF["condition"] != "temp"].reset_index(drop=True)

In [None]:
TRIALS_AND_SPECTRAL_DF["condition"].unique()

In [232]:
TRIALS_AND_SPECTRAL_DF["condition"] = TRIALS_AND_SPECTRAL_DF["condition"].apply(lambda x: str(x).strip().lower())

In [233]:
TRIALS_AND_SPECTRAL_DF["trial_label"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: "win" if x["current_subject"] == x["condition"]  else ("lose" if x["agent"] == x["condition"] else x["condition"]), axis=1)
                                                                        

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_label"].unique()

In [235]:
TRIALS_AND_SPECTRAL_DF["competition_closeness"] = TRIALS_AND_SPECTRAL_DF["competition_closeness"].map(comp_closeness_dict).fillna("rewarded")

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_label"].unique()

In [None]:
TRIALS_AND_SPECTRAL_DF["competition_closeness"].unique()

## Making separate rows for each cluster

In [None]:
TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges_dict"]

In [None]:
TRIALS_AND_SPECTRAL_DF[TRIALS_AND_SPECTRAL_DF["trial_label"] != "rewarded"]["competitiveness_frame_ranges_dict"].iloc[0]

In [None]:
TRIALS_AND_SPECTRAL_DF["tone_start_frame"].head()

In [None]:
TRIALS_AND_SPECTRAL_DF["tone_stop_frame"]

- Making each dictionary into a list so that we can explode it

In [242]:
dict_col = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "dict" in col and "competitiveness" in col]


In [None]:
dict_col

In [None]:
TRIALS_AND_SPECTRAL_DF[TRIALS_AND_SPECTRAL_DF["trial_label"] != "rewarded"]["competitiveness_index_ranges_dict"].iloc[0]

In [245]:
for col in dict_col:
    TRIALS_AND_SPECTRAL_DF[col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: list(x[col].items()) if isinstance(x[col], dict) else [("rewarded", [(x["tone_start_frame"], x["tone_stop_frame"])])], axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges_dict"].head()

In [None]:
TRIALS_AND_SPECTRAL_DF[TRIALS_AND_SPECTRAL_DF["trial_label"] != "rewarded"]["competitiveness_frame_ranges_dict"].head()

- Exploding each row so that competitive and non-competitive clusters are on different rows

In [248]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.explode(column=dict_col)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges_dict"].head()

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_label"] != "rewarded"]["competitiveness_frame_ranges_dict"].head()

- Making a new column based on the key which is the competitiveness of the group

In [251]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_label"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges_dict"].apply(lambda x: str(x[0]))

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_label"].iloc[0]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_label"].head()

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_label"] != "rewarded"]["competitiveness_label"].head()

- Filtering rows that don't have matching frame and trial competitiveness

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_label"].unique()

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competition_closeness"].unique()

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF.shape

In [258]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["competition_closeness"] == cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_label"]].reset_index(drop=True)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF.shape

- Combining the competitiveness and the trial outcome labels

In [260]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_and_competitiveness_label"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: "_".join([x["trial_label"], x["competitiveness_label"]]) if  x["trial_label"] != x["competitiveness_label"] else x["trial_label"], axis=1)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_and_competitiveness_label"].unique()

- Making a new column that is the range of frames

In [262]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges_dict"].apply(lambda x: x[1])

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"].head()

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_label"] != "rewarded"]["competitiveness_frame_ranges"].head()

- Filtering for ranges that are only during the trial

In [None]:
def adjust_frame_range(frame_range, event_range):
    """
    Adjusts the frame range based on the event range.

    Parameters:
        frame_range (tuple): The tuple representing the start and end of the frame range.
        event_range (tuple): The tuple representing the start and end of the event range.

    Returns:
        tuple or None: Adjusted frame range or None if there's no overlap.
    """
    # Unpack the ranges
    frame_start, frame_end = frame_range
    event_start, event_end = event_range

    # Check for overlap
    if frame_end < event_start or frame_start > event_end:
        return None  # No overlap

    # Adjust the start and end of the frame range
    adjusted_start = max(frame_start, event_start)
    adjusted_end = min(frame_end, event_end)

    return (adjusted_start, adjusted_end)

# Example Usage:
frame_range = (10, 20)
event_range = (15, 25)
result = adjust_frame_range(frame_range, event_range)
print(result)  # Output will be (15, 20)


In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["tone_start_frame"]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["tone_stop_frame"]

In [268]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: [adjust_frame_range((start, stop), (x["tone_start_frame"], x["tone_stop_frame"])) for start, stop in x["competitiveness_frame_ranges"] if adjust_frame_range((start, stop), (x["tone_start_frame"], x["tone_stop_frame"]))], axis=1)

- Filtering for rows that have at least one consecutive second

In [269]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"].apply(lambda x: len(x) >= 1)]

- Filtering for ranges that are 1 second consecutive

In [270]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"].apply(lambda x: [(start, stop) for start, stop in x if stop - start >= 20])

- Calculating the total number of frames for each trial

In [271]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["frame_ranges_sum_frames"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"].apply(lambda x: sum([stop - start for start, stop in x]))

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["frame_ranges_sum_frames"]

- Filtering for rows where at least half the frames are there

In [273]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["frame_ranges_sum_frames"] >= 100]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"]

- Filtering for rows that have at least one consecutive second

In [275]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"].apply(lambda x: len(x) >= 1)]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_frame_ranges"]

In [None]:
for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF.columns:
    print(col)

## Adding the velocity

In [278]:
alone_trials_df = pd.read_pickle("../2024_06_26_sleap_clustering/proc/alone_trials_sleap.pkl")

In [None]:
alone_trials_df.columns

In [None]:
alone_trials_df["subject_thorax_velocity"].apply(lambda x: x.shape)

In [281]:
alone_trials_df["trial_subject_thorax_velocity"] = alone_trials_df["subject_thorax_velocity"].apply(lambda x: x[199:399])

In [282]:
alone_trials_df["trial_frame_index"] = alone_trials_df["frame_index"].apply(lambda x: x[199:399])

In [283]:
alone_trial_to_velocity = {(row['session_dir'], row['current_subject'], row['tone_start_timestamp']): row['trial_subject_thorax_velocity'] for index, row in alone_trials_df.iterrows()}

In [284]:
alone_trial_to_frame_index = {(row['session_dir'], row['current_subject'], row['tone_start_timestamp']): row['trial_frame_index'] for index, row in alone_trials_df.iterrows()}

In [285]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_subject_thorax_velocity"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda row: alone_trial_to_velocity[(row['session_dir'], row['current_subject'], row['tone_start_timestamp'])] if (row['session_dir'], row['current_subject'], row['tone_start_timestamp']) in alone_trial_to_velocity else row["trial_subject_thorax_velocity"], axis=1)

In [286]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_frame_index"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda row: alone_trial_to_frame_index[(row['session_dir'], row['current_subject'], row['tone_start_timestamp'])] if (row['session_dir'], row['current_subject'], row['tone_start_timestamp']) in alone_trial_to_frame_index else row["trial_frame_index"], axis=1)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_subject_thorax_velocity"].iloc[:5].apply(lambda x: x.shape)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_frame_index"].iloc[:5].apply(lambda x: x.shape)

In [289]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_video_timestamps"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: x["video_timestamps"][x["trial_frame_index"]] if isinstance(pd.isna(x["trial_video_timestamps"]), bool) else x["trial_video_timestamps"], axis=1)

- Getting the timestamps for each frame

In [290]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_timestamp_ranges"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: update_tuples_in_list(x["competitiveness_frame_ranges"], x["video_timestamps"]), axis=1)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_label"] != "rewarded"]["competitiveness_timestamp_ranges"]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["competitiveness_timestamp_ranges"]

- Getting all the spectral timestamps that are in the ranges

In [293]:
# timestamps_col = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF.columns if "timestamps" in col and "video" not in col and "trial" in col]
timestamps_col = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF.columns if "timestamps" in col and "trial" in col]

In [None]:
timestamps_col

- Getting all the indexes that fit within each range

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_lfp_timestamps"]

In [None]:
for col in timestamps_col:
    base_col = col.replace("trial_", "").replace("timestamps", "index")
    print(base_col)
    cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_filtered_{}".format(base_col)] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: find_indices_within_ranges(x["competitiveness_timestamp_ranges"], x[col]), axis=1)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_mPFC_theta_band"].iloc[0].shape

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_filtered_{}".format(base_col)].iloc[0][:10]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_filtered_{}".format(base_col)].iloc[0][-10:]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF.head()

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF.tail()

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[~cluster_exploded_TRIALS_AND_SPECTRAL_DF["video_name"].apply(lambda x: ".2" in x)]

In [111]:
# cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[~cluster_exploded_TRIALS_AND_SPECTRAL_DF["video_name"].apply(lambda x: ".2" in x)]
# cluster_exploded_TRIALS_AND_SPECTRAL_DF = 
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[~cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: (".2" in x["video_name"]) and (x["condition"] != "rewarded"), axis=1)]

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF.groupby(["video_name", "competitiveness_label", "current_subject"]).count().tail(n=25)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_and_competitiveness_label"].shape

In [114]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[~cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_and_competitiveness_label"].apply(lambda x: "temp" in x)]

- Filtering for rows where the subject is correct for the rewarded trials

In [115]:
START_STOP_FRAME_DF = pd.read_excel("../2024_06_26_sleap_clustering/data/rce_pilot_3_alone_comp_alone_trials_start_stop_video_frame.xlsx")

In [116]:
START_STOP_FRAME_DF["video_name"] = START_STOP_FRAME_DF["file_path"].apply(lambda x: ".".join(os.path.basename(x).split(".")[:2]))


In [117]:
START_STOP_FRAME_DF["current_subject"] = START_STOP_FRAME_DF["tracked_subject"].astype(str)

In [118]:
video_to_current_subject = cluster_exploded_TRIALS_AND_SPECTRAL_DF.drop_duplicates(["video_name", "current_subject"])[["video_name", "current_subject"]].copy()

In [119]:
alone_video_to_subject = pd.merge(video_to_current_subject, START_STOP_FRAME_DF, on=["video_name", "current_subject"])[["video_name", "current_subject"]]

In [None]:
alone_video_to_subject

In [None]:
video_to_current_subject[video_to_current_subject["video_name"].str.contains("long")]

In [122]:
video_to_current_subject = pd.concat([alone_video_to_subject, video_to_current_subject[video_to_current_subject["video_name"].str.contains("long")]])

In [None]:
video_to_current_subject

In [124]:
list_of_tuples = list(zip(video_to_current_subject['video_name'], video_to_current_subject['current_subject']))

In [125]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[(cluster_exploded_TRIALS_AND_SPECTRAL_DF["condition"] != "rewarded") | (cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: (x['video_name'], x['current_subject']) in list_of_tuples, axis=1))]

# Filtering for power

In [126]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_filtered_power_index"].apply(lambda x: len(x)) >= 1] 

In [None]:
list(cluster_exploded_TRIALS_AND_SPECTRAL_DF.columns)

In [128]:
power_columns = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF if "trial" in col and ("power" in col or "granger" in col or "coherence" in col or "trial_subject_thorax_velocity" in col) and "agent" not in col]

In [None]:
power_columns

- Getting the timestamps of all the clusters

In [None]:
for col in power_columns:
    print(col)
    cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_all_{}".format(col)] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: x[col][x["cluster_filtered_power_index"]], axis=1)

In [131]:
cluster_all_columns = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF if "cluster_all" in col]

In [None]:
cluster_all_columns

- Aggregating all the values within a given trial

In [None]:
for col in cluster_all_columns:
    updated_column = col.replace("cluster_all", "cluster_mean")
    if "gamma" in col or "theta" in col:
        cluster_exploded_TRIALS_AND_SPECTRAL_DF[updated_column] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: np.nanmean(x[col]), axis=1)
    else:
        cluster_exploded_TRIALS_AND_SPECTRAL_DF[updated_column] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: np.nanmean(x[col], axis=0), axis=1)
    print(updated_column)
    print(cluster_exploded_TRIALS_AND_SPECTRAL_DF[updated_column].iloc[0])
        

In [134]:
cluster_mean_columns = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF if "cluster_mean" in col and "all_frequencies_all_windows" in col]

In [None]:
cluster_mean_columns

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF

In [None]:
low_freq = 0
high_freq = 51
current_frequencies = range(low_freq, high_freq)

# Iterating through each brain region
for col in cluster_mean_columns:
    if "velocity" in col:
        continue
    if "all_frequencies" not in col:
        continue
    print(col)

    fig, ax = plt.subplots()
    plt.xlim(low_freq, high_freq-1) 
    plt.xticks(np.arange(low_freq, high_freq, 5))

    stacked_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF.groupby("trial_and_competitiveness_label").agg({col: stack_arrays})
    stacked_df = stacked_df.reset_index()
    stacked_df = stacked_df[~stacked_df["trial_and_competitiveness_label"].str.contains("tie")]

    stacked_df["color"] = stacked_df["trial_and_competitiveness_label"].map(comp_id_to_color)
    
    if "power" in col:
        # plt.ylim(0,0.01)
        # plt.yscale("log")
        plt.ylim((10**-3.2,10**-1.2))
        # plt.set_ylim(auto=True)
        # plt.ylim(0, 0.5)
        ax.set_yscale('log')
        # plt.ylim(0, max_value)
    elif "coherence" in col:
        plt.ylim(0.3, 0.85)
        plt.yticks(np.arange(0.3, 0.85, 0.1)) 
    elif "granger" in col:
        plt.ylim(0.2, 0.8)
        plt.yticks(np.arange(0.2, 0.85, 0.1)) 
    else:
        # pass
        plt.ylim()
    
    # Iterating through each trial type
    for index, row in stacked_df.iterrows():
        print(row["trial_and_competitiveness_label"])
        print(row[col].shape)
        #### DURING TRIAL ####
        mean_power = np.nanmean(row[col], axis=0)[low_freq: high_freq]
        
        sem_power = np.nanstd(row[col], axis=0) / np.sqrt(row[col].shape[0])
        sem_power = sem_power[low_freq: high_freq]
        print(sem_power)

        plt.fill_between(current_frequencies, 
        mean_power - sem_power, mean_power + sem_power, \
        alpha=0.3, color=row["color"])

        plt.plot(current_frequencies, mean_power, \
        label="{}".format(row["trial_and_competitiveness_label"]), linewidth=3, color=row["color"])
    
    if "granger" in col:
        title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").strip()
    else:
        title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip()

    plt.title(title)
    plt.xlabel("Frequency (Hz)", fontsize=25)
    
    if "power" in col:
        # plt.ylabel("Normalized Power (a.u.)")
        output_dir = "./proc/cluster_spectra_plots/power"

    elif "coherence" in col:
        # plt.ylabel("Coherence")
        output_dir = "./proc/cluster_spectra_plots/coherence"

    elif "granger" in col:
        # plt.ylabel("Granger's Causality")
        output_dir = "./proc/cluster_spectra_plots/granger"
    
    os.makedirs(output_dir, exist_ok=True)

    # plt.legend(fontsize=10)
    
    # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(4)
    ax.spines['left'].set_linewidth(4)

    ax.tick_params(length=8, width=4)
    plt.yticks(fontsize=25)
    plt.xticks(fontsize=25)

    plt.tight_layout()    
    plt.savefig(os.path.join(output_dir, "cluster_{}.png".format(col.replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").strip("_").strip())), transparent=True)
    plt.show()


In [138]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF.to_pickle("./proc/cluster_exploded_TRIALS_AND_SPECTRAL_DF.pkl")

In [139]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF.copy()

In [None]:
plt.hist(export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_mean_trial_subject_thorax_velocity"])

In [141]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["velocity_zscore"] = stats.zscore(export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_mean_trial_subject_thorax_velocity"])

In [142]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF = export_cluster_exploded_TRIALS_AND_SPECTRAL_DF[export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["velocity_zscore"] <= 2]

In [None]:
plt.hist(export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_mean_trial_subject_thorax_velocity"])

In [None]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF

In [None]:
for col in export_cluster_exploded_TRIALS_AND_SPECTRAL_DF.columns:
    print(col)

In [146]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF = export_cluster_exploded_TRIALS_AND_SPECTRAL_DF.reset_index().rename({"index": "index_num"})

In [147]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF = export_cluster_exploded_TRIALS_AND_SPECTRAL_DF[[col for col in export_cluster_exploded_TRIALS_AND_SPECTRAL_DF if ("cluster_mean_trial" in col and "all_frequencies" not in col) or (col in ["trial_and_competitiveness_label", "current_subject", "index_num"])]]

In [148]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF.to_csv("./proc/export_cluster_exploded_TRIALS_AND_SPECTRAL_DF.csv")

In [None]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF[["trial_and_competitiveness_label", "cluster_mean_trial_subject_thorax_velocity"]].groupby(["trial_and_competitiveness_label"]).mean()

In [150]:
import seaborn as sns


## Plotting

In [None]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_and_competitiveness_label"].unique()

In [152]:
label_to_ticks = {'rewarded': "alone\nrewarded", 'win_no_comp': "win\nnon-comp", 'lose_no_comp': "lose\nnon-comp", 'win_competitive': "win\ncomp",
       'lose_competitive': "lose\ncomp"}

In [153]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["bar_ticks"] = export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_and_competitiveness_label"].map(label_to_ticks)

In [None]:
col

In [None]:
for col in [col for col in export_cluster_exploded_TRIALS_AND_SPECTRAL_DF if ("cluster_mean_trial" in col and "all_frequencies" not in col)]:
    print(col)
    sns.set_style('white', {'axes.linewidth': 0.5})
    plt.rcParams['xtick.major.size'] = 20
    plt.rcParams['xtick.major.width'] = 4
    plt.rcParams['xtick.bottom'] = True
    plt.rcParams['ytick.left'] = True

    # Plot the transition matrix using only matplotlib
    fig, ax = plt.subplots(figsize=(6.4, 4.8))

    ax = sns.violinplot(data=export_cluster_exploded_TRIALS_AND_SPECTRAL_DF.sort_values(["trial_and_competitiveness_label"]), x='bar_ticks', y=col, hue='trial_and_competitiveness_label',
        palette=comp_id_to_color,
                        inner=None, linewidth=0, saturation=1)
    
    sns.boxplot(x='bar_ticks', y=col, data=export_cluster_exploded_TRIALS_AND_SPECTRAL_DF.sort_values(["trial_and_competitiveness_label"]), color='white',            # hue='competitiveness_grouping',    palette=comp_to_color, 
                width=0.2,
                # boxprops={'zorder': 2}, 
                ax=ax, fill=True, linecolor="black", linewidth=2)

    plt.legend([],[], frameon=False)

        # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)

    ax.tick_params(axis='both', which='major', labelsize=23, length=10, width=3, color='black')
    ax.tick_params(axis='both', which='minor', labelsize=23, length=10, width=3, color='black')



    if "granger" in col:
        title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").strip()
    else:
        title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip()

    if "coherence" in col or "granger" in col:
        plt.ylim(0, 1)

    elif "theta" in col:
        plt.ylim(0, 0.04)
    elif "gamma" in col:
        plt.ylim(0, 0.004)

    if "coherence" in col:
        plt.ylabel('Coherence', size=30)

    elif "granger" in col:
        plt.ylabel('Granger Causality', size=30)

    elif "power" in col: 
        plt.ylabel('Normalized Power (a.u.)', size=30)
    
    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off

    plt.xticks(rotation=90)
    plt.title(title.replace("theta", "Theta").replace("gamma", "Gamma"))
    plt.tight_layout()
    # plt.savefig("./proc/spectra_bar/{}_bar.png".format(col.replace("cluster_mean_trial_", "")))
    plt.show()
    






- Plotting with the predicted values

In [156]:
all_predicted_df = []
for file_path in glob.glob("../../output/predicted/*"):
    current_df = pd.read_csv(file_path, index_col=0)
    current_df = current_df.rename(columns={"predicted_value": "predicted_" + current_df["col_name"].iloc[0]})
    current_df = current_df.rename(columns={"original_value": "original_" + current_df["col_name"].iloc[0]})
    current_df = current_df.drop(columns=["col_name"], errors="ignore")
    all_predicted_df.append(current_df)


In [None]:
current_df

In [158]:
joined_df = all_predicted_df[0]
for current_df in all_predicted_df[1:]:
    joined_df = pd.merge(joined_df, current_df, on="index_num", suffixes=("", "_drop"))
    joined_df = joined_df.drop(columns=[col for col in joined_df if "_drop" in col])

In [None]:
joined_df

- Plotting the emmeans

In [160]:
all_emmeans_df = []
for file_path in glob.glob("../../output/emmeans_csv/*with_CI*"):
    current_df = pd.read_csv(file_path, index_col=0)
    # current_df = current_df.rename(columns={"predicted_value": "predicted_" + current_df["col_name"].iloc[0]})
    # current_df = current_df.rename(columns={"original_value": "original_" + current_df["col_name"].iloc[0]})
    current_df = current_df.drop(columns=["col_name"], errors="ignore")
    current_df = current_df.dropna(subset=["emmean"])
    all_emmeans_df.append(current_df)
combined_emmeans_df = pd.concat(all_emmeans_df)


In [161]:
combined_emmeans_df["color"] = combined_emmeans_df["trial_and_competitiveness_label"].map(comp_id_to_color)

In [None]:
current_df = combined_emmeans_df[combined_emmeans_df["spectra_metric"] == "cluster_mean_trial_vHPC_power_theta"]
for index, row in current_df.sort_values(by=["trial_and_competitiveness_label"]).iterrows():
    plt.errorbar(row["trial_and_competitiveness_label"], row["emmean"], yerr=row["upper.CL"] - row["emmean"], fmt='o', color=row["color"],elinewidth=4, capsize=10, markeredgewidth=4)
    plt.scatter(row["trial_and_competitiveness_label"], row["emmean"], s = 100, marker = "o", color = row["color"])
    # plt.barh(row["trial_and_competitiveness_label"], row["emmean"], yerr=row["upper.CL"] - row["emmean"], fmt='o', color=row["color"])

    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off
    
    plt.ylabel('Estimated Marginal Means', size=20)

    # means.index, means, yerr=sems, fmt='o', color='black')


In [None]:
0.021322 - 0.026051

		# 0.001951	6.243603	0.021322	


In [None]:
col

In [165]:
from matplotlib.ticker import FormatStrFormatter

In [None]:

joined_df["bar_ticks"] = joined_df["trial_and_competitiveness_label"].map(label_to_ticks)

for col in [col for col in joined_df if ("cluster_mean_trial" in col and "all_frequencies" not in col and "predicted" in col)]:
    if "timestamps" in col:
        continue
    
    print(col)
    sns.set_style('white', {'axes.linewidth': 0.5})
    plt.rcParams['xtick.major.size'] = 20
    plt.rcParams['xtick.major.width'] = 4
    plt.rcParams['xtick.bottom'] = True
    plt.rcParams['ytick.left'] = True

    # Plot the transition matrix using only matplotlib
    # fig, ax = plt.subplots(figsize=(6.4, 4.8))
    fig, ax = plt.subplots(figsize=(6.4, 4.8))


    ax = sns.violinplot(data=joined_df.sort_values(["trial_and_competitiveness_label"]), x='bar_ticks', y=col, hue='trial_and_competitiveness_label',
        palette=comp_id_to_color,
                        inner=None, linewidth=0, saturation=1)
    
    # ax.tick_params(bottom=False)
    
    # means = joined_df.groupby('bar_ticks')[col].mean()
    # sems = joined_df.groupby('bar_ticks')[col].sem()

    # Plot the means with SEM bars
    # plt.errorbar(means.index, means, yerr=sems, fmt='o', color='black')

    sns.boxplot(x='bar_ticks', y=col, data=joined_df.sort_values(["trial_and_competitiveness_label"]), color='white',            # hue='competitiveness_grouping',    palette=comp_to_color, 
                width=0.2,
                # boxprops={'zorder': 2}, 
                ax=ax, fill=True, linecolor="black", linewidth=2, medianprops=dict(color="red", alpha=1))

    plt.legend([],[], frameon=False)

        # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)

    ax.tick_params(axis='both', which='major', labelsize=23, length=10, width=3, color='black')
    ax.tick_params(axis='both', which='minor', labelsize=23, length=10, width=3, color='black')

    if "granger" in col:
        # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").replace("predicted", "").strip()
        title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

    else:
        # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").replace("predicted", "").strip()
        title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

    if "coherence" in col:
        plt.ylim(0.2, 1)
        plt.yticks([0.2, 0.4, 0.6, 0.8, 1]) 
    if "granger" in col:
        plt.ylim(0.2, 0.8)
        plt.yticks([0.2, 0.4, 0.6, 0.8]) 
    elif "power" in col:
        plt.ylim(0, np.round(joined_df[col].max() * 1.25, decimals=3))
        if "theta" in col:
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        elif "gamma" in col:
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))


    elif "theta" in col:
        # plt.ylim(0, 0.05)
        pass
    elif "gamma" in col:
        # plt.ylim(0, 0.0075)
        pass

    # if "coherence" in col:
    #     plt.ylabel('Coherence', size=25)

    # elif "granger" in col:
    #     plt.ylabel('Granger Causality', size=25)

    # elif "power" in col: 
    #     plt.ylabel('Normalized Power (a.u.)', size=25)
    
    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off
    
    plt.xticks(visible=False)

    ax.set(xlabel=None)
    ax.set(ylabel=None)


    plt.xticks(rotation=90)
    plt.title(title.replace("theta", "Theta").replace("gamma", "Gamma").replace("predicted", "").strip(),  pad=50)
    plt.tight_layout()

    if "coherence" in col:
        sub_dir = "coherence"

    elif "granger" in col:
        sub_dir = "granger"

    elif "power" in col: 
        sub_dir = "power"

    plt.savefig("./proc/spectra_bar/{}/{}_bar.png".format(sub_dir, col.replace("cluster_mean_trial_", "").replace("predicted_", "")))
    plt.show()
    # break
    






In [307]:

# joined_df["bar_ticks"] = joined_df["trial_and_competitiveness_label"].map(label_to_ticks)

# for col in [col for col in joined_df if ("cluster_mean_trial" in col and "all_frequencies" not in col and "predicted" in col)]:
#     if "timestamps" in col:
#         continue
    
#     print(col)
#     sns.set_style('white', {'axes.linewidth': 0.5})
#     plt.rcParams['xtick.major.size'] = 20
#     plt.rcParams['xtick.major.width'] = 4
#     plt.rcParams['xtick.bottom'] = True
#     plt.rcParams['ytick.left'] = True

#     # Plot the transition matrix using only matplotlib
#     fig, ax = plt.subplots(figsize=(6.4, 4.8))

#     ax = sns.violinplot(data=joined_df.sort_values(["trial_and_competitiveness_label"]), x='bar_ticks', y=col, hue='trial_and_competitiveness_label',
#         palette=comp_id_to_color,
#                         inner=None, linewidth=0, saturation=1)
    
#     # ax.tick_params(bottom=False)
    
#     # means = joined_df.groupby('bar_ticks')[col].mean()
#     # sems = joined_df.groupby('bar_ticks')[col].sem()

#     # Plot the means with SEM bars
#     # plt.errorbar(means.index, means, yerr=sems, fmt='o', color='black')

#     sns.boxplot(x='bar_ticks', y=col, data=joined_df.sort_values(["trial_and_competitiveness_label"]), color='white',            # hue='competitiveness_grouping',    palette=comp_to_color, 
#                 width=0.2,
#                 # boxprops={'zorder': 2}, 
#                 ax=ax, fill=True, linecolor="black", linewidth=2)

#     plt.legend([],[], frameon=False)

#         # Hide top and right borders
#     ax.spines['top'].set_visible(False)
#     ax.spines['right'].set_visible(False)

#     # Leave bottom and left spines
#     ax.spines['bottom'].set_visible(True)
#     ax.spines['left'].set_visible(True)
#     ax.spines['bottom'].set_linewidth(3)
#     ax.spines['left'].set_linewidth(3)

#     ax.tick_params(axis='both', which='major', labelsize=23, length=10, width=3, color='black')
#     ax.tick_params(axis='both', which='minor', labelsize=23, length=10, width=3, color='black')

#     if "granger" in col:
#         # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").replace("predicted", "").strip()
#         title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

#     else:
#         # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").replace("predicted", "").strip()
#         title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

#     if "coherence" in col or "granger" in col:
#         plt.ylim(0, 1)

#     elif "theta" in col:
#         plt.ylim(0, 0.04)
#     elif "gamma" in col:
#         plt.ylim(0, 0.004)

#     if "coherence" in col:
#         plt.ylabel('Coherence', size=25)

#     elif "granger" in col:
#         plt.ylabel('Granger Causality', size=25)

#     elif "power" in col: 
#         plt.ylabel('Normalized Power (a.u.)', size=25)
    
#     plt.tick_params(
#         axis='x',          # changes apply to the x-axis
#         which='both',      # both major and minor ticks are affected
#         bottom=False,      # ticks along the bottom edge are off
#         top=False,         # ticks along the top edge are off
#         labelbottom=False) # labels along the bottom edge are off
    
#     plt.xticks(visible=False)

#     ax.set(xlabel=None)

#     plt.xticks(rotation=90)
#     plt.title(title.replace("theta", "Theta").replace("gamma", "Gamma").replace("predicted", "").strip())
#     plt.tight_layout()
#     plt.savefig("./proc/spectra_bar/{}_bar.png".format(col.replace("cluster_mean_trial_", "").replace("predicted_", "")))
#     plt.show()
#     # break

In [None]:
for col in [col for col in joined_df if ("cluster_mean_trial" in col and "all_frequencies" not in col and "predicted" in col)]:
    if "timestamps" in col:
        continue
    
    print(col)


    # Plot the transition matrix using only matplotlib
    fig, ax = plt.subplots(figsize=(6.4, 4.8))


    current_df = combined_emmeans_df[combined_emmeans_df["spectra_metric"] == col.replace("predicted_", "")]
    for index, row in current_df.sort_values(by=["trial_and_competitiveness_label"]).iterrows():
        plt.errorbar(row["trial_and_competitiveness_label"], row["emmean"], yerr=row["upper.CL"] - row["emmean"], fmt='o', color=row["color"],elinewidth=4, capsize=10, markeredgewidth=4)
        plt.scatter(row["trial_and_competitiveness_label"], row["emmean"], s = 100, marker = "o", color = row["color"])
        # plt.barh(row["trial_and_competitiveness_label"], row["emmean"], yerr=row["upper.CL"] - row["emmean"], fmt='o', color=row["color"])

        plt.tick_params(
            axis='x',          # changes apply to the x-axis
            which='both',      # both major and minor ticks are affected
            bottom=False,      # ticks along the bottom edge are off
            top=False,         # ticks along the top edge are off
            labelbottom=False) # labels along the bottom edge are off
        
    # plt.ylabel('Estimated Marginal Means', size=20)

        # means.index, means, yerr=sems, fmt='o', color='black')

    # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)

    if "granger" in col:
        # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").replace("predicted", "").strip()
        title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

    else:
        # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").replace("predicted", "").strip()
        title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

    if "coherence" in col or "granger" in col:
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    elif "power" in col:
        if "theta" in col:
            # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        elif "gamma" in col:
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.4f'))


    plt.title(title.replace("theta", "Theta").replace("gamma", "Gamma").replace("predicted", "").strip(),  pad=50)
    plt.tight_layout()
    
    # plt.savefig("./proc/emmeans/{}_emmeans.png".format(col.replace("cluster_mean_trial_", "").replace("predicted_", "")))
    if "coherence" in col:
        sub_dir = "coherence"

    elif "granger" in col:
        sub_dir = "granger"

    elif "power" in col: 
        sub_dir = "power"

    plt.savefig("./proc/emmeans/{}/{}_emmeans.png".format(sub_dir, col.replace("cluster_mean_trial_", "").replace("predicted_", "")))
    
    plt.show()


- Getting all the significant values

In [None]:
all_significance_df = []
for file_path in glob.glob("../../output_lmer_velocity/emmeans_csv/*"):
    if "with_CI" in file_path:
        continue
    current_df = pd.read_csv(file_path, index_col=0)
    all_significance_df.append(current_df)
combined_significance_df = pd.concat(all_significance_df)


In [None]:
raise ValueError()

In [None]:
combined_significance_df["contrast"].unique()

In [482]:
valid_comparisons = [
                    # 'rewarded - lose_no_comp', 
                    'rewarded - win_competitive',
                    'rewarded - lose_competitive', 
                    # 'win_no_comp - lose_competitive',
                    'win_no_comp - lose_no_comp', 
                    'win_no_comp - win_competitive',
                    # 'lose_no_comp - win_competitive',
                    'lose_no_comp - lose_competitive',
                    'win_competitive - lose_competitive', 
                    'rewarded - win_no_comp'
                    ]

In [483]:
combined_significance_df = combined_significance_df[combined_significance_df["p.value"] <= 0.001]

In [484]:
combined_significance_df = combined_significance_df[combined_significance_df["contrast"].isin(valid_comparisons)]

In [485]:
def get_signficance_stars(p_value):
    """
    """
    if p_value <= 0.001:
        return "***"
    elif p_value <= 0.01:
        return "**"
    elif p_value <= 0.05:
        return "*"
    else:
        return ""

In [486]:
combined_significance_df["signficance_stars"] = combined_significance_df["p.value"].apply(lambda x: get_signficance_stars(x))

In [487]:
combined_significance_df["metric"] = combined_significance_df["spectra_metric"].apply(lambda x: x.split("_")[-2])
combined_significance_df["band"] = combined_significance_df["spectra_metric"].apply(lambda x: x.split("_")[-1])

In [488]:
combined_significance_df["contrast"] = combined_significance_df["contrast"].apply(lambda x: sorted([word.strip() for word in x.split("-")]))

In [489]:
combined_significance_df["label_1"] = combined_significance_df["contrast"].apply(lambda x: x[0])
combined_significance_df["label_2"] = combined_significance_df["contrast"].apply(lambda x: x[1])

In [490]:
combined_significance_df["is_comp_comparison"] = combined_significance_df["contrast"].apply(lambda x: x[0].split("_")[0] == x[1].split("_")[0])

In [491]:
combined_significance_df = combined_significance_df.sort_values(["metric", "band", "spectra_metric", "label_1", "label_2"])

In [None]:
combined_significance_df[~combined_significance_df["is_comp_comparison"]].tail(n=50)

In [None]:
combined_significance_df[combined_significance_df["is_comp_comparison"]]

In [None]:
# combined_significance_df[combined_significance_df["metric"] == "coherence"]

combined_significance_df[(combined_significance_df["metric"] == "coherence") & (combined_significance_df["spectra_metric"].str.contains("mPFC"))]

In [None]:
combined_significance_df[(combined_significance_df["metric"] == "coherence") & (combined_significance_df["contrast"].apply(lambda x: "lose_no_comp" in x))]

In [None]:
combined_significance_df[(combined_significance_df["metric"] == "granger") & (combined_significance_df["contrast"].apply(lambda x: "lose_no_comp" in x))]

In [None]:
combined_significance_df["contrast"]

In [179]:
# import matplotlib.pyplot as plt
# import numpy as np

# # Some example data to display
# x = np.linspace(0, 2 * np.pi, 400)
# y = np.sin(x ** 2)





# joined_df["bar_ticks"] = joined_df["trial_and_competitiveness_label"].map(label_to_ticks)

# for col in [col for col in joined_df if ("cluster_mean_trial" in col and "all_frequencies" not in col and "predicted" in col)]:
#     print(col)
#     sns.set_style('white', {'axes.linewidth': 0.5})
#     plt.rcParams['xtick.major.size'] = 20
#     plt.rcParams['xtick.major.width'] = 4
#     plt.rcParams['xtick.bottom'] = True
#     plt.rcParams['ytick.left'] = True

#     # Plot the transition matrix using only matplotlib
#     # fig, ax = plt.subplots(figsize=(6.4, 4.8))

#     fig, ax = plt.subplots(2)
#     # fig.suptitle('Vertically stacked subplots')
#     # ax[0].plot(x, y)
#     # ax[1].plot(x, -y)

#     sns.violinplot(ax=ax[0], data=joined_df.sort_values(["trial_and_competitiveness_label"]), x='bar_ticks', y=col, hue='trial_and_competitiveness_label',
#         palette=comp_id_to_color,
#                         inner=None, linewidth=0, saturation=1)

#     sns.boxplot(ax=ax[0], x='bar_ticks', y=col, data=joined_df.sort_values(["trial_and_competitiveness_label"]), color='white',            # hue='competitiveness_grouping',    palette=comp_to_color, 
#                 width=0.2,
#                 # boxprops={'zorder': 2}, 
#                 fill=True, linecolor="black", linewidth=2)



#         # Hide top and right borders
#     ax[0].spines['top'].set_visible(False)
#     ax[0].spines['right'].set_visible(False)

#     # Leave bottom and left spines
#     ax[0].spines['bottom'].set_visible(True)
#     ax[0].spines['left'].set_visible(True)
#     ax[0].spines['bottom'].set_linewidth(3)
#     ax[0].spines['left'].set_linewidth(3)

#     ax[0].tick_params(axis='both', which='major', labelsize=23, length=10, width=3, color='black')
#     ax[0].tick_params(axis='both', which='minor', labelsize=23, length=10, width=3, color='black')



#     if "granger" in col:
#         # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").replace("predicted", "").strip()
#         title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

#     else:
#         # title = "{}".format(col.replace(".", "-").replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").replace("predicted", "").strip()
#         title = col.replace("predicted_cluster_mean_trial_", "").replace("power", "").replace("granger", "").replace("coherence", "").replace("_", " ").replace("  ", " ").replace(".", " ")

#     if "coherence" in col or "granger" in col:
#         plt.ylim(0, 1)

#     elif "theta" in col:
#         plt.ylim(0, 0.04)
#     elif "gamma" in col:
#         plt.ylim(0, 0.004)

#     # if "coherence" in col:
#     #     plt.ylabel('Coherence', size=30)

#     # elif "granger" in col:
#     #     plt.ylabel('Granger Causality', size=30)

#     # elif "power" in col: 
#     #     plt.ylabel('Normalized Power\n(a.u.)', size=30)
    
#     plt.tick_params(
#         axis='x',          # changes apply to the x-axis
#         which='both',      # both major and minor ticks are affected
#         bottom=False,      # ticks along the bottom edge are off
#         top=False,         # ticks along the top edge are off
#         labelbottom=False) # labels along the bottom edge are off

#     plt.xticks(rotation=90)
#     plt.title(title.replace("theta", "Theta").replace("gamma", "Gamma").replace("predicted", "").strip())

#     plt.legend([],[], frameon=False)

#     plt.tight_layout()
#     # plt.savefig("./proc/spectra_bar/{}_bar.png".format(col.replace("cluster_mean_trial_", "").replace("predicted_", "")))
#     plt.show()
#     raise ValueError()






In [None]:
current_df

In [181]:
joined_df = all_predicted_df[0]
for current_df in all_predicted_df[1:]:
    joined_df = pd.merge(joined_df, current_df, on="index_num", suffixes=("", "_drop"))
    joined_df = joined_df.drop(columns=[col for col in joined_df if "_drop" in col])

In [None]:
joined_df

In [None]:
raise ValueError()

In [None]:
joined_df

In [None]:

joined_df["bar_ticks"] = joined_df["trial_and_competitiveness_label"].map(label_to_ticks)

for col in [col for col in joined_df if ("cluster_mean_trial" in col and "all_frequencies" not in col and "predicted" in col)]:
    print(col)
    sns.set_style('white', {'axes.linewidth': 0.5})
    plt.rcParams['xtick.major.size'] = 20
    plt.rcParams['xtick.major.width'] = 4
    plt.rcParams['xtick.bottom'] = True
    plt.rcParams['ytick.left'] = True

    # Plot the transition matrix using only matplotlib
    fig, ax = plt.subplots(figsize=(6.4, 4.8))

    ax = sns.violinplot(data=joined_df.sort_values(["trial_and_competitiveness_label"]), x='subject_id', y=col, 
                        inner=None, linewidth=0, saturation=1)
    
    # means = joined_df.groupby('bar_ticks')[col].mean()
    # sems = joined_df.groupby('bar_ticks')[col].sem()

    # Plot the means with SEM bars
    # plt.errorbar(means.index, means, yerr=sems, fmt='o', color='black')

    sns.boxplot(x='subject_id', y=col, data=joined_df.sort_values(["trial_and_competitiveness_label"]), color='white',            # hue='competitiveness_grouping',    palette=comp_to_color, 
                width=0.2,
                # boxprops={'zorder': 2}, 
                ax=ax, fill=True, linecolor="black", linewidth=2)

    plt.legend([],[], frameon=False)

        # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)

    ax.tick_params(axis='both', which='major', labelsize=23, length=10, width=3, color='black')
    ax.tick_params(axis='both', which='minor', labelsize=23, length=10, width=3, color='black')



    if "granger" in col:
        title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").replace("predicted", "").strip()
    else:
        title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").replace("predicted", "").strip()

    if "coherence" in col or "granger" in col:
        plt.ylim(0, 1)

    elif "theta" in col:
        plt.ylim(0, 0.04)
    elif "gamma" in col:
        plt.ylim(0, 0.004)

    if "coherence" in col:
        plt.ylabel('Coherence', size=30)

    elif "granger" in col:
        plt.ylabel('Granger Causality', size=30)

    elif "power" in col: 
        plt.ylabel('Normalized Power (a.u.)', size=30)
    
    # plt.tick_params(
    #     axis='x',          # changes apply to the x-axis
    #     which='both',      # both major and minor ticks are affected
    #     bottom=False,      # ticks along the bottom edge are off
    #     top=False,         # ticks along the top edge are off
    #     labelbottom=False) # labels along the bottom edge are off

    plt.xticks(rotation=90)
    plt.title(title.replace("theta", "Theta").replace("gamma", "Gamma").replace("predicted", "").strip())
    plt.tight_layout()
    # plt.savefig("./proc/spectra_bar/{}_bar.png".format(col.replace("cluster_mean_trial_", "").replace("predicted_", "")))
    plt.show()
    






In [None]:
raise ValueError()

In [None]:
joined_df

In [None]:
for df in all_predicted_df:
    print(df.shape)

In [None]:
all_predicted_df["cur"]

In [None]:
all_predicted_df[0].head()

In [None]:
all_predicted_df[1]

In [None]:
raise ValueError(312)

In [None]:
raise ValueError()

In [None]:
sns.set_style('white', {'axes.linewidth': 0.5})
plt.rcParams['xtick.major.size'] = 20
plt.rcParams['xtick.major.width'] = 4
plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True

# Plot the transition matrix using only matplotlib
fig, ax = plt.subplots(figsize=(6.4, 4.8))

ax = sns.violinplot(data=distance_mean_df.sort_values(["trial_label", "competitiveness_grouping"]), x='trial_label_and_competitiveness_grouping', y='mean_distance', hue='competitiveness_grouping',
    palette=comp_id_to_color,
                    inner=None, linewidth=0, saturation=1)

sns.boxplot(x='trial_label_and_competitiveness_grouping', y='mean_distance', data=distance_mean_df.sort_values(["trial_label", "competitiveness_grouping"]), color='white',            # hue='competitiveness_grouping',    palette=comp_to_color, 
            width=0.1,
            # boxprops={'zorder': 2}, 
            ax=ax, fill=True, linecolor="black",)

plt.legend([],[], frameon=False)

    # Hide top and right borders
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Leave bottom and left spines
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_linewidth(3)
ax.spines['left'].set_linewidth(3)

ax.tick_params(axis='both', which='major', labelsize=23, length=10, width=3, color='black')
ax.tick_params(axis='both', which='minor', labelsize=23, length=10, width=3, color='black')

# Adding labels and title
plt.ylabel('Distance (cm)', size=30)
plt.xlabel('')

plt.tight_layout()
# plt.savefig("./proc/distance_violin_plot.png")
plt.show()



In [None]:
raise ValueError()

## Looking at cluster related stuff

- Parsing velocity during relavent cluster

In [None]:
for col in export_cluster_exploded_TRIALS_AND_SPECTRAL_DF.columns:
    print(col)

In [None]:
export_cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_mean_trial_BLA-to-LH_granger_gamma"]

In [None]:
raise ValueError()

- Code for verifying spectram

In [None]:
# low_freq = 0
# high_freq = 51
# current_frequencies = range(low_freq, high_freq)

# # Iterating through each brain region
# for col in cluster_mean_columns:
#     if "all_frequencies" not in col:
#         continue
#     print(col)

#     fig, ax = plt.subplots()
#     plt.xlim(low_freq, high_freq-1) 
#     plt.xticks(np.arange(low_freq, high_freq-1, 5))

#     stacked_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF.groupby("trial_and_competitiveness_label").agg({col: stack_arrays})
#     stacked_df = stacked_df.reset_index()
#     stacked_df = stacked_df[~stacked_df["trial_and_competitiveness_label"].str.contains("tie")]

#     stacked_df["color"] = stacked_df["trial_and_competitiveness_label"].map(comp_id_to_color)
    
#     if "power" in col:
#         # plt.ylim(0,0.01)
#         # plt.yscale("log")
#         # plt.ylim((10**-3.5,10**-0.5))
#         # plt.set_ylim(auto=True)
#         plt.ylim()
#         ax.set_yscale('log')
#         # plt.ylim(0, max_value)
        
#     else:
#         # pass
#         plt.ylim(0, 0.8)
    
#     # Iterating through each trial type
#     counter = 0
#     for index, row in stacked_df.iloc[:100].iterrows():
#         print(row["trial_and_competitiveness_label"])
#         print(row[col].shape)
#         for trace in row[col]: 
#             plt.plot(current_frequencies, trace[low_freq: high_freq], color=row["color"], alpha=0.5)
#             counter += 1
#             if counter >= 100:
#                 break
    
#     if "granger" in col:
#         title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip().replace(" ", " to ").strip()
#     else:
#         title = "{}".format(col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post_", "").replace("RMS_filtered_", "").replace("all_frequencies_all_windows", "").replace("power", "").replace("granger", "")).replace("coherence", "").replace("_", " ").replace("  ", " ").strip()

#     plt.title(title)
#     plt.xlabel("Frequency (Hz)")
    
#     if "power" in col:
#         # plt.ylabel("Normalized Power (a.u.)")
#         output_dir = "./proc/cluster_spectra_plots/power"

#     elif "coherence" in col:
#         # plt.ylabel("Coherence")
#         output_dir = "./proc/cluster_spectra_plots/coherence"

#     elif "granger" in col:
#         # plt.ylabel("Granger's Causality")
#         output_dir = "./proc/cluster_spectra_plots/granger"
    
#     os.makedirs(output_dir, exist_ok=True)

#     # plt.legend(fontsize=10)

#     # Hide top and right borders
#     ax.spines['top'].set_visible(False)
#     ax.spines['right'].set_visible(False)

#     # Leave bottom and left spines
#     ax.spines['bottom'].set_visible(True)
#     ax.spines['left'].set_visible(True)
#     ax.spines['bottom'].set_linewidth(2)
#     ax.spines['left'].set_linewidth(2)

#     plt.tight_layout()    
#     plt.show()
#     break

In [101]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF.to_pickle("./proc/cluster_exploded_TRIALS_AND_SPECTRAL_DF.pkl")

# Plotting averages

In [102]:
GROUPING = "trial_and_competitiveness_label"

In [103]:
theta_gamma_mean_columns = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF if "cluster_mean" in col and "all_frequencies_all_windows" not in col and ("gamma" in col or "theta" in col)]

In [None]:
theta_gamma_mean_columns

In [None]:
col

In [107]:
# for col in theta_gamma_mean_columns:
#     fig, ax = plt.subplots()
#     mean_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).mean()[col]).rename(columns={col: "mean"})
#     sem_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).agg("sem")[col]).rename(columns={col: "sem"})
#     merged_df = pd.merge(mean_df, sem_df, on=GROUPING).reset_index().sort_values([GROUPING])
#     merged_df["measurement"] = col
#     for index, row in merged_df.iterrows():
#         if "tie" in row[GROUPING]:
#             continue
#         plt.bar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], color=comp_id_to_color[row[GROUPING]])
#         plt.errorbar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], yerr=row["sem"], capsize=20, ecolor = "black")

#     ax.tick_params(axis='x', rotation=90)
#     brain_region = col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post", "").replace("coherence", "").replace("granger", "").replace("power", "").replace("__", "_").replace("_", " ").strip()
#     plt.title(brain_region.replace("gamma", "").replace("theta", "").strip())
    
#     if "power" in col:
#         plt.ylim()
#     else:
#         plt.ylim(0, 0.8)


#     plt.tight_layout()
#     # plt.savefig("./proc/cluster_average_spectra/{}_averages.png".format(col.replace("cluster_mean_trial_", "")))
#     plt.close()
#     # plt.show()
#     # for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique():
#     #     current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING] == group]
    

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["trial_and_competitiveness_label"].unique()

In [112]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[["trial_and_competitiveness_label", "current_subject", "recording", "video_name"] + theta_gamma_mean_columns].to_excel("./proc/competitiveness_trials_and_spectral_mean.xlsx")

In [113]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[["trial_and_competitiveness_label", "current_subject", "recording", "video_name"] + theta_gamma_mean_columns].to_csv("./proc/competitiveness_trials_and_spectral_mean.csv")

# Plotting power bar graphs

In [116]:
    x_pos = np.arange(0,1,len(merged_df))

In [None]:
x_pos

In [180]:
for col in theta_gamma_mean_columns:
    if "power" not in col:
        continue

    fig, ax = plt.subplots(figsize=(8,4))
    # plt.figure(figsize=(10,6))

    mean_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).mean()[col]).rename(columns={col: "mean"})
    sem_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).agg("sem")[col]).rename(columns={col: "sem"})
    merged_df = pd.merge(mean_df, sem_df, on=GROUPING).reset_index().sort_values([GROUPING])
    merged_df["measurement"] = col
    x_pos = np.linspace(0,0.5,len(merged_df)) * 0.9
    
    for index, row in merged_df.sort_values(by=["trial_and_competitiveness_label"]).iloc[[3,4,2,0,1]].reset_index().iterrows():
        
        plt.bar(x_pos[index], row["mean"], color=comp_id_to_color[row[GROUPING]], width=0.5/len(merged_df))
        plt.errorbar(x_pos[index], row["mean"], yerr=row["sem"], capsize=20, ecolor = "black", elinewidth=3, markeredgewidth=3)

        # plt.bar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], color=comp_id_to_color[row[GROUPING]], width=0.5)
        # plt.errorbar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], yerr=row["sem"], capsize=20, ecolor = "black")

    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off

    ax.tick_params(axis='x', rotation=90)
    brain_region = col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post", "").replace("coherence", "").replace("granger", "").replace("power", "").replace("__", "_").replace("_", " ").strip()
    # plt.title(brain_region.replace("gamma", "").replace("theta", "").strip())
    
    if "theta" in col:
        plt.ylim(0, 0.04)
    elif "gamma" in col:
        plt.ylim(0, 0.004)
    else:
        plt.ylim(0, 0.8)
    ax.tick_params(length=10, width=4)
    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    plt.yticks(fontsize=30)

    # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(4)
    ax.spines['left'].set_linewidth(4)

    plt.tight_layout()
    plt.savefig("./proc/cluster_average_spectra/power/{}_averages.png".format(col.replace("cluster_mean_trial_", "")), transparent=True)
    plt.close()
    # plt.show()
    # for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique():
    #     current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING] == group]
    

# Plotting coherence plots

In [236]:
for col in theta_gamma_mean_columns:
    if "coherence" not in col:
        continue


    fig, ax = plt.subplots(figsize=(8,4))
    # plt.figure(figsize=(10,6))

    mean_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).mean()[col]).rename(columns={col: "mean"})
    sem_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).agg("sem")[col]).rename(columns={col: "sem"})
    merged_df = pd.merge(mean_df, sem_df, on=GROUPING).reset_index().sort_values([GROUPING])
    merged_df["measurement"] = col
    x_pos = np.linspace(0,0.5,len(merged_df)) * 0.9
    
    for index, row in merged_df.sort_values(by=["trial_and_competitiveness_label"]).iloc[[3,4,2,0,1]].reset_index().iterrows():
        
        plt.bar(x_pos[index], row["mean"], color=comp_id_to_color[row[GROUPING]], width=0.5/len(merged_df))
        plt.errorbar(x_pos[index], row["mean"], yerr=row["sem"], capsize=20, ecolor = "black", elinewidth=3, markeredgewidth=3)

        # plt.bar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], color=comp_id_to_color[row[GROUPING]], width=0.5)
        # plt.errorbar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], yerr=row["sem"], capsize=20, ecolor = "black")

    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off

    ax.tick_params(axis='x', rotation=90)
    brain_region = col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post", "").replace("coherence", "").replace("granger", "").replace("power", "").replace("__", "_").replace("_", " ").strip()
    # plt.title(brain_region.replace("gamma", "").replace("theta", "").strip())
    
    if "theta" in col:
        plt.ylim(0, 0.8)
    elif "gamma" in col:
        plt.ylim(0, 0.8)
    else:
        plt.ylim(0, 0.8)
    ax.tick_params(length=10, width=4)
    plt.yticks(fontsize=30)

    # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(4)
    ax.spines['left'].set_linewidth(4)

    plt.tight_layout()
    plt.savefig("./proc/cluster_average_spectra/coherence/{}_averages.png".format(col.replace("cluster_mean_trial_", "")), transparent=True)
    plt.close()
    # plt.show()
    # for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique():
    #     current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING] == group]
    

# Plotting Granger averages

In [235]:
for col in theta_gamma_mean_columns:
    if "granger" not in col:
        continue


    fig, ax = plt.subplots(figsize=(8,4))
    # plt.figure(figsize=(10,6))

    mean_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).mean()[col]).rename(columns={col: "mean"})
    sem_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).agg("sem")[col]).rename(columns={col: "sem"})
    merged_df = pd.merge(mean_df, sem_df, on=GROUPING).reset_index().sort_values([GROUPING])
    merged_df["measurement"] = col
    x_pos = np.linspace(0,0.5,len(merged_df)) * 0.9
    
    for index, row in merged_df.sort_values(by=["trial_and_competitiveness_label"]).iloc[[3,4,2,0,1]].reset_index().iterrows():
        
        plt.bar(x_pos[index], row["mean"], color=comp_id_to_color[row[GROUPING]], width=0.5/len(merged_df))
        plt.errorbar(x_pos[index], row["mean"], yerr=row["sem"], capsize=20, ecolor = "black", elinewidth=3, markeredgewidth=3)

        # plt.bar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], color=comp_id_to_color[row[GROUPING]], width=0.5)
        # plt.errorbar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], yerr=row["sem"], capsize=20, ecolor = "black")

    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off

    ax.tick_params(axis='x', rotation=90)
    brain_region = col.replace("cluster_mean_trial_", "").replace("cluster_mean_trial_and_post", "").replace("coherence", "").replace("granger", "").replace("power", "").replace("__", "_").replace("_", " ").strip()
    # plt.title(brain_region.replace("gamma", "").replace("theta", "").strip())
    
    if "theta" in col:
        plt.ylim(0, 0.8)
    elif "gamma" in col:
        plt.ylim(0, 0.8)
    else:
        plt.ylim(0, 0.8)
    ax.tick_params(length=10, width=4)
    plt.yticks(fontsize=30)

    # Hide top and right borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Leave bottom and left spines
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_linewidth(4)
    ax.spines['left'].set_linewidth(4)

    plt.tight_layout()
    plt.savefig("./proc/cluster_average_spectra/granger/{}_averages.png".format(col.replace("cluster_mean_trial_", "")), transparent=True)
    plt.close()
    # plt.show()
    # for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique():
    #     current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING] == group]
    

In [None]:
merged_df

# Checking significance

In [None]:
all_significance_csv = []
for file_path in glob.glob("./significance_output/emmeans_coherence/*csv"):
    current_df = pd.read_csv(file_path, index_col=0)
    file_name_parts = os.path.basename(file_path).replace(".csv", "").replace("cluster_mean_trial_", "").replace("_emmeans", "").replace("_coherence_", " ").split(" ")
    
    print(file_name_parts)
    current_df["brain_region"] = file_name_parts[0]
    current_df["band"] = file_name_parts[1]

    all_significance_csv.append(current_df)

In [196]:
all_significance_df = pd.concat(all_significance_csv)

In [198]:
all_significance_df["p.value"] = all_significance_df["p.value"].astype(float)

In [None]:
all_significance_df["contrast"].unique()

In [204]:
good_contrasts = ['rewarded - win_competitive', 'rewarded - win_no_comp',
       'win_competitive - win_no_comp',
       'win_competitive - lose_competitive',
       'win_no_comp - lose_no_comp', 'lose_competitive - lose_no_comp']

In [None]:
all_significance_df[(all_significance_df["p.value"] <= 0.001) & (all_significance_df["contrast"].isin(good_contrasts)) & (~all_significance_df["brain_region"].str.contains("mPFC"))].sort_values(by=["band", "brain_region", "contrast"])

In [None]:
all_significance_csv[0]

# OLD Stuff

In [None]:
granger_col = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF.columns if ("theta" in col or "gamma" in col) and "granger" in col and "mean" in col]

In [None]:
granger_col

In [None]:
for col in granger_col:
    fig, ax = plt.subplots()
    mean_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).mean()[col]).rename(columns={col: "mean"})
    sem_df = pd.DataFrame(cluster_exploded_TRIALS_AND_SPECTRAL_DF[[GROUPING, col]].groupby(GROUPING).agg("sem")[col]).rename(columns={col: "sem"})
    merged_df = pd.merge(mean_df, sem_df, on=GROUPING).reset_index().sort_values([GROUPING])
    merged_df["measurement"] = col
    for index, row in merged_df.iterrows():
        plt.bar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], color=comp_id_to_color[row[GROUPING]])
        plt.errorbar("\n".join(row[GROUPING].split("_")).replace("competitive", "comp"), row["mean"], yerr=row["sem"], capsize=20, ecolor = "black")

    ax.tick_params(axis='x', rotation=90)
    brain_region = col.replace("cluster_mean_trial_and_post", "").replace("coherence", "").replace("__", "_").replace("_", " ").strip()
    plt.title(brain_region.replace("gamma", "").replace("theta", "").strip())
    plt.ylim(0, 0.6)
    plt.tight_layout()
    plt.savefig("./proc/cluster_spectra_plots/mean_granger/{} granger.png".format(brain_region))
    # for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique():
    #     current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING] == group]
    

In [None]:
raise ValueError()

# LDA

In [None]:
power_columns = [col for col in TRIALS_AND_SPECTRAL_DF if ("gamma" in col or "theta" in col) and ("trial_and_post" in col or "cluster" in col) and "phase" not in col and "band" not in col or "cluster_filtered" in col] 

In [None]:
power_columns

- Making a separate row for each cluster

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF[to_keep_columns + power_columns].explode(["cluster_filtered_power_timestamps"])

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_filtered_granger_timestamps"]

- Making independent columns for cluster name and indexes

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["power_cluster_id"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_filtered_power_timestamps"].apply(lambda x: x[0])
cluster_exploded_TRIALS_AND_SPECTRAL_DF["power_cluster_indexes"] = cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_filtered_power_timestamps"].apply(lambda x: x[1])

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF

In [None]:
power_columns = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF if "granger" in col and "trial_and_post" in col and "timestamps" not in col]


power_columns = [col for col in TRIALS_AND_SPECTRAL_DF if ("gamma" in col or "theta" in col) and ("trial_and_post" in col or "cluster" in col) and "phase" not in col and "band" not in col] 

In [None]:
power_columns

- Getting the timestamps of all the clusters

In [None]:
for col in power_columns:
    print(col)
    cluster_exploded_TRIALS_AND_SPECTRAL_DF["cluster_all_{}".format(col)] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: x[col][x["power_cluster_indexes"]], axis=1)

In [None]:
cluster_all_columns = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF if "cluster_all" in col]

In [None]:
cluster_all_columns

- Aggregating all the values within a given trial

In [None]:
for col in cluster_all_columns:
    updated_column = col.replace("cluster_all", "cluster_mean")
    if "gamma" in col or "theta" in col:
        cluster_exploded_TRIALS_AND_SPECTRAL_DF[updated_column] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: np.nanmean(x[col]), axis=1)
    else:
        cluster_exploded_TRIALS_AND_SPECTRAL_DF[updated_column] = cluster_exploded_TRIALS_AND_SPECTRAL_DF.apply(lambda x: np.nanmean(x[col], axis=0), axis=1)
    print(updated_column)
    print(cluster_exploded_TRIALS_AND_SPECTRAL_DF[updated_column].iloc[0])
        

In [None]:
cluster_mean_columns = [col for col in cluster_exploded_TRIALS_AND_SPECTRAL_DF if "cluster_mean" in col]

In [None]:
cluster_mean_columns = [col for col in cluster_mean_columns if "mPFC" in col and "power" not in col and "coherence" in col]

In [None]:
cluster_mean_columns

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[~cluster_exploded_TRIALS_AND_SPECTRAL_DF["power_cluster_id"].str.contains("tie")]

In [None]:
# cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["power_cluster_id"].isin(["lose_no_comp", "lose_competitive"])].reset_index(drop=True)

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["power_cluster_id"]

In [None]:
# GROUPING = "competition_closeness"


GROUPING = "trial_label"	
GROUPING = "competition_closeness"
GROUPING = "power_cluster_id"





In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF["factorized_{}".format(GROUPING)], unique = pd.factorize(cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING])


In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF = cluster_exploded_TRIALS_AND_SPECTRAL_DF.dropna(subset=cluster_mean_columns)

In [None]:
network_array = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_mean_columns].values

In [None]:
cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique()

# Reducing Dimensions

In [None]:
network_array

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from umap import UMAP


In [None]:
scaler = StandardScaler()
reduced_dimension_network_array = scaler.fit_transform(network_array)





In [None]:
# lda = LinearDiscriminantAnalysis()
# lda = LinearDiscriminantAnalysis(n_components=2)
# reduced_dimension_network_array = lda.fit_transform(reduced_dimension_network_array, cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING])
# reduced_dimension_network_array = lda.fit_transform(reduced_dimension_network_array)

In [None]:
# lda = PCA(n_components=2)
# reduced_dimension_network_array = lda.fit_transform(reduced_dimension_network_array)

In [None]:
reduced_dimension_network_array = UMAP(
    n_components=2,
    random_state=42,
).fit_transform(network_array, y=cluster_exploded_TRIALS_AND_SPECTRAL_DF["factorized_{}".format(GROUPING)].values)


In [None]:
reduced_dimension_network_array.shape

In [None]:
# reduced_dimension_network_array = lda.fit_transform(reduced_dimension_network_array, cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING])


In [None]:
reduced_dimension_network_array.shape

In [None]:
GROUPING = "current_subject"
# GROUPING = "video_name"
# GROUPING = "power_cluster_id"


In [None]:
reduced_dimension_network_array

In [None]:
plt.figure(figsize=(5,5))
for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique():
    current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING] == group]
    print(current_df.index)
    plot = plt.scatter(reduced_dimension_network_array[current_df.index,0], reduced_dimension_network_array[current_df.index,1], label=group, alpha=0.8)#, c=comp_id_to_color[group], )
    
plt.title("LDA of LFP features", fontsize=20)
plt.xlabel("LD1", fontsize=20)
plt.ylabel("LD2", fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
# plt.legend()
plt.tight_layout()

os.makedirs("./proc/network", exist_ok=True)

plt.savefig("./proc/network/rf_LDA_outcome_labeled.png")
plt.savefig("./proc/network/rf_LDA_outcome_labeled.eps")
plt.show()

In [None]:
plt.figure(figsize=(5,5))
for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING].unique():
    current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF[GROUPING] == group]
    print(current_df.index)
    plot = plt.hist(reduced_dimension_network_array[current_df.index,0], color=comp_id_to_color[group], label=group, alpha=0.8)
    
plt.title("LDA of LFP features", fontsize=20)
plt.xlabel("LD1", fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
# plt.legend()
plt.tight_layout()

os.makedirs("./proc/network", exist_ok=True)

plt.savefig("./proc/network/rf_LDA_outcome_labeled.png")
plt.savefig("./proc/network/rf_LDA_outcome_labeled.eps")
plt.show()

In [None]:
raise ValueError()

In [None]:
plt.figure(figsize=(5,5))
for group in cluster_exploded_TRIALS_AND_SPECTRAL_DF["current_subject"].unique():
    current_df = cluster_exploded_TRIALS_AND_SPECTRAL_DF[cluster_exploded_TRIALS_AND_SPECTRAL_DF["current_subject"] == group]
    
    plot = plt.scatter(reduced_dimension_network_array[current_df.index,0], reduced_dimension_network_array[current_df.index,1], label=group)

    
plt.title("LDA of LFP features", fontsize=20)
plt.xlabel("LD1", fontsize=20)
plt.ylabel("LD2", fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.legend()
plt.tight_layout()
plt.savefig("./proc/network/rf_LDA_subject_labeled.png")
plt.savefig("./proc/network/rf_LDA_subject_labeled.eps")
plt.show()

In [None]:
# Get the coefficients (loadings) of each feature
loadings = lda.scalings_

# Calculate importance for each feature for each discriminant
# The importance is calculated as the square of each coefficient
importance = np.square(loadings)



In [None]:
lda_importance = pd.DataFrame(importance)
lda_importance["feature"] = feature_columns

In [None]:
important_features = []

In [None]:
lda_importance.sort_values(0, ascending=False)

In [None]:
important_features += lda_importance.sort_values(0, ascending=False).head(n=4)["feature"].to_list()

In [None]:
important_features += lda_importance.sort_values(1, ascending=False).head(n=4)["feature"].to_list()

In [None]:
filtered_lda_importance = lda_importance[lda_importance["feature"].isin(important_features)]

In [None]:
filtered_lda_importance = filtered_lda_importance.sort_values(0, ascending=False)

In [None]:
filtered_lda_importance["feature"] = filtered_lda_importance["feature"].apply(lambda x: x.replace("_band_", "\n").replace("_", " "))

In [None]:
# Sample data
list1 = filtered_lda_importance[0]
list2 = filtered_lda_importance[1]
indices = range(len(list1))

# Creating the plot
fig, ax1 = plt.subplots(figsize=(6,6))
    # fig, ax = plt.subplots(figsize=(6,6))


# Plotting the first list
color1 = '#15616F'
# FFAF00; teal #15616F
ax1.bar(indices, list1, width=0.4, align='center', color=color1)
# ax1.set_xlabel('Index')
ax1.set_ylabel('LD1 Scalings', color=color1, fontsize=20)
ax1.tick_params(axis='y', labelcolor=color1, rotation=90)

# Creating a second Y-axis for the second list
ax2 = ax1.twinx()  
color2 = 'tab:red'
ax2.bar([i + 0.4 for i in indices], list2, width=0.4, align='center', color=color2)
ax2.set_ylabel('LD2 Scalings', color=color2, fontsize=20)
ax2.tick_params(axis='y', labelcolor=color2, rotation=90)

# Adjusting the X-axis to show labels correctly
ax1.set_xticks(indices, filtered_lda_importance['feature'].values, rotation = 90, fontsize=15)
# xticks(indices, filtered_lda_importance['feature'].values, rotation = 90)
# plt.xticks(labels=filtered_lda_importance['feature'])

plt.title('LDA feature importance', fontsize=20)
plt.tight_layout()
plt.savefig("./proc/network/lda_feature_importance.png")
plt.savefig("./proc/network/lda_feature_importance.eps")

plt.show()

In [None]:
raise ValueError()

- Grouping all the rows with the same video and subject together

In [None]:
list(TRIALS_AND_SPECTRAL_DF.columns)

In [None]:
explode_columns

In [None]:
other_explode_columns = ["tone_stop_frame",
                         "condition",
                         "competition_closeness",
                         "notes",
                         "10s_before_tone_frame",
                         "10s_after_tone_frame",
                        'cluster_index_ranges_dict',
 'cluster_times',
 'cluster_times_ranges_dict',
 'cluster_timestamps_ranges_dict',
 'trial_cluster_times_ranges_dict',
 'trial_cluster_timestamps_ranges_dict']

In [None]:
filter_columns

In [None]:
# Define columns to be transformed into numpy arrays


# Define aggregation dictionary
agg_dict = {col: list for col in explode_columns + other_explode_columns if col not in groupby_columns and col != "tone_start_frame"}

agg_dict.update({col: 'first' for col in filter_columns + ["tone_start_frame"] if col not in groupby_columns and col not in other_explode_columns})

# Apply groupby and aggregation
video_TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.groupby(["video_name", "current_subject"]).agg(agg_dict).reset_index()


In [None]:
video_TRIALS_AND_SPECTRAL_DF.columns

In [None]:
video_TRIALS_AND_SPECTRAL_DF.head()

- Combining all the dictionaries together

In [None]:
for col in [_ for _ in video_TRIALS_AND_SPECTRAL_DF if "dict" in _]:
    video_TRIALS_AND_SPECTRAL_DF[col] = video_TRIALS_AND_SPECTRAL_DF[col].apply(lambda x: combine_dicts(x))

In [None]:
video_TRIALS_AND_SPECTRAL_DF.head()

In [None]:
video_TRIALS_AND_SPECTRAL_DF.to_pickle("./proc/{}_cluster_ranges.pkl".format(OUTPUT_PREFIX))

# Filtering out the SLEAP posed for during trials

In [None]:
raise ValueError()

In [None]:
TRIALS_AND_SPECTRAL_DF

In [None]:
TRIALS_AND_SPECTRAL_DF["video_frame"] = TRIALS_AND_SPECTRAL_DF["video_timestamps"].apply(lambda x: np.array(list(range(len(x)))) + 1)

In [None]:
TRIALS_AND_SPECTRAL_DF["video_frame"].head().apply(lambda x: x.shape)

In [None]:
['subject_thorax_to_agent_thorax',
 'nose_to_reward_port_sum',
 'nose_to_reward_port_diff',
 'thorax_velocity_sum',
 'thorax_velocity_diff',
 'to_reward_port_angle_sum',
 'to_reward_port_angle_diff',
 'subject_nose_to_reward_port',
 'subject_thorax_velocity',
 'subject_to_reward_port_angle',
 'agent_nose_to_reward_port',
 'agent_thorax_velocity',
 'agent_to_reward_port_angle',
 'closebool_subject_nose_to_reward_port',
 'closebool_agent_nose_to_reward_port',
 'movingbool_subject_thorax_velocity',
 'movingbool_agent_thorax_velocity',
 'manual_cluster_id',
 'standard_embedding_x',
 'standard_embedding_y',
 'kmeans_cluster',
 'subject_locations',
 'agent_locations',
 'subject_thorax',
 'subject_nose',
 'subject_tail_base',
 'agent_thorax',
 'agent_nose',
 'agent_tail_base']

In [None]:
# sleap_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "locations" in col or "velocity" in col or "to_reward_port" in col or "video_frame" in col]

In [None]:
sleap_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if ("thorax" in col or "nose" in col or "reward_port" in col or "standard_embedding" in col or "cluster" in col or "frame_index" in col or "locations" in col or "tailbase" in col) and "timestamp" not in col]


In [None]:
sleap_columns

In [None]:
for col in sleap_columns:
    print(col)
    print(TRIALS_AND_SPECTRAL_DF[col].apply(lambda x: x.shape).iloc[0])

In [None]:
for col in sorted(sleap_columns):
    updated_item_col = "trial_and_post_{}".format(col)
    print(updated_item_col)
    updated_timestamp_col = "trial_and_post_video_timestamps"
    TRIALS_AND_SPECTRAL_DF[updated_item_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["cluster_timestamp"], items=x[col])[1], axis=1)

TRIALS_AND_SPECTRAL_DF[updated_timestamp_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["cluster_timestamp"], items=x[col])[0], axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF.head()

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_frame_index"].iloc[0].shape

In [None]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.drop(columns=sleap_columns + ["cluster_timestamp"], errors="ignore")

In [None]:
for col in TRIALS_AND_SPECTRAL_DF.columns:
    print(col)

- Filtering coherence

In [None]:
coherence_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "coherence" in col and "timestamps" not in col and "calculation" not in col]

In [None]:
coherence_columns

In [None]:
for col in sorted(coherence_columns):
    updated_item_col = "trial_and_post_{}".format(col)
    print(updated_item_col)
    updated_timestamp_col = "trial_and_post_coherence_timestamps"
    TRIALS_AND_SPECTRAL_DF[updated_item_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["coherence_timestamps"], items=x[col])[1], axis=1)

TRIALS_AND_SPECTRAL_DF[updated_timestamp_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["coherence_timestamps"], items=x[col])[0], axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.drop(columns=coherence_columns + ["coherence_timestamps"], errors="ignore")

- Filtering Grangers

In [None]:
granger_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "granger" in col and "timestamps" not in col and "calculation" not in col]

In [None]:
granger_columns

In [None]:
for col in sorted(granger_columns):
    updated_item_col = "trial_and_post_{}".format(col)
    print(updated_item_col)
    updated_timestamp_col = "trial_and_post_granger_timestamps"
    TRIALS_AND_SPECTRAL_DF[updated_item_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["granger_timestamps"], items=x[col])[1], axis=1)

TRIALS_AND_SPECTRAL_DF[updated_timestamp_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["granger_timestamps"], items=x[col])[0], axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.drop(columns=granger_columns + ["granger_timestamps"], errors="ignore")

In [None]:
TRIALS_AND_SPECTRAL_DF.to_pickle(os.path.join(OUTPUT_DIR, FULL_LFP_TRACES_PKL))

- Filtering power

In [None]:
TRIALS_AND_SPECTRAL_DF.head()

In [None]:
power_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "power" in col and "timestamps" not in col and "calculation" not in col]

In [None]:
power_columns

In [None]:
for col in sorted(power_columns):
    updated_item_col = "trial_and_post_{}".format(col)
    print(updated_item_col)
    updated_timestamp_col = "trial_and_post_power_timestamps"
    TRIALS_AND_SPECTRAL_DF[updated_item_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["power_timestamps"], items=x[col])[1], axis=1)

TRIALS_AND_SPECTRAL_DF[updated_timestamp_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["power_timestamps"], items=x[col])[0], axis=1)


In [None]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.drop(columns=power_columns + ["power_timestamps"], errors="ignore")

In [None]:
for col in TRIALS_AND_SPECTRAL_DF:
    print(col)

# Filtering out phase

In [None]:
lfp_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "trace" in col and "timestamps" not in col and "calculation" not in col]

In [None]:
band_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "band" in col and "timestamps" not in col and "calculation" not in col]

In [None]:
phase_columns = [col for col in TRIALS_AND_SPECTRAL_DF.columns if "phase" in col and "timestamps" not in col and "calculation" not in col]

In [None]:
phase_columns = phase_columns + band_columns +lfp_columns

In [None]:
phase_columns

In [None]:
for col in sorted(phase_columns):
    updated_item_col = "trial_and_post_{}".format(col)
    print(updated_item_col)
    updated_timestamp_col = "trial_and_post_lfp_timestamps"
    TRIALS_AND_SPECTRAL_DF[updated_item_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["lfp_timestamps"], items=x[col])[1], axis=1)

TRIALS_AND_SPECTRAL_DF[updated_timestamp_col] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["post_trial_end_timestamp"], timestamps=x["lfp_timestamps"], items=x[col])[0], axis=1)


In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_lfp_timestamps"]

In [None]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.drop(columns=phase_columns + ["lfp_timestamps"], errors="ignore")

In [None]:
TRIALS_AND_SPECTRAL_DF.to_pickle(os.path.join(OUTPUT_DIR, FULL_LFP_TRACES_PKL))

In [None]:
for col in TRIALS_AND_SPECTRAL_DF:
    print(col)

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_kmeans_cluster"].iloc[0].shape

In [None]:
raise ValueError()

# Filtering for rows that are in the video

In [None]:
TRIALS_AND_SPECTRAL_DF["in_video"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: x["start_frame"] <= x["trial_video_frame"][0] <= x["stop_frame"], axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF[TRIALS_AND_SPECTRAL_DF["in_video"]].reset_index()

In [None]:
TRIALS_AND_SPECTRAL_DF

In [None]:
TRIALS_AND_SPECTRAL_DF.to_pickle(os.path.join(OUTPUT_DIR, FULL_LFP_TRACES_PKL))

In [None]:
raise ValueError()

# Filtering out spikes

In [None]:
raise ValueError()

In [None]:
TRIALS_AND_SPECTRAL_DF["baseline_spike_times"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: filter_spike_times(x["spike_times"], start=x["tone_start_timestamp"], stop=x["baseline_stop_timestamp"]).astype(int), axis=1)


In [None]:
TRIALS_AND_SPECTRAL_DF["trial_spike_times"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: filter_spike_times(x["spike_times"], start=x["tone_start_timestamp"], stop=x["tone_stop_timestamp"]).astype(int), axis=1)


In [None]:
TRIALS_AND_SPECTRAL_DF["tone_start_timestamp"].head()

In [None]:
TRIALS_AND_SPECTRAL_DF["baseline_stop_timestamp"].head()

In [None]:
TRIALS_AND_SPECTRAL_DF["baseline_spike_times"].iloc[0]

In [None]:
TRIALS_AND_SPECTRAL_DF["baseline_neuron_average_fr"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["baseline_stop_timestamp"], timestamps=x["neuron_average_timestamps"], items=x["neuron_average_fr"].T)[1], axis=1)
TRIALS_AND_SPECTRAL_DF["baseline_neuron_average_timestamp"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["baseline_stop_timestamp"], timestamps=x["neuron_average_timestamps"], items=x["neuron_average_fr"].T)[0], axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_neuron_average_fr"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["tone_stop_timestamp"], timestamps=x["neuron_average_timestamps"], items=x["neuron_average_fr"].T)[1].T, axis=1)
TRIALS_AND_SPECTRAL_DF["trial_neuron_average_timestamp"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: utilities.helper.filter_by_timestamp_range(start=x["tone_start_timestamp"], stop=x["tone_stop_timestamp"], timestamps=x["neuron_average_timestamps"], items=x["neuron_average_fr"].T)[0], axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF = TRIALS_AND_SPECTRAL_DF.drop(columns=["spike_clusters", "spike_times", "neuron_average_fr", "neuron_average_timestamps",], errors="ignore")

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_neuron_average_fr"].iloc[0].shape

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_neuron_average_timestamp"].iloc[0].shape

# OLD Stuff

## Getting the ranges of each cluster

- Getting the index range

In [None]:
list(TRIALS_AND_SPECTRAL_DF.columns)

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_kmeans_cluster"].iloc[0]

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_comp_id"] = TRIALS_AND_SPECTRAL_DF["trial_and_post_kmeans_cluster"].apply(lambda x: np.vectorize(cluster_to_comp_id.get)(x.astype(str)))

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_competitiveness"] = TRIALS_AND_SPECTRAL_DF["trial_and_post_kmeans_cluster"].apply(lambda x: np.vectorize(cluster_to_competitiveness.get)(x.astype(str)))

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_competitiveness"] 

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_competitiveness"].iloc[0].shape

In [None]:
# TRIALS_AND_SPECTRAL_DF["cluster_index_ranges_dict"] = TRIALS_AND_SPECTRAL_DF["trial_and_post_comp_id"].apply(lambda x: find_consecutive_ranges(x, min_length=20))

TRIALS_AND_SPECTRAL_DF["cluster_index_ranges_dict"] = TRIALS_AND_SPECTRAL_DF["trial_and_post_competitiveness"].apply(lambda x: find_consecutive_ranges(x[:200], min_length=20))


In [None]:
TRIALS_AND_SPECTRAL_DF["cluster_index_ranges_dict"].iloc[0]

In [None]:
TRIALS_AND_SPECTRAL_DF.head()

In [None]:
TRIALS_AND_SPECTRAL_DF["cluster_timestamp"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: x["video_timestamps"][x["trial_and_post_frame_index"]], axis=1)


- Calculating the times in milliseconds of each cluster frame

In [None]:
TRIALS_AND_SPECTRAL_DF["cluster_times"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: (np.array(x["cluster_timestamp"]) - x["first_timestamp"]) // 20, axis=1)


- Updating the index to use cluster times and timestamps based on video frame timestamps list

In [None]:
TRIALS_AND_SPECTRAL_DF["cluster_times_ranges_dict"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: update_tuples_in_dict(x["cluster_index_ranges_dict"], x["cluster_times"]), axis=1)

In [None]:
TRIALS_AND_SPECTRAL_DF["cluster_timestamps_ranges_dict"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: update_tuples_in_dict(x["cluster_index_ranges_dict"], x["cluster_timestamp"]), axis=1)

- Combining the win and loss label with the cluster

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_cluster_times_ranges_dict"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: {"{}_{}".format(x["trial_label"], k): v for k, v in x["cluster_times_ranges_dict"].items()}, axis=1)


In [None]:
TRIALS_AND_SPECTRAL_DF["trial_cluster_timestamps_ranges_dict"] = TRIALS_AND_SPECTRAL_DF.apply(lambda x: {"{}_{}".format(x["trial_label"], k): v for k, v in x["cluster_timestamps_ranges_dict"].items()}, axis=1)


In [None]:
TRIALS_AND_SPECTRAL_DF["trial_cluster_timestamps_ranges_dict"].iloc[0]

In [None]:
TRIALS_AND_SPECTRAL_DF.columns

In [None]:
TRIALS_AND_SPECTRAL_DF["trial_and_post_coherence_timestamps"].iloc[0]