In [95]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

from main import load_and_prepare_sessions
from analysis.performance_funcs import add_performance_container
from analysis.response_metrics import calculate_signal_response_metrics
from processing.timepoint_analysis import aggregate_signals
from data.mouse import create_mice_dict
from data.data_loading import DataContainer
from tqdm.notebook import tqdm

from collections import defaultdict
import matplotlib.pyplot as plt

sessions = load_and_prepare_sessions("../../Baseline", load_from_pickle=True, remove_bad_signal_sessions=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [96]:
mice_dict = create_mice_dict(sessions)

In [97]:
for mouse in mice_dict.values():
    add_performance_container(mouse)

    for session in mouse.sessions:
        session.metric_container = mouse.metric_container

In [98]:
session_metric_order = defaultdict(list)

for idx, session in enumerate(sessions):
    for metric, val in session.metric_container.data.items():
        session_metric_order[metric].append((idx, val))

metric_session_order = {}
for metric, pairs in session_metric_order.items():
    sorted_metrics = sorted(pairs, key=lambda t: t[-1])
    metric_session_order[metric] = list(zip(*sorted_metrics))[0]

In [99]:
mouse_responses = {}

In [100]:
def is_relevant_session(brain_regions, event_type, session):
    return ((brain_regions[0] in session.brain_regions 
             or brain_regions[1] in session.brain_regions)
             and session.timepoints_container.get_data(event_type))


for mouse_id, mouse in tqdm(mice_dict.items()):
    # Initialize a flat dictionary for each mouse
    flattened_responses = {}
    
    # Iterate over combinations of brain regions and event types
    for brain_regions in [['VS_left', 'VS_right'], ['DMS_left', 'DMS_right'], ['DLS_left', 'DLS_right']]:
        for event_type in ['hit', 'mistake', 'miss', 'cor_reject', 'reward_collect']:
            filtered_sessions = [session for session in mouse.sessions if is_relevant_session(brain_regions, event_type, session)]

            if filtered_sessions:
                _, ys, _, _, interval = aggregate_signals(filtered_sessions, event_type, brain_regions, 
                                                          aggregate_by_session=False, normalize_baseline=True)
                response_metrics = calculate_signal_response_metrics(ys, interval)

                # Generate a unique key by combining brain region (prefix) and event type
                unique_key_prefix = brain_regions[0].split('_')[0]  # Extract region prefix (e.g., "VS")
                unique_key = f"{unique_key_prefix}_{event_type}"
                
                # Add suffixes to each metric and merge them into the flattened_responses dictionary
                for metric_name, metric_value in response_metrics.items():
                    flattened_key = f"{unique_key}_{metric_name}"
                    flattened_responses[flattened_key] = metric_value
    
    # Assign the flattened dictionary to the current mouse
    mouse_responses[mouse_id] = flattened_responses

print(mouse_responses.keys())

  0%|          | 0/17 [00:00<?, ?it/s]

dict_keys(['23', '25', '31', '35', '37', '33', '39', '45', '43', '47', '49', '55', '67', '57', '69', '51', '63'])


In [101]:
performance_metrics_2 = {}
for mouse_id, mouse in mice_dict.items():
    performance_metrics_2[mouse_id] = mouse.metric_container.data

In [102]:
import pandas as pd

In [103]:
correlation_matrix = pd.DataFrame(index=response_metrics.keys(), columns=performance_metrics_2.keys())

In [104]:
for d in mouse_responses.values():
    print(d)

{'DMS_hit_slope_up': 0.061020002804434634, 'DMS_hit_slope_down': 0.034095400589834825, 'DMS_hit_maximal_value': 3.254229734220574, 'DMS_hit_peak_timing': 53, 'DMS_hit_auc': 14.624882779680597, 'DMS_mistake_slope_up': 0.01781125792225364, 'DMS_mistake_slope_down': 0.002906540961230918, 'DMS_mistake_maximal_value': 0.22219047932048652, 'DMS_mistake_peak_timing': 16, 'DMS_mistake_auc': -4.725635623793409, 'DMS_miss_slope_up': 0.029465579467498032, 'DMS_miss_slope_down': 0.004312009206620564, 'DMS_miss_maximal_value': 0.23275047410753355, 'DMS_miss_peak_timing': 8, 'DMS_miss_auc': -5.325868747593303, 'DMS_cor_reject_slope_up': 0.02098904042077355, 'DMS_cor_reject_slope_down': 0.001962693078655496, 'DMS_cor_reject_maximal_value': 0.1348594592364527, 'DMS_cor_reject_peak_timing': 8, 'DMS_cor_reject_auc': -0.7334225076101838, 'DMS_reward_collect_slope_up': 0.04246764683402401, 'DMS_reward_collect_slope_down': 0.0171856965884911, 'DMS_reward_collect_maximal_value': 1.8253809559537901, 'DMS_rew

In [105]:
import numpy as np
import scipy.stats

# Assuming mouse_responses and performance_metrics_2 are structured as described:
# mouse_responses = {mouse_id: {'response_metric_name1': value, ...}, ...}
# performance_metrics_2 = {mouse_id: {'performance_metric_name1': value, ...}, ...}

def calculate_correlations(mouse_responses, performance_metrics_2):
    # Compile a comprehensive list of all unique response metric names across all mice
    all_response_metric_names = set()
    for metrics in mouse_responses.values():
        all_response_metric_names.update(metrics.keys())

    # List of all performance metric names (assuming these are consistent across all mice)
    performance_metric_names = list(next(iter(performance_metrics_2.values())).keys())
    
    # Initialize a dictionary to store correlation results
    correlation_results = {}

    # Iterate over each unique pair of response and performance metrics
    for response_metric in all_response_metric_names:
        for performance_metric in performance_metric_names:
            response_values, performance_values = [], []

            # Collect values for the current pair of metrics across all mice
            for mouse_id, response_metrics in mouse_responses.items():
                response_value = response_metrics.get(response_metric)
                performance_value = performance_metrics_2.get(mouse_id, {}).get(performance_metric)
                
                # Only include mice that have data for both the current response and performance metric
                if response_value is not None and performance_value is not None:
                    response_values.append(response_value)
                    performance_values.append(performance_value)
            
            clean_response_values = []
            clean_performance_values = []
            for rv, pv in zip(response_values, performance_values):
                if not (np.isnan(rv) or np.isnan(pv) or np.isinf(rv) or np.isinf(pv)):
                    clean_response_values.append(rv)
                    clean_performance_values.append(pv)


            # Calculate correlation if both lists have values
            if response_values and performance_values:
                corr, p_val = scipy.stats.pearsonr(clean_response_values, clean_performance_values)
                correlation_results[(response_metric, performance_metric)] = (corr, p_val)

    return correlation_results

correlation_results = calculate_correlations(mouse_responses, performance_metrics_2)

# Print correlations
for metric_pair, (corr_value, p_val) in correlation_results.items():
    # if p_val <= 0.05:
    print(f"Correlation between '{metric_pair[0]}' and '{metric_pair[1]}': {corr_value:.4f}, {p_val:.4f}")


Correlation between 'DLS_hit_slope_up' and 'd_prime': 0.1821, 0.6392
Correlation between 'DLS_hit_slope_up' and 'c_score': -0.0069, 0.9859
Correlation between 'DLS_hit_slope_up' and 'participation': 0.1448, 0.7101
Correlation between 'DLS_hit_slope_up' and 'total_hits': 0.2066, 0.5938
Correlation between 'DLS_hit_slope_up' and 'total_mistakes': 0.0153, 0.9688
Correlation between 'DLS_hit_slope_up' and 'hit_rate': 0.1993, 0.6072
Correlation between 'DLS_hit_slope_up' and 'false_alarm_rate': 0.0109, 0.9779
Correlation between 'DMS_hit_slope_down' and 'd_prime': -0.0842, 0.7948
Correlation between 'DMS_hit_slope_down' and 'c_score': -0.0717, 0.8247
Correlation between 'DMS_hit_slope_down' and 'participation': 0.1282, 0.6912
Correlation between 'DMS_hit_slope_down' and 'total_hits': 0.0991, 0.7592
Correlation between 'DMS_hit_slope_down' and 'total_mistakes': 0.1466, 0.6493
Correlation between 'DMS_hit_slope_down' and 'hit_rate': 0.0265, 0.9349
Correlation between 'DMS_hit_slope_down' and 

In [106]:
data_for_df = []
for (response_metric, performance_metric), corr_value in correlation_results.items():
    data_for_df.append({
        'Response Metric': response_metric,
        'Performance Metric': performance_metric,
        'Correlation': corr_value
    })

# Convert the list into a DataFrame
df = pd.DataFrame(data_for_df)

# Pivot the DataFrame to get response metrics as columns and performance metrics as rows
pivot_df = df.pivot(index='Response Metric', columns='Performance Metric', values='Correlation')

# Optionally, fill NaN values with zeros or any other value deemed appropriate
# pivot_df.fillna(0, inplace=True)

pivot_df

Performance Metric,c_score,d_prime,false_alarm_rate,hit_rate,participation,total_hits,total_mistakes
Response Metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
DLS_cor_reject_auc,"(-0.1305526857964975, 0.7377888543072493)","(-0.2686398756325323, 0.48458420085048654)","(0.2576656234758902, 0.5032686065863509)","(-0.08894773022379772, 0.8199876077087089)","(0.08545027514822928, 0.8269778790577715)","(-0.10099774462338705, 0.7959900418461738)","(0.2670041036366906, 0.4873507607251009)"
DLS_cor_reject_maximal_value,"(-0.6750860750928344, 0.046020541328914207)","(0.9048523442566342, 0.0007962629704496753)","(0.07194397695268082, 0.8540676296635044)","(0.8882076825557823, 0.0013765990822775297)","(0.642179937702616, 0.062199895195251334)","(0.9264174541103793, 0.0003309509441758246)","(0.056054792269510684, 0.8861047032180419)"
DLS_cor_reject_peak_timing,"(0.845869573303773, 0.004057028933124552)","(-0.4742054061084194, 0.19716885957101293)","(-0.4328276092798268, 0.24456538396288277)","(-0.6912232322442297, 0.03919153811748818)","(-0.6920635236734146, 0.03885513076372778)","(-0.687857094226123, 0.0405580085629059)","(-0.4197749887786719, 0.26067033418219265)"
DLS_cor_reject_slope_down,"(0.570120844844306, 0.10898001610609528)","(-0.10149163556673856, 0.7950094813707813)","(-0.36433791606278393, 0.33505461422177607)","(-0.3048306978485709, 0.42510241319912667)","(-0.4013288440124517, 0.28435721062923125)","(-0.29768608807044655, 0.4365761643549959)","(-0.36055019576932634, 0.34048086210919604)"
DLS_cor_reject_slope_up,"(-0.7286196214013355, 0.025975193124643418)","(0.8651081010736126, 0.0025948586074368988)","(0.16402602953742698, 0.6732503212621643)","(0.9130796114770687, 0.0005850319786345286)","(0.7031530102508586, 0.03458933560777781)","(0.939652868421669, 0.00016752342466970548)","(0.14785587420343824, 0.70422401915947)"
...,...,...,...,...,...,...,...
VS_reward_collect_auc,"(-0.47519481981843215, 0.13964047767406876)","(-0.031222373826046906, 0.9273903872337651)","(0.3822121857091181, 0.24603144713502062)","(0.3775365045333805, 0.25234078474524313)","(0.5224336439845266, 0.09921262613635688)","(0.48855766000892437, 0.12730105158727922)","(0.427668543854819, 0.18948895334218888)"
VS_reward_collect_maximal_value,"(-0.35065000736820473, 0.2903914801563543)","(-0.16002421753855658, 0.6383490679904029)","(0.31282079475530106, 0.34894513508449604)","(0.21971753849736358, 0.5162375944539924)","(0.4019572845695585, 0.22040097960887717)","(0.3434550279002469, 0.30108126551567793)","(0.36267550710300095, 0.2730012001061187)"
VS_reward_collect_peak_timing,"(0.04286137910603041, 0.9004233160576227)","(0.10132625977393638, 0.7668953221798089)","(-0.18238113988234111, 0.5914504466923911)","(-0.02150259497324022, 0.9499644505822378)","(-0.11030688422819643, 0.7467950089483592)","(0.001444178937089164, 0.99663765712798)","(-0.1987390896361976, 0.5579925892059941)"
VS_reward_collect_slope_down,"(-0.27159028868754725, 0.41916352115119937)","(-0.2437882473189288, 0.4700483117336502)","(0.2568075256527545, 0.44587971017902467)","(0.11480905074345887, 0.7367705620794913)","(0.3158506985125995, 0.344043819150367)","(0.24907934024000883, 0.460155218884289)","(0.30654929616298787, 0.359204968651656)"


In [107]:
# # Reshape the DataFrame from wide to long format
long_df = pivot_df.reset_index().melt(id_vars=['Response Metric'], var_name='Performance Metric', value_name='Correlation')

# # Sort by the absolute value of the correlations to find the biggest ones, regardless of direction
# sorted_df = long_df.reindex(long_df.Correlation.abs().sort_values(ascending=False, key=lambda t: t[-1]).index)

# # Optionally, you can filter to show only the top N correlations
# sorted_df

# sorted_df.to_csv("all_correlations.csv")

# Your initial transformation is good, creating new columns for absolute correlation values and p-values
long_df['Correlation_abs'] = long_df['Correlation'].apply(lambda x: abs(x[0]))
long_df['P_value'] = long_df['Correlation'].apply(lambda x: x[1])

# Now, sort by 'Correlation_abs' in descending order and then by 'P_value' in ascending order
sorted_df = long_df.sort_values(by=['P_value'], ascending=True)

# This will give you a DataFrame sorted by the absolute correlation values first, and then by p-values where correlations are equal or nearly equal.


In [108]:
sorted_df

Unnamed: 0,Response Metric,Performance Metric,Correlation,Correlation_abs,P_value
379,DLS_cor_reject_slope_up,total_hits,"(0.939652868421669, 0.00016752342466970548)",0.939653,0.000168
376,DLS_cor_reject_maximal_value,total_hits,"(0.9264174541103793, 0.0003309509441758246)",0.926417,0.000331
229,DLS_cor_reject_slope_up,hit_rate,"(0.9130796114770687, 0.0005850319786345286)",0.913080,0.000585
76,DLS_cor_reject_maximal_value,d_prime,"(0.9048523442566342, 0.0007962629704496753)",0.904852,0.000796
438,VS_miss_slope_down,total_hits,"(0.8488372775339689, 0.0009509772064382363)",0.848837,0.000951
...,...,...,...,...,...
256,DMS_hit_maximal_value,hit_rate,"(0.00982081430313117, 0.9758346975369001)",0.009821,0.975835
159,DLS_hit_slope_up,false_alarm_rate,"(0.010860613639043774, 0.9778771142231159)",0.010861,0.977877
9,DLS_hit_slope_up,c_score,"(-0.006901224640076153, 0.9859414985890855)",0.006901,0.985941
332,DMS_hit_peak_timing,participation,"(-0.0050581911589941825, 0.9875525323289267)",0.005058,0.987553


In [109]:
sorted_df.drop(columns=['Correlation', 'Correlation_abs']).to_csv('all_correlations_3.csv')