In [None]:
#---------------------------------------------------------------------------------------------------#
from pathlib import Path
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from matplotlib import cm
import plotly.io as pio
import plotly.express as px
import plotly.subplots as sp
import math
from pprint import pprint
import pickle

from plotly.subplots import make_subplots
from scipy.stats import mode
from scipy.integrate import cumulative_trapezoid
from scipy.signal import correlate
%config Completer.use_jedi = False  # Fixes autocomplete issues
%config InlineBackend.figure_format = 'retina'  # Improves plot resolution

import gc # garbage collector for removing large variables from memory instantly 
import importlib #for force updating changed packages 
from scipy.stats import pearsonr, norm
import seaborn as sns
import pickle
import warnings
warnings.filterwarnings('ignore')

In [None]:

#FIXME THIS NEEDS TO BECOME ITERATIVE: NOW, I HAVE TO SPECIFY FOR WHICH RESULTS I WANT THE PLOT TO BE COMPUTED IN 
#CELLS BLOW THE FIRST. HIGHLY CONFUSING. 
#-------------------------------
# data paths setup
#-------------------------------
data_dirs = [  # Add your data directories here
    # Path('~/RANCZLAB-NAS/data/ONIX/20250409_Cohort3_rotation/Vestibular_mismatch_day1').expanduser(),
    # Path('/Volumes/RanczLab2/20241125_Cohort1_rotation/Visual_mismatch_day4').expanduser(),
    Path('/Volumes/RanczLab2/20241125_Cohort1_rotation/Visual_mismatch_day4').expanduser(),
    # Path('/Volumes/RanczLab2/20241125_Cohort1_rotation/Visual_mismatch_day3').expanduser(),
    # Path('/Volumes/RanczLab2/20250409_Cohort3_rotation/Visual_mismatch_day4').expanduser(),
]
# Collect raw data paths (excluding '_processedData' dirs)
rawdata_paths = []
for data_dir in data_dirs:
    subdirs = [p for p in data_dir.iterdir() if p.is_dir() and not p.name.endswith('_processedData')]
    rawdata_paths.extend(subdirs)  # Collect all subdirectories

# Build processed data paths
data_paths = [raw.parent / f"{raw.name}_processedData/aligned_data" for raw in rawdata_paths]
mouse_name = [raw.name.split('-')[0] for raw in rawdata_paths]
selected_mice2 = ["B6J2780", "B6J2781", "B6J2783", "B6J2782"] #for cohort 3
selected_mice1 = ['B6J2717', 'B6J2718', 'B6J2719', 'B6J2721', 'B6J2722'] #for cohort 1 'B6J2722, 'B6J2723
#-------------------------------
# load aligned_downsampled data for each data path
#-------------------------------
loaded_data = {}  # Dictionary to store loaded data for each path
for idx, data_path in enumerate(data_paths, start=1):
    print(f"\nProcessing data path {idx}/{len(data_paths)}: {data_path}")
    csv_file_path = data_path / f"{mouse_name[idx - 1]}_No halt_baselined_data.csv"
    # csv_file_path = data_path / f"{mouse_name[idx - 1]}_DrumWithReverseflow block started_baselined_data.csv"
    # Skip if no mouse name is found in the directory
    if mouse_name[idx - 1] == 'baselined':
        print(f"⚠️ Skipping directory {data_path} as it does not contain a valid mouse name.")
        continue

    aligned_df = pd.read_csv(csv_file_path)
    print(f"✅ Successfully loaded all data files for {data_path.name}")
    loaded_data[data_path] = {
        'mouse_name': mouse_name[idx - 1],
        'data_path': data_path
    }
    # Add each column of aligned_df as a separate key in the dictionary
    for column in aligned_df.columns:
        loaded_data[data_path][column] = aligned_df[column].values
    print(f"Data loaded for {data_path.name}: {len(aligned_df)} rows, {len(aligned_df.columns)} columns")
    # Clean up memory
    del aligned_df
    gc.collect()  # Run garbage collection to free up memory
#-------------------------------
#variables
save_grand_avg_with_sem = True  # Boolean to control whether to save grand averages with SEMs to a CSV file
generate_new_plots = True  # Set to False if you don't want to generate new plots
# selected_columns = ['Velocity_0X_Baseline','z_470_Baseline','z_560_Baseline']  # Add your selected columns here ', 'z_560', 'z_470', 'dfF_470', 'dfF_560'
# columns_to_plot = ['Velocity_0X_Baseline','z_470_Baseline','z_560_Baseline']  # Add more columns as needed , 'dfF_470', 'dfF_560', 'z_470', 'z_560'
selected_columns = ['Velocity_0X','z_470','z_560']  # Add your selected columns here ', 'z_560', 'z_470', 'dfF_470', 'dfF_560'
columns_to_plot = ['Velocity_0X','z_470','z_560']  # Add more columns as needed , 'dfF_470', 'dfF_560', 'z_470', 'z_560'

# Print data paths in a more readable format
print("Processed Data Paths:")
pprint(data_paths)
print("Mouse Name:")
pprint(mouse_name)

In [None]:
#computes mean and sem per mouse + stores grand averages across all mice

def compute_mouse_means_and_grand_average(loaded_data, selected_columns, main_data_dir):
    """
    Compute means per mouse and grand averages across mice for selected columns.
    
    Parameters:
    loaded_data (dict): Dictionary with data paths as keys and mouse data as values
    selected_columns (list): List of column names to analyze
    main_data_dir (str/Path): Main directory to save results
    
    Returns:
    tuple: (mean_data_per_mouse, sem_data_per_mouse, grand_averages)
    """
    
    main_data_dir = Path(main_data_dir)
    
    print(f"Processing selected columns: {selected_columns}")
    
    # Step 1: Compute mean and SEM for each mouse
    mean_data_per_mouse = {}
    sem_data_per_mouse = {}
    
    for data_path, data in loaded_data.items():
        mouse_name = data['mouse_name']
        print(f"Processing mouse: {mouse_name}")
        
        # Create DataFrame from the loaded data
        df = pd.DataFrame(data)
        
        # Check which selected columns are available
        available_columns = [col for col in selected_columns if col in df.columns]
        missing_columns = [col for col in selected_columns if col not in df.columns]
        
        if missing_columns:
            print(f"⚠️  Missing columns for {mouse_name}: {missing_columns}")
        
        if 'Time (s)' not in df.columns:
            print(f"⚠️  'Time (s)' column not found for {mouse_name}, skipping...")
            continue
        
        # Group by time and compute mean and SEM
        grouped = df.groupby('Time (s)')
        
        # Only use numeric columns that are in our selected list
        numeric_selected = []
        for col in available_columns:
            if col != 'Time (s)' and pd.api.types.is_numeric_dtype(df[col]):
                numeric_selected.append(col)
        
        if len(numeric_selected) == 0:
            print(f"⚠️  No numeric columns found for {mouse_name}")
            continue
        
        mean_data_per_mouse[mouse_name] = grouped[numeric_selected].mean()
        sem_data_per_mouse[mouse_name] = grouped[numeric_selected].sem()
        
        print(f"✅ Processed {len(numeric_selected)} columns for {mouse_name}")
    
    # Step 2: Compute grand averages across mice
    print(f"\n📊 Computing grand averages across {len(mean_data_per_mouse)} mice...")
    
    # Get all unique time points
    all_time_points = set()
    for mouse_data in mean_data_per_mouse.values():
        all_time_points.update(mouse_data.index)
    all_time_points = sorted(list(all_time_points))
    
    # Get all columns that were successfully processed
    all_processed_columns = set()
    for mouse_data in mean_data_per_mouse.values():
        all_processed_columns.update(mouse_data.columns)
    all_processed_columns = sorted(list(all_processed_columns))
    
    print(f"Time points: {len(all_time_points)} from {min(all_time_points):.2f}s to {max(all_time_points):.2f}s")
    print(f"Processed columns: {all_processed_columns}")
    
    # Create grand average DataFrame
    grand_averages = pd.DataFrame(index=all_time_points, columns=all_processed_columns)
    grand_averages.index.name = 'Time (s)'
    
    grand_sems = pd.DataFrame(index=all_time_points, columns=all_processed_columns)
    grand_sems.index.name = 'Time (s)'
    
    # Compute grand averages for each column and time point
    for col in all_processed_columns:
        for time_point in all_time_points:
            # Collect data from all mice for this time point and column
            mouse_values = []
            for mouse_name, mouse_data in mean_data_per_mouse.items():
                if time_point in mouse_data.index and col in mouse_data.columns:
                    value = mouse_data.loc[time_point, col]
                    if not pd.isna(value):
                        mouse_values.append(value)
            
            if len(mouse_values) > 0:
                grand_averages.loc[time_point, col] = np.mean(mouse_values)
                if len(mouse_values) > 1:
                    grand_sems.loc[time_point, col] = np.std(mouse_values) / np.sqrt(len(mouse_values))
                else:
                    grand_sems.loc[time_point, col] = 0
    
    # # Step 3: Save grand averages to CSV
    # csv_filename = main_data_dir / f'grand_averages_across_mice_{main_data_dir.name}.csv'
    # # Combine mean and SEM into one CSV for convenience
    # combined_df = grand_averages.copy()
    # for col in all_processed_columns:
    #     combined_df[f'{col}_SEM'] = grand_sems[col]
    
    # combined_df.to_csv(csv_filename)
    # print(f"✅ Grand averages saved to: {csv_filename}")
    
    return mean_data_per_mouse, sem_data_per_mouse, grand_averages, grand_sems

def analyze_mice_data(loaded_data, selected_columns, main_data_dir):
    """
    Complete analysis workflow: compute means, grand averages, save CSV, and create plots.
    
    Parameters:
    loaded_data (dict): Your loaded_data dictionary
    selected_columns (list): List of column names to analyze (including 'Time (s)')
    main_data_dir (str/Path): Main directory to save results
    
    Returns:
    dict: Complete results including individual and grand averages
    """
    
    print(f"\n{'='*60}")
    print(f"MOUSE DATA ANALYSIS")
    print(f"{'='*60}")
    
    # Compute means and grand averages
    mean_data_per_mouse, sem_data_per_mouse, grand_averages, grand_sems = compute_mouse_means_and_grand_average(
        loaded_data, selected_columns, main_data_dir
    )
    

    # Print summary
    print(f"\n📊 ANALYSIS COMPLETE:")
    print(f"   • Number of mice analyzed: {len(mean_data_per_mouse)}")
    print(f"   • Mouse names: {list(mean_data_per_mouse.keys())}")
    print(f"   • Columns processed: {list(grand_averages.columns)}")
    print(f"   • Time range: {grand_averages.index.min():.2f}s to {grand_averages.index.max():.2f}s")
    print(f"   • Files saved in: {main_data_dir}")
    
    # Return all results
    results = {
        'mean_data_per_mouse': mean_data_per_mouse,
        'sem_data_per_mouse': sem_data_per_mouse,
        'grand_averages': grand_averages,
        'grand_sems': grand_sems,
    }
    
    return results

In [None]:
#FIXME: THIS IS A TEMPORARY FIX TO LOAD THE RESULTS FROM A PICKLE FILE
#DO WE NEED THE PICKLE FILE AT ALL? IF SO, SAVE INSTEAD OF GRAND AVERAGE CSV (SAVE THE GRAND AVERAGE IN THE SAME PICLKE FILE)


#HERE you define which data dir you want the analysis to be performed on
#---------------------------
main_data_dir = data_dirs[0]  # Use the first directory from the list
#here you run the analysis
#----------------------------
#to save the results to a pickle file
def save_results(results, filename='results.pkl'):
    with open(filename, 'wb') as f:
        pickle.dump(results, f)
    print(f"Results saved to {filename}")
#----------------------------
results_cohort1_vmm4_nohalt = analyze_mice_data(loaded_data, selected_columns, main_data_dir)
# Save results to a pickle file
save_results(results_cohort1_vmm4_nohalt, 'results_cohort1_vmm4_nohalt.pkl')

In [None]:
#FIXME: THIS IS NOT ITERATIVE, I HAVE TO MANUALLY SPECIFY WHICH RESULTS I WANT TO PLOT

#PLOTTING MEAN PER MOUSE AND GRAND AVERAGE, STORING CSV WITH GRAND AVERAGES AND SEMs, STORING PLOTS
#--------------------------------
selected_mice = selected_mice1 #defined above, do not remove from here 
# Define the column to plot the grand average of
#--------------------------------
# PLOT properties
plt.rcParams.update({
    'font.size': 10,           # Set global font size
    'font.family': 'sans-serif',  # Font family (e.g., 'serif', 'sans-serif', 'monospace')
    'font.sans-serif': ['Arial'],  # Preferred font
    'axes.titlesize': 10,      # Title font size
    'axes.labelsize': 10,      # Axis label size
    'legend.fontsize': 8,     # Legend text
    'xtick.labelsize': 10,
    'ytick.labelsize': 10
})
# Generate a color palette
color_palette = plt.cm.Set2.colors  # Use the 'tab10' colormap from matplotlib
# Define mouse colors using the color palette
mouse_colors = {mouse: color_palette[i % len(color_palette)] for i, mouse in enumerate(selected_mice)}

# Boolean to control whether to generate new plots
if not generate_new_plots:
    print("Skipping plot generation as per user configuration.")
    # Move on to the next code bit
    pass
#---------------------------
# plot data for each selected mouse and the grand average   
for column_to_plot in columns_to_plot:
    # Create the plot
    plt.figure(figsize=(8, 4))

    # Plot mean and SEM for each selected mouse
    for mouse in selected_mice:
        if mouse in results_cohort1_vmm4_nohalt['mean_data_per_mouse']:
            mean_data = results_cohort1_vmm4_nohalt['mean_data_per_mouse'][mouse][column_to_plot]
            sem_data = results_cohort1_vmm4_nohalt['sem_data_per_mouse'][mouse][column_to_plot]
            
            # Ensure data is numeric and handle any conversion issues
            mean_data = pd.to_numeric(mean_data, errors='coerce')
            sem_data = pd.to_numeric(sem_data, errors='coerce')
            time_points = pd.to_numeric(mean_data.index, errors='coerce')
            
            # Drop any NaN values that might have been created during conversion
            valid_mask = ~(pd.isna(mean_data) | pd.isna(sem_data) | pd.isna(time_points))
            mean_data_clean = mean_data[valid_mask]
            sem_data_clean = sem_data[valid_mask]
            time_points_clean = time_points[valid_mask]

            # Plot mean with SEM as shaded area using the defined color palette
            plt.plot(time_points_clean, mean_data_clean, label=f'{mouse} Mean', color=mouse_colors[mouse])
            plt.fill_between(time_points_clean, mean_data_clean - sem_data_clean, 
                             mean_data_clean + sem_data_clean, color=mouse_colors[mouse], alpha=0.2)

    # Plot the grand average
    grand_mean = results_cohort1_vmm4_nohalt['grand_averages'][column_to_plot]
    grand_sem = results_cohort1_vmm4_nohalt['grand_sems'][column_to_plot]
    
    # Ensure grand average data is numeric
    grand_mean = pd.to_numeric(grand_mean, errors='coerce')
    grand_sem = pd.to_numeric(grand_sem, errors='coerce')
    time_points = pd.to_numeric(grand_mean.index, errors='coerce')
    
    # Drop any NaN values
    valid_mask = ~(pd.isna(grand_mean) | pd.isna(grand_sem) | pd.isna(time_points))
    grand_mean_clean = grand_mean[valid_mask]
    grand_sem_clean = grand_sem[valid_mask]
    time_points_clean = time_points[valid_mask]

    plt.plot(time_points_clean, grand_mean_clean, label='Grand Average', color='black', linewidth=2)
    plt.fill_between(time_points_clean, grand_mean_clean - grand_sem_clean, 
                    grand_mean_clean + grand_sem_clean, color='gray', alpha=0.3)
    
    # Add labels, legend, and title
    plt.xlabel('Time (s)')
    plt.ylabel(column_to_plot)
    plt.title(f'Mean and SEM of {column_to_plot} Over Time')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Save the plot to the "baselined" subdirectory
    baselined_dir = main_data_dir / "baselined"
    baselined_dir.mkdir(exist_ok=True)  # Create the directory if it doesn't exist
    plot_filename = baselined_dir / f"{column_to_plot}_plot_nohalt.pdf"
    plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
    print(f"Plot saved to: {plot_filename}")

    # Show the plot
    plt.show()

#create a SCATTER PLOT
    pre_time = (-2, 0)
    post_time = (0, 2)

    pre_values = []
    post_values = []
    mouse_labels = []

    # Collect values
    for mouse in selected_mice:
        if mouse in results_cohort1_vmm4_nohalt['mean_data_per_mouse']:
            mean_data = results_cohort1_vmm4_nohalt['mean_data_per_mouse'][mouse][column_to_plot]
            
            pre_mean = mean_data.loc[(mean_data.index >= pre_time[0]) & (mean_data.index < pre_time[1])].mean()
            post_mean = mean_data.loc[(mean_data.index >= post_time[0]) & (mean_data.index <= post_time[1])].mean()
            
            pre_values.append(pre_mean)
            post_values.append(post_mean)
            mouse_labels.append(mouse)

    # Grand average
    grand_mean = results_cohort1_vmm4_nohalt['grand_averages'][column_to_plot]
    pre_grand_mean = grand_mean.loc[(grand_mean.index >= pre_time[0]) & (grand_mean.index < pre_time[1])].mean()
    post_grand_mean = grand_mean.loc[(grand_mean.index >= post_time[0]) & (grand_mean.index <= post_time[1])].mean()

    grand_sem = results_cohort1_vmm4_nohalt['grand_sems'][column_to_plot]
    pre_grand_sem = grand_sem.loc[(grand_sem.index >= pre_time[0]) & (grand_sem.index < pre_time[1])].mean()
    post_grand_sem = grand_sem.loc[(grand_sem.index >= post_time[0]) & (grand_sem.index <= post_time[1])].mean()

    # Plotting
    plt.figure(figsize=(3, 4)) 

    # Plot each mouse with connecting line
    for i, mouse in enumerate(mouse_labels):
        plt.plot([1, 2], [pre_values[i], post_values[i]], color=mouse_colors[mouse], marker='o', linewidth=1, label=mouse)

    # Grand average as large black dots
    plt.plot([1, 2], [pre_grand_mean, post_grand_mean], color='black', marker='o', markersize=8, linewidth=1, label='Grand Avg')

    # Add error bars for grand average SEM
    plt.errorbar([1, 2], [pre_grand_mean, post_grand_mean], yerr=[pre_grand_sem, post_grand_sem], fmt='o', color='black', capsize=5)

    # Formatting
    plt.xticks([1, 2], [pre_time, post_time])
    plt.title(f'Mean {column_to_plot} Before and After Time 0')
    plt.ylabel(column_to_plot)
    plt.xlim(0.8, 2.2)  # tighter x-axis
    plt.grid(True)

    # Legend: one color per mouse
    handles, labels = plt.gca().get_legend_handles_labels()
    unique_labels = {}
    for h, l in zip(handles, labels):
        if l not in unique_labels:
            unique_labels[l] = h
    plt.legend(unique_labels.values(), unique_labels.keys(), loc='best', fontsize='small')
    plt.tight_layout()

    # Save the plot to the "baselined" subdirectory
    baselined_dir = main_data_dir / "baselined"
    baselined_dir.mkdir(exist_ok=True)  # Create the directory if it doesn't exist
    plot_filename = baselined_dir / f"{column_to_plot}_scatterplot_nohalt.pdf"
    plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
    print(f"Plot saved to: {plot_filename}")

    plt.show()

In [None]:
#FIXME: IDEALLY WE SHOULD ONLY USE 1 FILE TO STORE RESULTS AND GRAND AVERAGES, NOW SPREAD ACROSS PICLKE AND CSV FILES

#save csv with grand averages and SEMs

if save_grand_avg_with_sem:
    # Create a DataFrame combining grand averages and SEMs
    grand_avg_with_sem = results_cohort3_vmm3['grand_averages'].copy()
    for col in results_cohort3_vmm3['grand_sems'].columns:
        grand_avg_with_sem[f'{col}_SEM'] = results_cohort3_vmm3['grand_sems'][col]

    # Generate a filename that includes the selected mice
    mice_str = "_".join(selected_mice)
    csv_filename = main_data_dir / f'grand_averages_with_sem_{mice_str}.csv'

    # Save the DataFrame to a CSV file
    grand_avg_with_sem.to_csv(csv_filename)
    print(f"Grand averages with SEM saved to: {csv_filename}")

In [None]:
# Assuming the load_results function is defined like this:
def load_results(filename):
    import pickle
    with open(filename, 'rb') as f:
        return pickle.load(f)

# Load both result files
results_cohort1 = load_results('/Users/nora/Documents/GitHub/vestibular_vr_pipeline/results_cohort1_vmm4_nohalt.pkl')
results_cohort2 = load_results('/Users/nora/Documents/GitHub/vestibular_vr_pipeline/results_cohort3_vmm4_nohalt.pkl')

In [None]:
#FIXME: IN EXTRACT MEANS, I HAVE TO MANUALLY SPECIFY THE COLUMNS TO EXTRACT, THIS IS NOT IDEAL

#compute correlations between Velocity_0X and z_470, z_560 for each cohort
from scipy.stats import pearsonr, norm

def fisher_z(r):
    return 0.5 * np.log((1 + r) / (1 - r))

def compare_correlations(r1, n1, r2, n2):
    z1 = fisher_z(r1)
    z2 = fisher_z(r2)
    se = np.sqrt(1 / (n1 - 3) + 1 / (n2 - 3))
    z = (z1 - z2) / se
    p = 2 * (1 - norm.cdf(abs(z)))
    return z, p

# def extract_means(results, mice, time_window, columns=('Velocity_0X_Baseline', 'z_470_Baseline', 'z_560_Baseline')):
def extract_means(results, mice, time_window, columns=('Velocity_0X', 'z_470', 'z_560')):

    v_means, z470_means, z560_means, valid_mice = [], [], [], []
    t0, t1 = time_window

    for mouse in mice:
        if mouse not in results['mean_data_per_mouse']:
            continue
        df = results['mean_data_per_mouse'][mouse]
        if not all(col in df.columns for col in columns):
            continue

        df_window = df.loc[(df.index >= t0) & (df.index <= t1)]
        v = df_window[columns[0]].mean()
        z470 = df_window[columns[1]].mean()
        z560 = df_window[columns[2]].mean()

        if not any(pd.isnull([v, z470, z560])):
            v_means.append(v)
            z470_means.append(z470)
            z560_means.append(z560)
            valid_mice.append(mouse)

    return v_means, z470_means, z560_means, valid_mice

In [None]:
#analyse correlations between Velocity_0X and z_470, z_560 for each cohort
def analyze_correlations(
    cohort1, mice1,
    cohort2, mice2,
    time_window=(0, 2),
    plot=True
):
    # Extract means
    v1, z470_1, z560_1, ids1 = extract_means(cohort1, mice1, time_window)
    v2, z470_2, z560_2, ids2 = extract_means(cohort2, mice2, time_window)

    # Compute correlations
    corr1_470, p1_470 = pearsonr(v1, z470_1)
    corr1_560, p1_560 = pearsonr(v1, z560_1)
    corr2_470, p2_470 = pearsonr(v2, z470_2)
    corr2_560, p2_560 = pearsonr(v2, z560_2)

    print("\n📊 Correlations:")
    print(f"Cohort 1: Velocity ~ z_470: r = {corr1_470:.3f}, p = {p1_470:.3f}")
    print(f"Cohort 1: Velocity ~ z_560: r = {corr1_560:.3f}, p = {p1_560:.3f}")
    print(f"Cohort 2: Velocity ~ z_470: r = {corr2_470:.3f}, p = {p2_470:.3f}")
    print(f"Cohort 2: Velocity ~ z_560: r = {corr2_560:.3f}, p = {p2_560:.3f}")

    # Compare correlation coefficients
    if len(v1) > 3 and len(v2) > 3:
        z_470, p_470 = compare_correlations(corr1_470, len(v1), corr2_470, len(v2))
        z_560, p_560 = compare_correlations(corr1_560, len(v1), corr2_560, len(v2))

        print("\n🔍 Comparison of correlations:")
        print(f"z_470: z = {z_470:.3f}, p = {p_470:.3f}")
        print(f"z_560: z = {z_560:.3f}, p = {p_560:.3f}")
    else:
        print("⚠️ Not enough data (need >3 samples per group) to compare correlation coefficients.")

    # Optional: Plot
    if plot:
        fig, axs = plt.subplots(1, 2, figsize=(8, 4))
        
        # Plot Velocity vs z_470
        axs[0].scatter(v1, z470_1, color='green', label='Cohort 1')
        axs[0].scatter(v2, z470_2, color='orange', label='Cohort 2')
        axs[0].set_title('Velocity vs z_470')
        axs[0].legend()
        
        # Add mouse names next to the dots for z_470
        for i, mouse in enumerate(ids1):
            axs[0].text(v1[i], z470_1[i], mouse, fontsize=8, color='green', alpha=0.7)
        for i, mouse in enumerate(ids2):
            axs[0].text(v2[i], z470_2[i], mouse, fontsize=8, color='orange', alpha=0.7)

        # Add regression lines for z_470
        if len(v1) > 1:
            m1, b1 = np.polyfit(v1, z470_1, 1)
            axs[0].plot(v1, m1 * np.array(v1) + b1, color='green', linestyle='--', label='Cohort 1 Fit')
        if len(v2) > 1:
            m2, b2 = np.polyfit(v2, z470_2, 1)
            axs[0].plot(v2, m2 * np.array(v2) + b2, color='orange', linestyle='--', label='Cohort 2 Fit')

        # Plot Velocity vs z_560
        axs[1].scatter(v1, z560_1, color='darkred', label='Cohort 1')
        axs[1].scatter(v2, z560_2, color='red', label='Cohort 2')
        axs[1].set_title('Velocity vs z_560')
        axs[1].legend()
        
        # Add mouse names next to the dots for z_560 with some offset
        for i, mouse in enumerate(ids1):
            axs[1].text(v1[i], z560_1[i], mouse, fontsize=8, color='darkred', alpha=0.7)
        for i, mouse in enumerate(ids2):
            axs[1].text(v2[i], z560_2[i], mouse, fontsize=8, color='red', alpha=0.7)

        # Add regression lines for z_560
        if len(v1) > 1:
            m1, b1 = np.polyfit(v1, z560_1, 1)
            axs[1].plot(v1, m1 * np.array(v1) + b1, color='darkred', linestyle='--', label='Cohort 1 Fit')
        if len(v2) > 1:
            m2, b2 = np.polyfit(v2, z560_2, 1)
            axs[1].plot(v2, m2 * np.array(v2) + b2, color='red', linestyle='--', label='Cohort 2 Fit')

        for ax in axs:
            ax.set_xlabel('Mean Velocity_0X (m/s)')
            ax.set_ylabel('Mean z-score')
            ax.grid(True)

        plt.tight_layout()
        plt.show()


In [None]:
analyze_correlations(
    cohort1=results_cohort1,
    mice1=selected_mice1,
    cohort2=results_cohort2,
    mice2=selected_mice2,
    time_window=(0, 2),
    plot=True
)

In [None]:
# Paths to your CSV files for the session 'Visual_mismatch_day3'
# Replace these with the actual paths to your files
vmm3_cohort1 = '/Volumes/RanczLab2/20241125_Cohort1_rotation/Visual_mismatch_day3/grand_averages_with_sem_B6J2717_B6J2718_B6J2719_B6J2720_B6J2721_B6J2722.csv'
vmm3_cohort3 = '/Volumes/RanczLab2/20250409_Cohort3_rotation/Visual_mismatch_day3/grand_averages_with_sem_B6J2780_B6J2781_B6J2783_B6J2782.csv'
# vmm4_cohort1 = '/Volumes/RanczLab2/20241125_Cohort1_rotation/Visual_mismatch_day4/grand_averages_with_sem_B6J2717_B6J2718_B6J2719_B6J2720_B6J2721_B6J2722.csv'
# vmm4_cohort3 = '/Volumes/RanczLab2/20250409_Cohort3_rotation/Visual_mismatch_day4/grand_averages_with_sem_B6J2780_B6J2781_B6J2783.csv'
# ol1_cohort3 = '/Volumes/RanczLab2/20250409_Cohort3_rotation/Open_loop_day1/grand_averages_with_sem_B6J2780_B6J2781_B6J2783.csv'

# Load the data
df_cohort1 = pd.read_csv(vmm3_cohort1)
df_cohort3 = pd.read_csv(vmm3_cohort3)

# Plot the grand average and SEM for z_470_Baseline
plt.figure(figsize=(10, 6))
plt.minorticks_on()  # Enable minor ticks
plt.grid(which='both', linestyle='--', linewidth=0.5, alpha=0.7)  # Add grid for both major and minor ticks

# # # Plot Cohort1
# plt.plot(df_cohort1['Time (s)'], df_cohort1['z_470_Baseline'], label='GRAB5HT3.0', color='green')
# plt.fill_between(df_cohort1['Time (s)'],
#                  df_cohort1['z_470_Baseline'] - df_cohort1['z_470_Baseline_SEM'],
#                  df_cohort1['z_470_Baseline'] + df_cohort1['z_470_Baseline_SEM'],
#                  color='green', alpha=0.2)

# # Plot Cohort3
# plt.plot(df_cohort3['Time (s)'], df_cohort3['z_470_Baseline'], label='mut-GRAB5HT3.0', color='orange')
# plt.fill_between(df_cohort3['Time (s)'],
#                  df_cohort3['z_470_Baseline'] - df_cohort3['z_470_Baseline_SEM'],
#                  df_cohort3['z_470_Baseline'] + df_cohort3['z_470_Baseline_SEM'],
#                  color='orange', alpha=0.2)
# # Plot Cohort1
plt.plot(df_cohort1['Time (s)'], df_cohort1['dfF_470_Baseline'], label='GRAB5HT3.0', color='green')
plt.fill_between(df_cohort1['Time (s)'],
                 df_cohort1['dfF_470_Baseline'] - df_cohort1['dfF_470_Baseline_SEM'],
                 df_cohort1['dfF_470_Baseline'] + df_cohort1['dfF_470_Baseline_SEM'],
                 color='green', alpha=0.2)

# Plot Cohort3
plt.plot(df_cohort3['Time (s)'], df_cohort3['dfF_470_Baseline'], label='mut-GRAB5HT3.0', color='orange')
plt.fill_between(df_cohort3['Time (s)'],
                 df_cohort3['dfF_470_Baseline'] - df_cohort3['dfF_470_Baseline_SEM'],
                 df_cohort3['dfF_470_Baseline'] + df_cohort3['dfF_470_Baseline_SEM'],
                 color='orange', alpha=0.2)

# Add a gray shadowed area between seconds 0 and 2
plt.axvspan(0, 2, color='gray', alpha=0.2, label='visual_mismatch')

# Plot customization
# plt.title('OLday1', fontname='Arial', fontsize=10)
plt.xlabel('Time (s)', fontname='Arial', fontsize=10)
plt.ylabel('dfF', fontname='Arial', fontsize=10)
plt.legend(prop={'family': 'Arial', 'size': 10})
plt.grid(True)

# Save the plot to the GitHub directory
github_dir = Path.home() / "GitHub" / "plots"
github_dir.mkdir(parents=True, exist_ok=True)
plot_filename_github = github_dir / "mmday3dff470.pdf"
plt.savefig(plot_filename_github, format='pdf', dpi=300, bbox_inches='tight')
print(f"Plot saved to: {plot_filename_github}")
plt.tight_layout()
# Show the plot
plt.show()
