In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../')

from main import load_and_prepare_sessions
from processing.session_sampling import MiceAnalysis
import config
from collections import defaultdict
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt

sessions = load_and_prepare_sessions("../../../Baseline", load_from_pickle=True, remove_bad_signal_sessions=True)
mouse_analyser = MiceAnalysis(sessions)

FileNotFoundError: [Errno 2] No such file or directory: '../Gq-DREADD_CPT_Training_Stages/Stage1/Post'

In [35]:
def pair_wise_idxs(raw_df, actions_attr_dict):
    # Define events of interest including 'Display Image'
    events_of_interest = list(actions_attr_dict.keys())

    # Get raw data and filter for events of interest
    item_df = raw_df[["Item_Name"]]
    filtered_df = item_df[item_df["Item_Name"].isin(events_of_interest)].reset_index()

    # Initialize defaultdict to store the pairs
    event_pairs_dict = defaultdict(list)

    # Iterate through the filtered_df row by row
    for i in range(len(filtered_df) - 1):
        current_event = filtered_df.at[i, "Item_Name"]
        next_event = filtered_df.at[i + 1, "Item_Name"]
        if current_event in events_of_interest and next_event in events_of_interest:
            event_pairs_dict[(current_event, next_event)]\
                .append((filtered_df.at[i, "index"], filtered_df.at[i + 1, "index"]))

    return event_pairs_dict

In [36]:
def get_session_pair_signals(session, brain_region):
    signal_order = []
    all_session_signals = []

    for event in (['hit', 'miss', 'mistake', 'cor_reject']):
        curr_signal_info = session.signal_info.get((brain_region, event))
        if curr_signal_info is None:
            continue
        curr_ranges = curr_signal_info['signal_idx_ranges']

        len_before = sum(len(arr) for arr in all_session_signals)
        all_session_signals.append(curr_signal_info['signal_matrix'])
        len_after = sum(len(arr) for arr in all_session_signals)

        event_list = [event] * len(curr_ranges)
        idx_list = range(len_before, len_after)
        signal_order.extend(zip(curr_ranges, event_list, idx_list))

    _, events, matrix_idx = zip(*sorted(signal_order, key=lambda x: x[0][0]))
    signal_matrix = np.vstack(all_session_signals)[matrix_idx, :]

    pair_dict = defaultdict(list)

    for i in range(len(events) - 1):
        curr_event = events[i]
        next_event = events[i + 1]
        #if next_event == 'hit':
        pair_dict[(curr_event, next_event, brain_region)].append(signal_matrix[i+1])

    for k, v in pair_dict.items():
        pair_dict[k] = np.vstack(v)

    return pair_dict

In [37]:
pair_dict = defaultdict(list)

for session in sessions:
    for brain_region in config.all_brain_regions:
        for event in (['hit', 'miss', 'mistake', 'cor_reject']):
            if session.signal_info.get((brain_region, event)) is None:
                    continue

            pair_dict_part = get_session_pair_signals(session, brain_region)
            if pair_dict_part is None:
                continue
            for k, v in pair_dict_part.items():
                pair_dict[k].append(v)

for k, v in pair_dict.items():
    pair_dict[k] = np.vstack(v)

In [38]:
for k, v in pair_dict.items():
    print(k, len(v))

('miss', 'cor_reject', 'DMS') 3320
('cor_reject', 'miss', 'DMS') 3524
('cor_reject', 'hit', 'DMS') 2350
('hit', 'cor_reject', 'DMS') 2147
('cor_reject', 'cor_reject', 'DMS') 9445
('cor_reject', 'mistake', 'DMS') 952
('mistake', 'cor_reject', 'DMS') 1348
('hit', 'miss', 'DMS') 505
('miss', 'miss', 'DMS') 978
('miss', 'hit', 'DMS') 438
('miss', 'mistake', 'DMS') 280
('hit', 'hit', 'DMS') 372
('hit', 'mistake', 'DMS') 120
('miss', 'cor_reject', 'DLS') 3456
('cor_reject', 'miss', 'DLS') 3672
('cor_reject', 'hit', 'DLS') 1764
('hit', 'cor_reject', 'DLS') 1628
('cor_reject', 'cor_reject', 'DLS') 9112
('cor_reject', 'mistake', 'DLS') 632
('mistake', 'cor_reject', 'DLS') 984
('hit', 'miss', 'DLS') 460
('miss', 'miss', 'DLS') 1048
('miss', 'hit', 'DLS') 424
('miss', 'mistake', 'DLS') 272
('hit', 'hit', 'DLS') 204
('hit', 'mistake', 'DLS') 80
('mistake', 'mistake', 'DLS') 268
('mistake', 'mistake', 'DMS') 364
('miss', 'cor_reject', 'VS') 2786
('cor_reject', 'miss', 'VS') 2917
('cor_reject', 'cor

In [39]:
brain_reg_to_color = {'VS': 'purple',
                      'DMS': 'forestgreen',
                      'DLS': 'C0'}

In [40]:
def preprocess_and_plot_signals(pair_dict, event_type, brain_region, smoothing_len=10):
    # Assuming all_signals is predefined
    event1, event2 = event_type
    signals = pair_dict[event1, event2, brain_region]

    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['fps']
    
    xs = np.arange(-interval_start, interval_end) / fps
    
    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # The lower and upper bounds
    lb = ci_lower.min()
    ub = ci_upper.max()

    ylim = (lb, ub)
    
    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color[brain_region]

    plt.figure(dpi=300)
    plt.plot(xs[100: -100], ys[100: -100], color=color, label='Mean Signal')
    plt.fill_between(xs[100: -100], ci_lower[100: -100], ci_upper[100: -100], color=color, alpha=0.2, label='95% CI')
    plt.ylim(ylim)
    plt.title(f'{event_type}, {brain_region}, (n = {len(signals)})')
    plt.xlabel('Time (s)')
    plt.ylabel('z-score')
    plt.legend()
    plt.grid()
    plt.savefig(f'{event1}_to_{event2}_{brain_region}.png')
    plt.show()

In [41]:
pair_dict.keys()

dict_keys([('miss', 'cor_reject', 'DMS'), ('cor_reject', 'miss', 'DMS'), ('cor_reject', 'hit', 'DMS'), ('hit', 'cor_reject', 'DMS'), ('cor_reject', 'cor_reject', 'DMS'), ('cor_reject', 'mistake', 'DMS'), ('mistake', 'cor_reject', 'DMS'), ('hit', 'miss', 'DMS'), ('miss', 'miss', 'DMS'), ('miss', 'hit', 'DMS'), ('miss', 'mistake', 'DMS'), ('hit', 'hit', 'DMS'), ('hit', 'mistake', 'DMS'), ('miss', 'cor_reject', 'DLS'), ('cor_reject', 'miss', 'DLS'), ('cor_reject', 'hit', 'DLS'), ('hit', 'cor_reject', 'DLS'), ('cor_reject', 'cor_reject', 'DLS'), ('cor_reject', 'mistake', 'DLS'), ('mistake', 'cor_reject', 'DLS'), ('hit', 'miss', 'DLS'), ('miss', 'miss', 'DLS'), ('miss', 'hit', 'DLS'), ('miss', 'mistake', 'DLS'), ('hit', 'hit', 'DLS'), ('hit', 'mistake', 'DLS'), ('mistake', 'mistake', 'DLS'), ('mistake', 'mistake', 'DMS'), ('miss', 'cor_reject', 'VS'), ('cor_reject', 'miss', 'VS'), ('cor_reject', 'cor_reject', 'VS'), ('miss', 'miss', 'VS'), ('cor_reject', 'hit', 'VS'), ('hit', 'cor_reject', 

In [42]:
pair_dict.keys()

dict_keys([('miss', 'cor_reject', 'DMS'), ('cor_reject', 'miss', 'DMS'), ('cor_reject', 'hit', 'DMS'), ('hit', 'cor_reject', 'DMS'), ('cor_reject', 'cor_reject', 'DMS'), ('cor_reject', 'mistake', 'DMS'), ('mistake', 'cor_reject', 'DMS'), ('hit', 'miss', 'DMS'), ('miss', 'miss', 'DMS'), ('miss', 'hit', 'DMS'), ('miss', 'mistake', 'DMS'), ('hit', 'hit', 'DMS'), ('hit', 'mistake', 'DMS'), ('miss', 'cor_reject', 'DLS'), ('cor_reject', 'miss', 'DLS'), ('cor_reject', 'hit', 'DLS'), ('hit', 'cor_reject', 'DLS'), ('cor_reject', 'cor_reject', 'DLS'), ('cor_reject', 'mistake', 'DLS'), ('mistake', 'cor_reject', 'DLS'), ('hit', 'miss', 'DLS'), ('miss', 'miss', 'DLS'), ('miss', 'hit', 'DLS'), ('miss', 'mistake', 'DLS'), ('hit', 'hit', 'DLS'), ('hit', 'mistake', 'DLS'), ('mistake', 'mistake', 'DLS'), ('mistake', 'mistake', 'DMS'), ('miss', 'cor_reject', 'VS'), ('cor_reject', 'miss', 'VS'), ('cor_reject', 'cor_reject', 'VS'), ('miss', 'miss', 'VS'), ('cor_reject', 'hit', 'VS'), ('hit', 'cor_reject', 

In [None]:
for k in sorted(pair_dict.keys(), key=lambda x: x[-2]):
    event1, event2, br = k
    event = (event1, event2)
    preprocess_and_plot_signals(pair_dict, event, br)