In [30]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../')

from main import load_and_prepare_sessions
from processing.session_sampling import MiceAnalysis
from analysis.timepoint_analysis import find_drug_split_x
from itertools import product

import numpy as np
import scipy.stats as stats 
import matplotlib.pyplot as plt

import config

from collections import defaultdict
import pickle



sessions = load_and_prepare_sessions("../../../trial_Gq-DREADD_CPT", load_from_pickle=True, remove_bad_signal_sessions=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
# brain_reg_to_color = {'VS': 'purple',
#                       'DMS': 'forestgreen',
#                       'DLS': 'C0'}

brain_reg_to_color = {'LH': 'orange',
                      'mPFC': 'cornflowerblue'}

In [32]:
def update_genotypes(sessions, mice_gen_dict):
    """
    Updates the genotypes for a list of sessions based on the provided genotype mapping.
    Prints whether the genotypes are valid ('TH-Cre', 'Wildtype') 
    and the number of genotype changes made.

    Parameters:
    - sessions: A list of session objects to process.
    - mice_gen_dict: A dictionary mapping mouse IDs to new genotypes.
    """
    geno_mapping = {
        "Cre": "TH-Cre",
        "WT": "Wildtype"
    }
    # Map mice_gen_dict to use TH-Cre and Wildtype
    mapped_genotypes = {k: geno_mapping[v] for k, v in mice_gen_dict.items()}
    
    # Initialize counters and trackers
    valid_genotypes = {'TH-Cre', 'Wildtype'}
    all_genotypes = set()
    genotype_changes = 0

    for session in sessions:
        original_genotype = session.genotype
        int_id = int(session.mouse_id)
        
        if int_id in mapped_genotypes:
            session.genotype = mapped_genotypes[int_id]
            # Count changes if the genotype was updated
            if session.genotype != original_genotype:
                genotype_changes += 1
        
        all_genotypes.add(session.genotype)
    
    # Print results
    if all_genotypes.issubset(valid_genotypes):
        print(f"Valid genotypes found: {all_genotypes}")
    else:
        print(f"Invalid genotypes found: {all_genotypes}")
    
    print(f"Genotype changes made: {genotype_changes}")

In [33]:
mice_gen_dict = {
    69: "Cre",
    71: "WT",
    73: "Cre",
    75: "WT",
    77: "Cre",
    79: "WT",
    85: "WT",
    87: "WT",
    135: "WT",
    137: "WT",
    139: "Cre",
    133: "WT",
    127: "WT",
    125: "WT",
    129: "Cre",
    131: "WT",
    143: "Cre",
    145: "WT",
    147: "WT",
    157: "Cre",
    159: "Cre",
    161: "WT",
    171: "Cre",
    173: "Cre"
}

In [34]:
update_genotypes(sessions, mice_gen_dict)

Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 2


In [35]:
from collections import defaultdict

# Instead of defaultdict(dict), we need a nested structure keyed by (event, brain_reg, genotype).
# The value remains a dict keyed by (stage, phase).
all_sessions = defaultdict(dict)

stages = ['stage1', 'stage2', 'stage3', 'stage4']
phases = ['pre', 'post']

for stage in stages:
    # Temporary storage for sessions in each phase
    phase_sessions = {}
    
    # Step 1: Load all sessions for the current stage across all phases
    for phase in phases:
        pickle_name = f'sessions_two_{stage}_{phase}'
        
        # Load and prepare sessions
        sessions = load_and_prepare_sessions(
            "../../../Gq-DREADD_CPT_Training_Stages",
            load_from_pickle=True,
            remove_bad_signal_sessions=True,
            pickle_name=pickle_name
        )
        
        # Update genotypes in-place
        update_genotypes(sessions, mice_gen_dict)
        phase_sessions[phase] = sessions
    
    # Step 2: Compute the intersection of mouse IDs across phases
    mouse_ids_pre = set(s.mouse_id for s in phase_sessions['pre'])
    mouse_ids_post = set(s.mouse_id for s in phase_sessions['post'])
    intersection_mouse_ids = mouse_ids_pre.intersection(mouse_ids_post)

    print(f"[INFO] {stage}: {len(intersection_mouse_ids)} common mouse IDs across phases.")
    
    # Step 3: Iterate over each phase and filter sessions by the intersection of mouse IDs
    for phase in phases:
        sessions = phase_sessions[phase]
        filtered_sessions = [s for s in sessions if s.mouse_id in intersection_mouse_ids]
        all_sessions[stage][phase] = filtered_sessions

Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 1
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 1
[INFO] stage1: 7 common mouse IDs across phases.
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 0
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 2
[INFO] stage2: 13 common mouse IDs across phases.
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 1
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 0
[INFO] stage3: 20 common mouse IDs across phases.
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 0
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 0
[INFO] stage4: 6 common mouse IDs across phases.


In [36]:
stages = ['stage1', 'stage2', 'stage3', 'stage4']
phases = ['pre', 'post']

for stage in stages:
    for phase in phases:
        sessions = all_sessions[stage][phase]
        for session in sessions:
            for brain_region in session.brain_regions:
                threshold_x = find_drug_split_x(session, brain_region)
                brain_region = brain_region.split('_')[0]
                for event in config.all_event_types:
                    curr_signal_info = session.signal_info.get((brain_region, event))
                    if curr_signal_info is None:
                        continue
                    signal_idx_ranges = curr_signal_info['signal_idx_ranges']

                    # find the middle position (0 seconds) of the given signal
                    thresholds = np.array([((tup[0] + tup[1]) // 2) for tup in signal_idx_ranges])

                    # if the 0 seconds position is later than threshold x, we know that the signal is past the injection peak
                    curr_signal_info["is_above_threshold"] = thresholds > threshold_x

In [38]:
all_signal_groups = defaultdict(list)

for stage in stages:
    for phase in phases:
        sessions = all_sessions[stage][phase]
        print(session.trial_id)
        for session in sessions:
            genotype = session.genotype
            dose = (session.drug_info['name'], session.drug_info['dose'])
            for event in config.all_event_types:
                for brain_region in config.all_brain_regions:
                    curr_signal_info = session.signal_info.get((brain_region, event))
                    if curr_signal_info is None:
                        continue

                    curr_key = (stage, phase, genotype, dose, event, brain_region)

                    if session.drug_info['dose'] is not None and genotype == 'TH-Cre':
                        is_above_threshold = curr_signal_info["is_above_threshold"]
                        all_signal_groups[curr_key + (True, )].append(
                            curr_signal_info['signal_matrix'][is_above_threshold]
                        )
                        all_signal_groups[curr_key + (False, )].append(
                            curr_signal_info['signal_matrix'][~is_above_threshold]
                        )
                    else:
                        all_signal_groups[curr_key + (None, )].append(
                            curr_signal_info['signal_matrix']
                        )

# Do the final conversion **after** all loops finish
all_signal_groups = {k: np.vstack(v) for k, v in all_signal_groups.items() if len(v) != 0}
all_signals = {k: v for k, v in all_signal_groups.items() if len(v) != 0}


T11_e.e.e.75
T2_77.85.79.87
T9_e.e.e.85
T13_e.e.e.171
T29_e.e.e.173
T14_e.129.159.e
T29_129.e.e.e
T5_e.e.e.75


In [28]:
def preprocess_and_plot_signals(key, dict, smoothing_len=10):
    # Assuming all_signals is predefined
    # signals = all_signals[(event_type, brain_region)]
    signals = dict[key]
    stage, phase, genotype, dose, event, brain_region, is_above_threshold = key

    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['cpt']['fps']
    
    xs = np.arange(-interval_start, interval_end) / fps
    
    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # The lower and upper bounds
    lb = ci_lower.min()
    ub = ci_upper.max()

    # Check if ylim is valid
    if not np.isfinite(lb) or not np.isfinite(ub) or lb == ub:
        print(f"[WARNING] Invalid ylim ({lb}, {ub}) for key: {key}. Skipping plot.")
        return

    ylim = (lb, ub)
    
    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color[brain_region]

    plt.figure(dpi=300)
    plt.plot(xs, ys, color=color, label='Mean Signal')
    plt.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, label='95% CI')
    
    # if brain_region == 'LH':
    #     plt.ylim(-.3, 0.5)
    # else:
    #     plt.ylim(-0.25, 0.35)
    plt.ylim(ylim)
    if is_above_threshold == True:
        threshold_text = 'above threshold'
    elif is_above_threshold == False:
        threshold_text = 'below threshold' 
    else:
        threshold_text = '' 

    plt.title(f"""stage: {stage}, phase: {phase},
              genotype: {genotype}, dose: {dose},
              event: {event}, brain region: {brain_region}, {threshold_text}
              (n={len(signals)})""")
    plt.xlabel('Time (s)')
    plt.ylabel('z-score')
    plt.legend()
    plt.grid()
    
    plt.tight_layout()
    plt.savefig(f"{stage}_{phase}_{genotype}_{dose}_{event}_{brain_region}_{threshold_text}_stages.pdf")
    plt.close()
    # plt.show()

    # Save the figure locall
    # plt.show()

# Example usage
# Assuming sessions, config, all_signals, brain_reg_to_color are defined
for k in all_signals.keys():
    preprocess_and_plot_signals(k, all_signals, smoothing_len=10)

In [29]:
# import pickle

# with open('all_signals_2.pickle', 'wb') as file:
#     # Pickle the dictionary and write it to the file
#     pickle.dump(all_signals, file)