In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../')

# Custom modules
from main import load_and_prepare_sessions
from processing.session_sampling import MiceAnalysis
from data.mouse import create_mice_dict
from analysis.timepoint_analysis import sample_signals_and_metrics, collect_sessions_data

# Config and constants
from config import all_event_types, all_brain_regions
import config

# Plotting libraries
import matplotlib.pyplot as plt

# Signal processing and statistical tools
from scipy.signal import savgol_filter
import scipy.stats as stats
from itertools import product

# Utility libraries
import numpy as np
from tqdm.notebook import tqdm
from ipywidgets import FloatProgress
from collections import defaultdict

In [2]:
# for stage in ['stage1', 'stage2', 'stage3', 'stage4']:
#     for phase in ['pre', 'post']:
#         curr_sessions = load_and_prepare_sessions("../../../../Gq-DREADD_CPT_Training_Stages", load_from_pickle=True, remove_bad_signal_sessions=True,
#                                                   pickle_name=f'sessions_{stage}_{phase}')
        #title = stage + ' ' + phase
        # for session in tqdm(curr_sessions):
        #     for brain_reg in session.brain_regions:
        #         pass
                # # Step 3: Create a Plotly figure object for subplots
                # # Adjust rows and cols based on your layout needs
                # fig = make_subplots(rows=1, cols=1)

                # title = f"genotype: {session.genotype}, dose: {session.drug_info['dose']}, brain region: {brain_reg}\
                #     \n{session.trial_id}, {stage}, {phase}"
                # plot_session_events_and_signal(session, brain_reg, fig, row=1, col=1, title_suffix=title)
                # #fig.write_image(f"{session.mouse_id}_{session.trial_id}_{brain_reg}.png")
                # #fig.show()

                # #fig = make_subplots(rows=1, cols=1)

                # #plot_session_events_and_signal(session, brain_reg, fig, row=1, col=1, title_suffix=title, smooth=True)
                # #fig.write_image(f"{session.mouse_id}_{session.trial_id}_{brain_reg}.png")
                # curr_fname = f"{session.mouse_id}_{session.trial_id}_{brain_reg}_{stage}_{phase}.html"
                # print(curr_fname)
                # fig.write_html(curr_fname)
                
    


In [3]:
first_mice_sessions = load_and_prepare_sessions("../../../../Gq-DREADD_CPT_Training_Stages", load_from_pickle=True, remove_bad_signal_sessions=True,
                                                pickle_name=f'sessions_two_stage1_pre')
first_mice_ids = [s.mouse_id for s in first_mice_sessions]

In [4]:
brain_reg_to_color = {'LH': 'orange',
                      'mPFC': 'cornflowerblue'}

In [5]:
def update_genotypes(sessions, mice_gen_dict):
    """
    Updates the genotypes for a list of sessions based on the provided genotype mapping.
    Prints whether the genotypes are valid ('TH-Cre', 'Wildtype') 
    and the number of genotype changes made.

    Parameters:
    - sessions: A list of session objects to process.
    - mice_gen_dict: A dictionary mapping mouse IDs to new genotypes.
    """
    geno_mapping = {
        "Cre": "TH-Cre",
        "WT": "Wildtype"
    }
    # Map mice_gen_dict to use TH-Cre and Wildtype
    mapped_genotypes = {k: geno_mapping[v] for k, v in mice_gen_dict.items()}
    
    # Initialize counters and trackers
    valid_genotypes = {'TH-Cre', 'Wildtype'}
    all_genotypes = set()
    genotype_changes = 0

    for session in sessions:
        original_genotype = session.genotype
        int_id = int(session.mouse_id)
        
        if int_id in mapped_genotypes:
            session.genotype = mapped_genotypes[int_id]
            # Count changes if the genotype was updated
            if session.genotype != original_genotype:
                genotype_changes += 1
        
        all_genotypes.add(session.genotype)
    
    # Print results
    if all_genotypes.issubset(valid_genotypes):
        print(f"Valid genotypes found: {all_genotypes}")
    else:
        print(f"Invalid genotypes found: {all_genotypes}")
    
    print(f"Genotype changes made: {genotype_changes}")

In [6]:
mice_gen_dict = {
    69: "Cre",
    71: "WT",
    73: "Cre",
    75: "WT",
    77: "Cre",
    79: "WT",
    85: "WT",
    87: "WT",
    135: "WT",
    137: "WT",
    139: "Cre",
    133: "WT",
    127: "WT",
    125: "WT",
    129: "Cre",
    131: "WT",
    143: "Cre",
    145: "WT",
    147: "WT",
    157: "Cre",
    159: "Cre",
    161: "WT",
    171: "Cre",
    173: "Cre"
}

In [7]:
from collections import defaultdict

# Instead of defaultdict(dict), we need a nested structure keyed by (event, brain_reg, genotype).
# The value remains a dict keyed by (stage, phase).
signals_grouped = defaultdict(dict)

stages = ['stage1', 'stage2', 'stage3', 'stage4']
phases = ['pre', 'post']

for stage in stages:
    # Temporary storage for sessions in each phase
    phase_sessions = {}
    
    # Step 1: Load all sessions for the current stage across all phases
    for phase in phases:
        pickle_name = f'sessions_two_{stage}_{phase}'
        
        # Load and prepare sessions
        sessions = load_and_prepare_sessions(
            "../../../../Gq-DREADD_CPT_Training_Stages",
            load_from_pickle=True,
            remove_bad_signal_sessions=True,
            pickle_name=pickle_name
        )
        
        # Update genotypes in-place
        update_genotypes(sessions, mice_gen_dict)
        phase_sessions[phase] = sessions
    
    # Step 2: Compute the intersection of mouse IDs across phases
    mouse_ids_pre = set(s.mouse_id for s in phase_sessions['pre'])
    mouse_ids_post = set(s.mouse_id for s in phase_sessions['post'])
    intersection_mouse_ids = mouse_ids_pre.intersection(mouse_ids_post)

    print(f"[INFO] {stage}: {len(intersection_mouse_ids)} common mouse IDs across phases.")
    
    # Step 3: Iterate over each phase and filter sessions by the intersection of mouse IDs
    for phase in phases:
        sessions = phase_sessions[phase]
        filtered_sessions = [s for s in sessions if s.mouse_id in intersection_mouse_ids]
        
        print(f"[INFO] {stage} - {phase}: {len(filtered_sessions)} sessions after filtering.")
        
        # Step 4: Group signals by (event, brain_reg, genotype)
        for event in all_event_types:
            for brain_reg in all_brain_regions:
                # 1) Gather only those sessions that have the current genotype
                #    We'll group them by genotype, so we can sample separately.
                genotype_groups = defaultdict(list)
                for sess in filtered_sessions:
                    genotype_groups[sess.genotype].append(sess)

                # 2) For each genotype, sample signals and store them
                for genotype, geno_sess_list in genotype_groups.items():
                    sampled = sample_signals_and_metrics(
                        geno_sess_list,
                        brain_reg,
                        'G', 
                        event, 
                        weight_method='events'
                    )
                    if sampled:
                        all_signals = sampled[0]
                        
                        # Now store under (event, brain_reg, genotype) instead of just (event, brain_reg)
                        key = (event, brain_reg, genotype)
                        # If you haven't initialized this sub-dict yet, do so now
                        if key not in signals_grouped:
                            signals_grouped[key] = {}
                        # Then store data keyed by (stage, phase)
                        signals_grouped[key][(stage, phase)] = all_signals

Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 1
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 1
[INFO] stage1: 7 common mouse IDs across phases.
[INFO] stage1 - pre: 7 sessions after filtering.
[INFO] stage1 - post: 7 sessions after filtering.
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 0
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 2
[INFO] stage2: 13 common mouse IDs across phases.
[INFO] stage2 - pre: 15 sessions after filtering.
[INFO] stage2 - post: 13 sessions after filtering.
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 1
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 0
[INFO] stage3: 20 common mouse IDs across phases.
[INFO] stage3 - pre: 20 sessions after filtering.
[INFO] stage3 - post: 20 sessions after filtering.
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 0
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotyp

In [8]:
def preprocess_and_plot_signals(all_signals, event_type, brain_region, smoothing_len=10, title_suffix=None):
    # Assuming all_signals is predefined
    signals = all_signals

    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['fps']
    
    xs = np.arange(-interval_start, interval_end) / fps
    
    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # The lower and upper bounds
    lb = ci_lower.min()
    ub = ci_upper.max()

    ylim = (lb, ub)
    
    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color[brain_region]

    if title_suffix is None:
        title_suffix = ''
    plt.figure(dpi=300)
    plt.plot(xs, ys, color=color, label='Mean Signal')
    plt.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, label='95% CI')
    plt.ylim(ylim)
    title_name = f'{event_type}, {brain_region}, {title_suffix}, (n = {len(signals)})'
    plt.title(title_name)
    plt.xlabel('Time (s)')
    plt.ylabel('z-score')
    plt.legend()
    plt.grid()
    # plt.savefig(title_name)
    plt.show()

In [9]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

def preprocess_and_plot_signals(all_signals, event_type, brain_region, smoothing_len=10, title_suffix=None, ax=None):
    """
    Preprocess and plot signals on the given Axes.

    Parameters:
    - all_signals: numpy array of signals
    - event_type: string indicating the event type
    - brain_region: string indicating the brain region
    - smoothing_len: integer for smoothing window length
    - title_suffix: additional string for the plot title
    - ax: matplotlib Axes object to plot on
    """
    if ax is None:
        fig, ax = plt.subplots(dpi=300)

    signals = all_signals

    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['cpt']['fps']

    xs = np.arange(-interval_start, interval_end) / fps

    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color.get(brain_region, 'blue')  # Default to 'blue' if not found

    if title_suffix is None:
        title_suffix = ''
    
    ax.plot(xs, ys, color=color, label='Mean Signal')
    ax.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, label='95% CI')
    
    title_name = f'{event_type}, {brain_region}, {title_suffix}, (n = {len(signals)})'
    ax.set_title(title_name, fontsize=10)
    ax.set_xlabel('Time (s)', fontsize=8)
    ax.set_ylabel('z-score', fontsize=8)
    ax.legend(fontsize=6)
    ax.grid(True)

In [10]:
sync_y_limits = 'row'  # or 'global'

for sync_y_limits in ['row', 'global']:
    for (event, brain_reg, genotype), stage_phase_signals in signals_grouped.items():
        # Create a figure with 4 rows (stages) and 2 columns (phases)
        fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(10, 15), dpi=300, sharex=True, sharey=False)
        fig_title = f'Event: {event} | Brain Region: {brain_reg} | Genotype: {genotype}'
        fig.suptitle(fig_title, fontsize=16)
        
        # Prepare for optional Y-limit synchronization
        if sync_y_limits == 'global':
            all_y_mins, all_y_maxs = [], []
        elif sync_y_limits == 'row':
            row_y_mins = [[] for _ in range(4)]
            row_y_maxs = [[] for _ in range(4)]
        else:
            raise ValueError("sync_y_limits must be either 'global' or 'row'")

        for i, stage in enumerate(['stage1', 'stage2', 'stage3', 'stage4']):
            for j, phase in enumerate(['pre', 'post']):
                ax = axes[i][j]
                key_stage_phase = (stage, phase)
                
                if key_stage_phase in stage_phase_signals:
                    all_signals = stage_phase_signals[key_stage_phase]
                    title_suffix = f'{stage}, {phase}'
                    preprocess_and_plot_signals(
                        all_signals,
                        event_type=event,
                        brain_region=brain_reg,
                        smoothing_len=10,
                        title_suffix=title_suffix,
                        ax=ax
                    )
                    current_ylim = ax.get_ylim()
                    if sync_y_limits == 'global':
                        all_y_mins.append(current_ylim[0])
                        all_y_maxs.append(current_ylim[1])
                    elif sync_y_limits == 'row':
                        row_y_mins[i].append(current_ylim[0])
                        row_y_maxs[i].append(current_ylim[1])
                else:
                    # No data for this stage/phase
                    ax.axis('off')
                    ax.set_title(f'{stage}, {phase}\nNo Data', fontsize=10)

        # Adjust Y-axis limits (same logic as before)
        if sync_y_limits == 'global':
            if all_y_mins and all_y_maxs:
                global_min = min(all_y_mins)
                global_max = max(all_y_maxs)
                for ax in axes.flat:
                    if ax.has_data():
                        ax.set_ylim(global_min, global_max)
        elif sync_y_limits == 'row':
            for row in range(4):
                if row_y_mins[row] and row_y_maxs[row]:
                    row_min = min(row_y_mins[row])
                    row_max = max(row_y_maxs[row])
                    for col in range(2):
                        ax = axes[row][col]
                        if ax.has_data():
                            ax.set_ylim(row_min, row_max)

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(f"plot_event_{event}_{brain_reg}_{genotype}_{sync_y_limits}_event_agg.png")
        #plt.show()
        plt.close(fig)

In [8]:
def session_br_event_count(session, br, event):
    total = 0
    maybe_signal_info = session.signal_info.get((br, event))
    if maybe_signal_info:
        total += maybe_signal_info['signal_matrix'].shape[0]
    return total


In [9]:
# Step 1: Load and Group Signals
signals_grouped = defaultdict(dict)

for stage in ['stage1', 'stage2', 'stage3', 'stage4']:
    for phase in ['pre', 'post']:
        pickle_name = f'sessions_two_{stage}_{phase}'
        curr_sessions = load_and_prepare_sessions(
            "../../../../Gq-DREADD_CPT_Training_Stages",
            load_from_pickle=True,
            remove_bad_signal_sessions=True,
            pickle_name=pickle_name
        )
        print(stage, phase, len(curr_sessions))
        # print(stage, phase, [(s.mouse_id, session_br_event_count(s, 'mPFC', 'dispimg')) for s in curr_sessions])
                              #if session_br_event_count(s, 'mPFC', 'dispimg') < 100])
        
        # # Filter sessions by mouse IDs
        # if stage == 'stage1':
        #     curr_sessions = [s for s in curr_sessions if s.mouse_id in first_mice_ids]
        
        # for event in all_event_types:
        #     for brain_reg in all_brain_regions:
        #         key = (event, brain_reg)
        #         if sample_signals_and_metrics(curr_sessions, event, brain_reg, weight_method='mice'):
        #             all_signals = sample_signals_and_metrics(curr_sessions, event, brain_reg, weight_method='mice')[0]
        #             signals_grouped[key][(stage, phase)] = all_signals

stage1 pre 8
stage1 post 23
stage2 pre 16
stage2 post 24
stage3 pre 23
stage3 post 21
stage4 pre 7
stage4 post 6
