In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

import numpy as np
from processing.timepoint_analysis import get_signal_around_timepoint
from analysis.response_metrics import calculate_signal_response_metrics
from data.mouse import create_mice_dict
from analysis.performance_funcs import add_performance_container

from main import load_and_prepare_sessions
from config import *


sessions = load_and_prepare_sessions("../../Baseline", load_from_pickle=True, remove_bad_signal_sessions=True)



In [2]:
interval_start = peak_interval_config["interval_start"]
interval_end = peak_interval_config["interval_end"]

In [3]:
def find_start_end_idxs(event_type):
    fps = PLOTTING_CONFIG['fps']

    start_time, end_time = attr_interval_dict[event_type]
    start_event_idx = int(start_time * fps + interval_start)
    end_event_idx = int(end_time * fps + interval_start)

    return start_event_idx, end_event_idx

In [4]:
mice_dict = create_mice_dict(sessions)
for mouse in mice_dict.values():
    add_performance_container(mouse)

In [5]:
curr_event_type = 'hit'
curr_brain_region = 'DMS_right'

xs, all_ys = get_signal_around_timepoint(sessions[0], curr_event_type, curr_brain_region)

In [6]:
for ys in all_ys:
    start_event_idx, end_event_idx = find_start_end_idxs(curr_event_type)
    ys -= np.mean(ys[start_event_idx-7:start_event_idx+7])

    curr_brain_region_uni = curr_brain_region.split('_')[0]
    curr_response_metrics = calculate_signal_response_metrics(ys, (start_event_idx, end_event_idx))
    curr_response_metrics = {f'{curr_brain_region_uni}_{curr_event_type}_{k}': v for k, v in curr_response_metrics.items()}
    print(curr_response_metrics)
    print(mice_dict[sessions[0].mouse_id].metric_container.data)

{'DMS_hit_slope_up': 0.19989641053773574, 'DMS_hit_slope_down': 0.04620011722591361, 'DMS_hit_maximal_value': 5.697205908802397, 'DMS_hit_peak_timing': 29, 'DMS_hit_auc': 117.9288359789883}
{'d_prime': 0.7062379183870319, 'c_score': 1.3446454338708465, 'participation': 57.0, 'total_hits': 36.0, 'total_mistakes': 21.0, 'hit_rate': 0.16071428571428573, 'false_alarm_rate': 0.04477611940298507}
{'DMS_hit_slope_up': 0.1137453028567767, 'DMS_hit_slope_down': 0.06899061506606964, 'DMS_hit_maximal_value': 5.69835292026002, 'DMS_hit_peak_timing': 52, 'DMS_hit_auc': 89.93219799013724}
{'d_prime': 0.7062379183870319, 'c_score': 1.3446454338708465, 'participation': 57.0, 'total_hits': 36.0, 'total_mistakes': 21.0, 'hit_rate': 0.16071428571428573, 'false_alarm_rate': 0.04477611940298507}
{'DMS_hit_slope_up': 0.02777208417576901, 'DMS_hit_slope_down': 0.033487139377628275, 'DMS_hit_maximal_value': 1.8183954970985385, 'DMS_hit_peak_timing': 69, 'DMS_hit_auc': 8.792607495106282}
{'d_prime': 0.70623791

In [7]:
from collections import defaultdict
from tqdm.notebook import tqdm

In [18]:
all_coords = defaultdict(list)

for session in tqdm(sessions):
    for curr_event_type in attr_interval_dict.keys():
        for curr_brain_region in ['DMS', 'DLS', 'VS']:
            if len(session.timepoints_container.data.get(curr_event_type, [])) == 0:
                continue

            if f'{curr_brain_region}_left' in session.brain_regions:
                xs, all_ys = get_signal_around_timepoint(session, curr_event_type, f'{curr_brain_region}_left')
            elif f'{curr_brain_region}_right' in session.brain_regions:
                xs, all_ys = get_signal_around_timepoint(session, curr_event_type, f'{curr_brain_region}_right')
            else:
                continue
                                                         
            for ys in all_ys:
                start_event_idx, end_event_idx = find_start_end_idxs(curr_event_type)
                ys -= np.mean(ys[start_event_idx-7:start_event_idx+7])

                curr_brain_region_uni = curr_brain_region.split('_')[0]
                curr_response_metrics = calculate_signal_response_metrics(ys, (start_event_idx, end_event_idx))
                curr_response_metrics = {f'{curr_brain_region_uni}_{curr_event_type}_{k}': v for k, v in curr_response_metrics.items()}

                for metric, metric_value in mice_dict[session.mouse_id].metric_container.data.items():
                    for response, response_value in curr_response_metrics.items():
                        all_coords[response + "_" + metric].append((metric_value, response_value))

                                                                 

  0%|          | 0/33 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
all_coords.keys()

dict_keys(['DMS_hit_slope_up_d_prime', 'DMS_hit_slope_down_d_prime', 'DMS_hit_maximal_value_d_prime', 'DMS_hit_peak_timing_d_prime', 'DMS_hit_auc_d_prime', 'DMS_hit_slope_up_c_score', 'DMS_hit_slope_down_c_score', 'DMS_hit_maximal_value_c_score', 'DMS_hit_peak_timing_c_score', 'DMS_hit_auc_c_score', 'DMS_hit_slope_up_participation', 'DMS_hit_slope_down_participation', 'DMS_hit_maximal_value_participation', 'DMS_hit_peak_timing_participation', 'DMS_hit_auc_participation', 'DMS_hit_slope_up_total_hits', 'DMS_hit_slope_down_total_hits', 'DMS_hit_maximal_value_total_hits', 'DMS_hit_peak_timing_total_hits', 'DMS_hit_auc_total_hits', 'DMS_hit_slope_up_total_mistakes', 'DMS_hit_slope_down_total_mistakes', 'DMS_hit_maximal_value_total_mistakes', 'DMS_hit_peak_timing_total_mistakes', 'DMS_hit_auc_total_mistakes', 'DMS_hit_slope_up_hit_rate', 'DMS_hit_slope_down_hit_rate', 'DMS_hit_maximal_value_hit_rate', 'DMS_hit_peak_timing_hit_rate', 'DMS_hit_auc_hit_rate', 'DMS_hit_slope_up_false_alarm_rate

In [10]:
import pandas as pd
from scipy.stats import pearsonr
import numpy as np

# Initialize a list to hold all rows before creating the DataFrame
data = []

for metric_pair, values in all_coords.items():
    # Convert values to a pandas DataFrame for easier handling
    df_values = pd.DataFrame(values, columns=['Metric Values', 'Response Values'])
    
    # Remove NaN and Inf values
    df_values.replace([np.inf, -np.inf], np.nan, inplace=True)  # Replace Inf with NaN
    df_values.dropna(inplace=True)  # Drop all rows with NaN
    
    if df_values.shape[0] > 1:  # Ensure there are at least two data points for correlation
        correlation, p_value = pearsonr(df_values['Metric Values'], df_values['Response Values'])
        
        # Append the results to the data list
        data.append({
            "Metric Pair": metric_pair,
            "Pearson Correlation": correlation,
            "P-Value": p_value,
            "Number of Coordinates": len(df_values)
        })

# Create a DataFrame from the data list
df = pd.DataFrame(data)

# Set the 'Metric Pair' column as the index
df.set_index("Metric Pair", inplace=True)

# Sort the DataFrame ascendingly by p-value
df_sorted = df.sort_values(by="P-Value", ascending=True)

# Display the sorted DataFrame
df_sorted


Unnamed: 0_level_0,Pearson Correlation,P-Value,Number of Coordinates
Metric Pair,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
DLS_cor_reject_maximal_value_false_alarm_rate,-0.220260,4.994553e-43,3805
DLS_cor_reject_maximal_value_total_mistakes,-0.217623,5.072863e-42,3805
VS_before_dispimg_hit_maximal_value_total_mistakes,0.351741,2.631637e-24,786
VS_before_dispimg_hit_maximal_value_false_alarm_rate,0.326479,5.604246e-21,786
DLS_before_dispimg_mistake_maximal_value_false_alarm_rate,-0.492151,1.692389e-20,313
...,...,...,...
VS_mistake_slope_down_participation,0.000903,9.839532e-01,499
DMS_reward_collect_slope_down_false_alarm_rate,-0.000703,9.839749e-01,818
DLS_before_dispimg_mistake_peak_timing_false_alarm_rate,-0.000841,9.881821e-01,313
DLS_before_dispimg_mistake_peak_timing_c_score,0.000448,9.936997e-01,313


In [11]:
import matplotlib.pyplot as plt

In [None]:
for k in all_coords.keys():
        plt.title(k)
        plt.scatter(*zip(*all_coords[k]))
        plt.show()

In [29]:
# Set pandas options to display all rows
pd.set_option('display.max_rows', None)

# Display the DataFrame
display(df_sorted)


Unnamed: 0_level_0,Pearson Correlation,P-Value,Number of Coordinates
Metric Pair,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
DLS_cor_reject_maximal_value_false_alarm_rate,-0.22026,4.994553e-43,3805
DLS_cor_reject_maximal_value_total_mistakes,-0.217623,5.072863e-42,3805
VS_hit_maximal_value_total_mistakes,0.320514,3.0912679999999996e-20,786
DLS_mistake_maximal_value_false_alarm_rate,-0.476458,3.8318089999999996e-19,313
DLS_mistake_maximal_value_total_mistakes,-0.474262,5.854673999999999e-19,313
VS_reward_collect_maximal_value_total_mistakes,0.303916,3.239629e-18,784
VS_hit_maximal_value_false_alarm_rate,0.293988,3.904935e-17,786
VS_reward_collect_peak_timing_participation,-0.2932,5.231153e-17,784
DLS_cor_reject_maximal_value_d_prime,0.130641,5.936047e-16,3805
VS_reward_collect_maximal_value_false_alarm_rate,0.276574,3.118954e-15,784
