In [26]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

from main import load_and_prepare_sessions
from analysis.performance_funcs import add_performance_container
from analysis.response_metrics import calculate_signal_response_metrics
from processing.timepoint_analysis import aggregate_signals
from data.mouse import create_mice_dict
from data.data_loading import DataContainer
from tqdm.notebook import tqdm

from collections import defaultdict
import matplotlib.pyplot as plt

sessions = load_and_prepare_sessions("../../Baseline", load_from_pickle=True, remove_bad_signal_sessions=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
mice_dict = create_mice_dict(sessions)

In [28]:
for mouse in mice_dict.values():
    add_performance_container(mouse)

    for session in mouse.sessions:
        session.metric_container = mouse.metric_container

In [29]:
session_metric_order = defaultdict(list)

for idx, session in enumerate(sessions):
    for metric, val in session.metric_container.data.items():
        session_metric_order[metric].append((idx, val))

metric_session_order = {}
for metric, pairs in session_metric_order.items():
    sorted_metrics = sorted(pairs, key=lambda t: t[-1])
    metric_session_order[metric] = list(zip(*sorted_metrics))[0]

In [30]:
mouse_responses = {}

In [31]:
def is_relevant_session(brain_regions, event_type, session):
    return ((brain_regions[0] in session.brain_regions 
             or brain_regions[1] in session.brain_regions)
             and session.timepoints_container.get_data(event_type))


for mouse_id, mouse in tqdm(mice_dict.items()):
    # Initialize a flat dictionary for each mouse
    flattened_responses = {}
    
    # Iterate over combinations of brain regions and event types
    for brain_regions in [['VS_left', 'VS_right'], ['DMS_left', 'DMS_right'], ['DLS_left', 'DLS_right']]:
        for event_type in ['hit', 'mistake', 'miss', 'cor_reject', 'reward_collect']:
            filtered_sessions = [session for session in mouse.sessions if is_relevant_session(brain_regions, event_type, session)]

            if filtered_sessions:
                _, ys, _, _, interval = aggregate_signals(filtered_sessions, event_type, brain_regions, 
                                                          aggregate_by_session=False, normalize_baseline=True)
                response_metrics = calculate_signal_response_metrics(ys, interval)

                # Generate a unique key by combining brain region (prefix) and event type
                unique_key_prefix = brain_regions[0].split('_')[0]  # Extract region prefix (e.g., "VS")
                unique_key = f"{unique_key_prefix}_{event_type}"
                
                # Add suffixes to each metric and merge them into the flattened_responses dictionary
                for metric_name, metric_value in response_metrics.items():
                    flattened_key = f"{unique_key}_{metric_name}"
                    flattened_responses[flattened_key] = metric_value
    
    # Assign the flattened dictionary to the current mouse
    mouse_responses[mouse_id] = flattened_responses

print(mouse_responses.keys())

  0%|          | 0/17 [00:00<?, ?it/s]

dict_keys(['23', '25', '31', '35', '37', '33', '39', '45', '43', '47', '49', '55', '67', '57', '69', '51', '63'])


In [32]:
performance_metrics_2 = {}
for mouse_id, mouse in mice_dict.items():
    performance_metrics_2[mouse_id] = mouse.metric_container.data

In [33]:
import pandas as pd

In [34]:
correlation_matrix = pd.DataFrame(index=response_metrics.keys(), columns=performance_metrics_2.keys())

In [35]:
for d in mouse_responses.values():
    print(d)

{'DMS_hit_slope_up': 0.061020002804434634, 'DMS_hit_slope_down': 0.034095400589834825, 'DMS_hit_maximal_value': 3.254229734220574, 'DMS_hit_peak_timing': 53, 'DMS_hit_auc': 14.624882779680597, 'DMS_mistake_slope_up': 0.01781125792225364, 'DMS_mistake_slope_down': 0.002906540961230918, 'DMS_mistake_maximal_value': 0.22219047932048652, 'DMS_mistake_peak_timing': 16, 'DMS_mistake_auc': -4.725635623793409, 'DMS_miss_slope_up': 0.029465579467498032, 'DMS_miss_slope_down': 0.004312009206620564, 'DMS_miss_maximal_value': 0.23275047410753355, 'DMS_miss_peak_timing': 8, 'DMS_miss_auc': -5.325868747593303, 'DMS_cor_reject_slope_up': 0.02098904042077355, 'DMS_cor_reject_slope_down': 0.001962693078655496, 'DMS_cor_reject_maximal_value': 0.1348594592364527, 'DMS_cor_reject_peak_timing': 8, 'DMS_cor_reject_auc': -0.7334225076101838, 'DMS_reward_collect_slope_up': 0.04246764683402401, 'DMS_reward_collect_slope_down': 0.0171856965884911, 'DMS_reward_collect_maximal_value': 1.8253809559537901, 'DMS_rew

In [37]:
import numpy as np
import scipy.stats

# Assuming mouse_responses and performance_metrics_2 are structured as described:
# mouse_responses = {mouse_id: {'response_metric_name1': value, ...}, ...}
# performance_metrics_2 = {mouse_id: {'performance_metric_name1': value, ...}, ...}

def calculate_correlations(mouse_responses, performance_metrics_2):
    # Compile a comprehensive list of all unique response metric names across all mice
    all_response_metric_names = set()
    for metrics in mouse_responses.values():
        all_response_metric_names.update(metrics.keys())

    # List of all performance metric names (assuming these are consistent across all mice)
    performance_metric_names = list(next(iter(performance_metrics_2.values())).keys())
    
    # Initialize a dictionary to store correlation results
    correlation_results = {}

    # Iterate over each unique pair of response and performance metrics
    for response_metric in all_response_metric_names:
        for performance_metric in performance_metric_names:
            response_values, performance_values = [], []

            # Collect values for the current pair of metrics across all mice
            for mouse_id, response_metrics in mouse_responses.items():
                response_value = response_metrics.get(response_metric)
                performance_value = performance_metrics_2.get(mouse_id, {}).get(performance_metric)
                
                # Only include mice that have data for both the current response and performance metric
                if response_value is not None and performance_value is not None:
                    response_values.append(response_value)
                    performance_values.append(performance_value)
            
            clean_response_values = []
            clean_performance_values = []
            for rv, pv in zip(response_values, performance_values):
                if not (np.isnan(rv) or np.isnan(pv) or np.isinf(rv) or np.isinf(pv)):
                    clean_response_values.append(rv)
                    clean_performance_values.append(pv)


            # Calculate correlation if both lists have values
            if response_values and performance_values:
                corr, _ = scipy.stats.pearsonr(clean_response_values, clean_performance_values)
                correlation_results[(response_metric, performance_metric)] = corr

    return correlation_results

correlation_results = calculate_correlations(mouse_responses, performance_metrics_2)

# Print correlations
for metric_pair, corr_value in correlation_results.items():
    print(f"Correlation between '{metric_pair[0]}' and '{metric_pair[1]}': {corr_value:.4f}")


Correlation between 'DMS_reward_collect_auc' and 'd_prime': 0.4777
Correlation between 'DMS_reward_collect_auc' and 'c_score': -0.5894
Correlation between 'DMS_reward_collect_auc' and 'participation': 0.6329
Correlation between 'DMS_reward_collect_auc' and 'total_hits': 0.6585
Correlation between 'DMS_reward_collect_auc' and 'total_mistakes': 0.4967
Correlation between 'DMS_reward_collect_auc' and 'hit_rate': 0.6192
Correlation between 'DMS_reward_collect_auc' and 'false_alarm_rate': 0.4794
Correlation between 'DLS_miss_slope_up' and 'd_prime': 0.7506
Correlation between 'DLS_miss_slope_up' and 'c_score': -0.6221
Correlation between 'DLS_miss_slope_up' and 'participation': 0.5599
Correlation between 'DLS_miss_slope_up' and 'total_hits': 0.7678
Correlation between 'DLS_miss_slope_up' and 'total_mistakes': 0.0951
Correlation between 'DLS_miss_slope_up' and 'hit_rate': 0.7392
Correlation between 'DLS_miss_slope_up' and 'false_alarm_rate': 0.1098
Correlation between 'VS_cor_reject_slope_up

In [38]:
data_for_df = []
for (response_metric, performance_metric), corr_value in correlation_results.items():
    data_for_df.append({
        'Response Metric': response_metric,
        'Performance Metric': performance_metric,
        'Correlation': corr_value
    })

# Convert the list into a DataFrame
df = pd.DataFrame(data_for_df)

# Pivot the DataFrame to get response metrics as columns and performance metrics as rows
pivot_df = df.pivot(index='Response Metric', columns='Performance Metric', values='Correlation')

# Optionally, fill NaN values with zeros or any other value deemed appropriate
# pivot_df.fillna(0, inplace=True)

pivot_df

Performance Metric,c_score,d_prime,false_alarm_rate,hit_rate,participation,total_hits,total_mistakes
Response Metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
DLS_cor_reject_auc,-0.130553,-0.268640,0.257666,-0.088948,0.085450,-0.100998,0.267004
DLS_cor_reject_maximal_value,-0.675086,0.904852,0.071944,0.888208,0.642180,0.926417,0.056055
DLS_cor_reject_peak_timing,0.845870,-0.474205,-0.432828,-0.691223,-0.692064,-0.687857,-0.419775
DLS_cor_reject_slope_down,0.570121,-0.101492,-0.364338,-0.304831,-0.401329,-0.297686,-0.360550
DLS_cor_reject_slope_up,-0.728620,0.865108,0.164026,0.913080,0.703153,0.939653,0.147856
...,...,...,...,...,...,...,...
VS_reward_collect_auc,-0.475195,-0.031222,0.382212,0.377537,0.522434,0.488558,0.427669
VS_reward_collect_maximal_value,-0.350650,-0.160024,0.312821,0.219718,0.401957,0.343455,0.362676
VS_reward_collect_peak_timing,0.042861,0.101326,-0.182381,-0.021503,-0.110307,0.001444,-0.198739
VS_reward_collect_slope_down,-0.271590,-0.243788,0.256808,0.114809,0.315851,0.249079,0.306549


In [39]:
# Reshape the DataFrame from wide to long format
long_df = pivot_df.reset_index().melt(id_vars=['Response Metric'], var_name='Performance Metric', value_name='Correlation')

# Sort by the absolute value of the correlations to find the biggest ones, regardless of direction
sorted_df = long_df.reindex(long_df.Correlation.abs().sort_values(ascending=False).index)

# Optionally, you can filter to show only the top N correlations
sorted_df

sorted_df.to_csv("all_correlations.csv")

In [40]:
sorted_df

Unnamed: 0,Response Metric,Performance Metric,Correlation
379,DLS_cor_reject_slope_up,total_hits,0.939653
376,DLS_cor_reject_maximal_value,total_hits,0.926417
229,DLS_cor_reject_slope_up,hit_rate,0.913080
76,DLS_cor_reject_maximal_value,d_prime,0.904852
226,DLS_cor_reject_maximal_value,hit_rate,0.888208
...,...,...,...
159,DLS_hit_slope_up,false_alarm_rate,0.010861
256,DMS_hit_maximal_value,hit_rate,0.009821
9,DLS_hit_slope_up,c_score,-0.006901
332,DMS_hit_peak_timing,participation,-0.005058
