In [30]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

from main import load_and_prepare_sessions
from analysis.performance_funcs import add_performance_container
from analysis.timepoint_analysis import collect_signals, aggregate_signals
from data.mouse import create_mice_dict
from analysis.response_metrics import calculate_signal_response_metrics
from scipy.stats import ttest_ind
from utils import count_session_events

from collections import defaultdict
import matplotlib.pyplot as plt

sessions = load_and_prepare_sessions("../../Baseline", load_from_pickle=True, remove_bad_signal_sessions=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
from processing.session_sampling import MiceAnalysis
mouse_analyser = MiceAnalysis(sessions)

In [32]:
from scipy.stats import ttest_ind
from tqdm.notebook import tqdm
import numpy as np
import numpy as np


def is_relevant_session(brain_regions, event_type, idx):
    curr_session = sessions[idx]
    return ((brain_regions[0] in curr_session.brain_regions or brain_regions[1] in curr_session.brain_regions)
            and curr_session.event_idxs_container.get_data(event_type))

results = []
n = 1000

for metric in ('c_score', 'd_prime', 'participation'):     
    for brain_region in ['VS', 'DMS', 'DLS']:
        for event_type in ['hit', 'mistake', 'miss', 'cor_reject', 'reward_collect']:
            lo_responses, hi_responses = mouse_analyser.sample_response_metrics(metric, brain_region, event_type, n=n)
            # Perform t-test
            for response_metric in lo_responses.keys():
                t_statistic, p_value = ttest_ind(lo_responses[response_metric], 
                                                 hi_responses[response_metric], nan_policy='omit')
                result = {
                    'key': (metric, brain_region, event_type, response_metric),
                    'T-Statistic': t_statistic,
                    'P-Value': p_value,
                    'n': min(n, len(lo_responses[response_metric]), len(hi_responses[response_metric]))
                }
                
                if np.isnan(p_value):
                    continue
                results.append(result)

# Sort the list of dictionaries by P-Value
sorted_results = sorted(results, key=lambda x: x['P-Value'])

In [34]:
len(sorted_results)

225

In [33]:
import pandas as pd
# Print the sorted results

all_ps = defaultdict(list)
all_ts = defaultdict(list)

for sorted_result in sorted_results:
    if sorted_result['P-Value'] > 0.05:
        continue
    curr_key = sorted_result['key']
    p_val = sorted_result['P-Value']
    t_stat = sorted_result['T-Statistic']
    n = sorted_result['n']
    
    for metric in curr_key:
        all_ps[metric].append(p_val)
        all_ts[metric].append(t_stat)

    print(curr_key)
    print("T-Statistic:", t_stat)
    print("P-Value:", p_val)
    print("N:", n)
    print('')

from statistics import geometric_mean
# Function to calculate the mean based on the flag
def mean(values, use_geometric_mean=False, use_abs=False):
    if use_geometric_mean:
        res = geometric_mean(values)
    else:
        res = sum(values) / len(values)
    if use_abs:
        return abs(res)
    return res

# Sort all_ps by list mean (arithmetic or geometric based on the flag)
sorted_all_ps = sorted(all_ps.items(), key=lambda x: mean(x[1], use_geometric_mean=True))

# Sort all_ts by list mean (arithmetic or geometric based on the flag)
sorted_all_ts = sorted(all_ts.items(), key=lambda x: mean(x[1], use_abs=True), reverse=True)

# Converting to DataFrame for low p-value factors
df_low_p = pd.DataFrame([(factor, f'{mean(values, use_geometric_mean=True):.3e}') for factor, values in sorted_all_ps], columns=['Factor', 'Mean (Geometric)'])

# Converting to DataFrame for high t-stat factors
df_high_t = pd.DataFrame([(factor, round(mean(values), 3)) for factor, values in sorted_all_ts], columns=['Factor', 'Mean'])

# Pretty print DataFrames
print("Factors that matter the most for a low p-value:")
display(df_low_p)

print("\nFactors that matter the most for a high t-stat:")
display(df_high_t)



# Create a list of dictionaries from the sorted results
sorted_results_list = []
for sorted_result in sorted_results:
    sorted_results_list.append({
        'key': sorted_result['key'],
        'T-Statistic': sorted_result['T-Statistic'],
        'P-Value': sorted_result['P-Value'],
        'N events': sorted_result['n'] 
    })

# Create a DataFrame from the list of dictionaries
df_sorted_results = pd.DataFrame(sorted_results_list)

# Save the DataFrame to a CSV file
df_sorted_results.to_csv('sorted_p_values.csv', index=False)



('c_score', 'VS', 'cor_reject', 'auc')
T-Statistic: 27.134720473490507
P-Value: 2.6458910408888578e-138
N: 1000

('d_prime', 'VS', 'cor_reject', 'auc')
T-Statistic: -23.877856082567142
P-Value: 4.5859368750197934e-111
N: 1000

('c_score', 'VS', 'cor_reject', 'maximal_value')
T-Statistic: 20.41841355740805
P-Value: 2.561721512596133e-84
N: 1000

('c_score', 'DMS', 'cor_reject', 'auc')
T-Statistic: 19.472901464529773
P-Value: 1.800529700616829e-77
N: 1000

('d_prime', 'DMS', 'cor_reject', 'auc')
T-Statistic: -16.622712173239943
P-Value: 3.2282143895617886e-58
N: 1000

('d_prime', 'VS', 'cor_reject', 'maximal_value')
T-Statistic: -15.804362405781156
P-Value: 4.170788517769568e-53
N: 1000

('participation', 'VS', 'cor_reject', 'maximal_value')
T-Statistic: -15.337859403127805
P-Value: 8.93870536428155e-50
N: 836

('c_score', 'DMS', 'cor_reject', 'maximal_value')
T-Statistic: 15.110408204885513
P-Value: 6.2841979214463284e-49
N: 1000

('c_score', 'VS', 'miss', 'auc')
T-Statistic: 14.7657731

Unnamed: 0,Factor,Mean (Geometric)
0,cor_reject,1.494e-27
1,auc,7.214e-24
2,maximal_value,2.091e-17
3,VS,3.564e-17
4,c_score,6.954e-16
5,d_prime,4.212e-13
6,DMS,5.642e-13
7,miss,6.142e-12
8,participation,2.627e-08
9,reward_collect,6.186e-06



Factors that matter the most for a high t-stat:


Unnamed: 0,Factor,Mean
0,c_score,3.891
1,d_prime,-3.513
2,participation,-1.914
3,slope_up,-1.605
4,maximal_value,-1.514
5,DMS,1.334
6,peak_timing,1.321
7,miss,-1.279
8,VS,-1.164
9,DLS,-1.127
