In [6]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../')

from main import load_and_prepare_sessions
from processing.session_sampling import MiceAnalysis
from analysis.timepoint_analysis import find_drug_split_x
from itertools import product

import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt

import config

from collections import defaultdict
import pickle



sessions = load_and_prepare_sessions("../../../Gq-DREADD-Projection-Spec_CPT", load_from_pickle=True, remove_bad_signal_sessions=True)
mouse_analyser = MiceAnalysis(sessions)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


FileNotFoundError: [Errno 2] No such file or directory: '../Dual_Sensor_CPT/Males_redo'

In [15]:
# brain_reg_to_color = {'VS': 'purple',
#                       'DMS': 'forestgreen',
#                       'DLS': 'C0'}

brain_reg_to_color = {'LH': 'orange',
                      'mPFC': 'cornflowerblue'}

In [16]:
for session in sessions:
    for brain_region in session.brain_regions:

        # find the x position at which the injection peak has been passed within the trial
        threshold_x = find_drug_split_x(session, brain_region)
        brain_region = brain_region.split('_')[0]
        for event in config.all_event_types:
            curr_signal_info = session.signal_info.get((brain_region, event))
            if curr_signal_info is None:
                continue
            signal_idx_ranges = curr_signal_info['signal_idx_ranges']

            # find the middle position (0 seconds) of the given signal
            thresholds = np.array([((tup[0] + tup[1]) // 2) for tup in signal_idx_ranges])

            # if the 0 seconds position is later than threshold x, we know that the signal is past the injection peak
            curr_signal_info["is_above_threshold"] = thresholds > threshold_x

In [17]:
all_signal_groups = defaultdict(list)

for session in sessions:
    genotype = session.genotype
    dose = session.drug_info['dose']
    for event in config.all_event_types:
        for brain_region in config.all_brain_regions:
            curr_signal_info = session.signal_info.get((brain_region, event))
            if curr_signal_info is None:
                continue
            
            curr_key = (genotype, dose, event, brain_region)

            if dose is not None and genotype == 'TH-Cre':
                is_above_threshold = curr_signal_info["is_above_threshold"]

                all_signal_groups[curr_key + (True, )].append(curr_signal_info['signal_matrix'][is_above_threshold])
                all_signal_groups[curr_key + (False, )].append(curr_signal_info['signal_matrix'][~is_above_threshold])
            else:
                all_signal_groups[curr_key + (None, )].append(curr_signal_info['signal_matrix'])

all_signal_groups = {k: np.vstack(v) for k, v in all_signal_groups.items() if len(v) != 0}
all_signals = {k: v for k, v in all_signal_groups.items() if (len(v) != 0)}

In [None]:
def preprocess_and_plot_signals(key, dict, smoothing_len=10):
    # Assuming all_signals is predefined
    # signals = all_signals[(event_type, brain_region)]
    signals = dict[key]
    genotype, dose, event, brain_region, is_above_threshold = key

    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['fps']
    
    xs = np.arange(-interval_start, interval_end) / fps
    
    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # The lower and upper bounds
    lb = ci_lower.min()
    ub = ci_upper.max()

    ylim = (lb, ub)
    
    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color[brain_region]

    plt.figure(dpi=300)
    plt.plot(xs, ys, color=color, label='Mean Signal')
    plt.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, label='95% CI')
    
    # if brain_region == 'LH':
    #     plt.ylim(-.3, 0.5)
    # else:
    #     plt.ylim(-0.25, 0.35)
    plt.ylim(ylim)
    if is_above_threshold == True:
        threshold_text = 'above threshold,'
    elif is_above_threshold == False:
        threshold_text = 'below threshold,' 
    else:
        threshold_text = '' 

    plt.title(f"""genotype: {genotype}, dose: {dose},
              event: {event}, brain region: {brain_region}, {threshold_text}
              (n={len(signals)})""")
    plt.xlabel('Time (s)')
    plt.ylabel('z-score')
    plt.legend()
    plt.grid()
    
    plt.tight_layout()
    plt.savefig(f"{genotype}_{dose}_{event}_{brain_region}_{threshold_text}_Projection-Spec.png")
    plt.show()

    # Save the figure locall
    # plt.show()

# Example usage
# Assuming sessions, config, all_signals, brain_reg_to_color are defined
for k in all_signals.keys():
    preprocess_and_plot_signals(k, all_signals, smoothing_len=10)

In [19]:
# import pickle

# with open('all_signals_2.pickle', 'wb') as file:
#     # Pickle the dictionary and write it to the file
#     pickle.dump(all_signals, file)