In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
import h5py
from scripts.sample_db import SampleDB

def compute_correlation_matrix(traces, window_size):
    num_traces, trace_length = traces.shape
    num_windows = trace_length - window_size + 1
    correlation_matrix = np.zeros((num_traces, num_traces))
    
    for i in range(num_traces):
        for j in range(i, num_traces):
            correlations = []
            for k in range(num_windows):
                window_i = traces[i, k:k+window_size]
                window_j = traces[j, k:k+window_size]
                corr, _ = pearsonr(window_i, window_j)
                correlations.append(corr)
            
            avg_correlation = np.mean(correlations)
            correlation_matrix[i, j] = avg_correlation
            correlation_matrix[j, i] = avg_correlation
    
    return correlation_matrix

def plot_correlation_matrix(correlation_matrix, odors, trials, title):
    plt.figure(figsize=(12, 10))
    sns.heatmap(correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1, center=0,
                square=True, annot=False, cbar=True)
    
    # Add odor and trial labels
    ax = plt.gca()
    ax.set_xticks(np.arange(len(odors)) + 0.5)
    ax.set_yticks(np.arange(len(odors)) + 0.5)
    ax.set_xticklabels([f'Odor {o}\nTrial {t}' for o, t in zip(odors, trials)], rotation=45, ha='right')
    ax.set_yticklabels([f'Odor {o}\nTrial {t}' for o, t in zip(odors, trials)], rotation=0)
    
    plt.title(title)
    plt.tight_layout()
    plt.show()

# Main script
if __name__ == "__main__":
    # Load the sample database
    db_path = r'\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\sample_db.csv'
    sample_db = SampleDB()
    sample_db.load(db_path)

    # Loading experiment
    sample_id = '20220427_RM0008_126hpf_fP3_f3'  
    exp = sample_db.get_sample(sample_id)
    print(exp.sample.id)

    # Path to the HDF5 file
    hdf5_file_path = f'{exp.paths.trials_path}/traces/{exp.sample.id}_fluorescence_data.h5'

    # Load data from HDF5 file
    with h5py.File(hdf5_file_path, 'r') as f:
        exp_grp = f[exp.sample.id]
        print(exp_grp.keys())
        traces = exp_grp['dff_traces'][()]  # Use dF/F traces
        odors = exp_grp['odor'][()]
        trials = exp_grp['trial_nr'][()]

    # Sort traces by odor and trial
    sort_indices = np.lexsort((trials, odors))
    traces = traces[sort_indices]
    odors = odors[sort_indices]
    trials = trials[sort_indices]

    # User input for window size
    window_size = int(input("Enter the window size for correlation calculation: "))

    # Compute correlation matrix
    correlation_matrix = compute_correlation_matrix(traces, window_size)

    # Plot correlation matrix
    plot_correlation_matrix(correlation_matrix, odors, trials, f"Trace Correlation Matrix (Window Size: {window_size})")

20220427_RM0008_126hpf_fP3_f3
<KeysViewHDF5 ['cell_mapping', 'deconvolved_spikes', 'dff_traces', 'lm_plane_centroids', 'lm_plane_labels', 'odor', 'plane_nr', 'raw_traces', 'trial_nr']>
