In [16]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../')

# Custom modules
from main import load_and_prepare_sessions
from processing.session_sampling import MiceAnalysis
from data.mouse import create_mice_dict
from analysis.timepoint_analysis import sample_signals_and_metrics, collect_sessions_data

# Config and constants
from config import all_event_types, all_brain_regions
import config

# Plotting libraries
import matplotlib.pyplot as plt

# Signal processing and statistical tools
from scipy.signal import savgol_filter
import scipy.stats as stats
from itertools import product

# Utility libraries
import numpy as np
from tqdm.notebook import tqdm
from ipywidgets import FloatProgress
from collections import defaultdict

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
# for stage in ['stage1', 'stage2', 'stage3', 'stage4']:
#     for phase in ['pre', 'post']:
#         curr_sessions = load_and_prepare_sessions("../../../../Gq-DREADD_CPT_Training_Stages", load_from_pickle=True, remove_bad_signal_sessions=True,
#                                                   pickle_name=f'sessions_{stage}_{phase}')
        #title = stage + ' ' + phase
        # for session in tqdm(curr_sessions):
        #     for brain_reg in session.brain_regions:
        #         pass
                # # Step 3: Create a Plotly figure object for subplots
                # # Adjust rows and cols based on your layout needs
                # fig = make_subplots(rows=1, cols=1)

                # title = f"genotype: {session.genotype}, dose: {session.drug_info['dose']}, brain region: {brain_reg}\
                #     \n{session.trial_id}, {stage}, {phase}"
                # plot_session_events_and_signal(session, brain_reg, fig, row=1, col=1, title_suffix=title)
                # #fig.write_image(f"{session.mouse_id}_{session.trial_id}_{brain_reg}.png")
                # #fig.show()

                # #fig = make_subplots(rows=1, cols=1)

                # #plot_session_events_and_signal(session, brain_reg, fig, row=1, col=1, title_suffix=title, smooth=True)
                # #fig.write_image(f"{session.mouse_id}_{session.trial_id}_{brain_reg}.png")
                # curr_fname = f"{session.mouse_id}_{session.trial_id}_{brain_reg}_{stage}_{phase}.html"
                # print(curr_fname)
                # fig.write_html(curr_fname)
                
    


In [18]:
first_mice_sessions = load_and_prepare_sessions("../../../../Gq-DREADD_CPT_Training_Stages", load_from_pickle=True, remove_bad_signal_sessions=True,
                                                pickle_name=f'sessions_two_stage1_pre')
first_mice_ids = [s.mouse_id for s in first_mice_sessions]

In [19]:
brain_reg_to_color = {'LH': 'orange',
                      'mPFC': 'cornflowerblue'}

In [20]:
def update_genotypes(sessions, mice_gen_dict):
    """
    Updates the genotypes for a list of sessions based on the provided genotype mapping.
    Prints whether the genotypes are valid ('TH-Cre', 'Wildtype') 
    and the number of genotype changes made.

    Parameters:
    - sessions: A list of session objects to process.
    - mice_gen_dict: A dictionary mapping mouse IDs to new genotypes.
    """
    geno_mapping = {
        "Cre": "TH-Cre",
        "WT": "Wildtype"
    }
    # Map mice_gen_dict to use TH-Cre and Wildtype
    mapped_genotypes = {k: geno_mapping[v] for k, v in mice_gen_dict.items()}
    
    # Initialize counters and trackers
    valid_genotypes = {'TH-Cre', 'Wildtype'}
    all_genotypes = set()
    genotype_changes = 0

    for session in sessions:
        original_genotype = session.genotype
        int_id = int(session.mouse_id)
        
        if int_id in mapped_genotypes:
            session.genotype = mapped_genotypes[int_id]
            # Count changes if the genotype was updated
            if session.genotype != original_genotype:
                genotype_changes += 1
        
        all_genotypes.add(session.genotype)
    
    # Print results
    if all_genotypes.issubset(valid_genotypes):
        print(f"Valid genotypes found: {all_genotypes}")
    else:
        print(f"Invalid genotypes found: {all_genotypes}")
    
    print(f"Genotype changes made: {genotype_changes}")

In [21]:
mice_gen_dict = {
    69: "Cre",
    71: "WT",
    73: "Cre",
    75: "WT",
    77: "Cre",
    79: "WT",
    85: "WT",
    87: "WT",
    135: "WT",
    137: "WT",
    139: "Cre",
    133: "WT",
    127: "WT",
    125: "WT",
    129: "Cre",
    131: "WT",
    143: "Cre",
    145: "WT",
    147: "WT",
    157: "Cre",
    159: "Cre",
    161: "WT",
    171: "Cre",
    173: "Cre"
}

In [22]:
from collections import defaultdict

def create_signals_grouped(weight_method, subset_first_stage=False):
    # Step 1: Load and Group Signals
    signals_grouped = defaultdict(dict)

    # Define the stages and phases
    stages = ['stage1', 'stage2', 'stage3', 'stage4']
    phases = ['pre', 'post']
    first_stage_mouse_ids = None

    for stage in stages:
        # Temporary storage for sessions in each phase
        phase_sessions = {}
        
        # Step 1: Load all sessions for the current stage across all phases
        for phase in phases:
            pickle_name = f'sessions_two_{stage}_{phase}'
            
            # Load and prepare sessions
            sessions = load_and_prepare_sessions(
                "../../../../Gq-DREADD_CPT_Training_Stages",
                load_from_pickle=True,
                remove_bad_signal_sessions=True,
                pickle_name=pickle_name
            )

            print(stage, phase, [session.mouse_id for session in sessions])
            if subset_first_stage and stage == 'stage1' and phase == 'pre':
                first_stage_mouse_ids = {session.mouse_id for session in sessions}
            elif subset_first_stage:
                sessions = [session for session in sessions if session.mouse_id in first_stage_mouse_ids]
        
            
            # Store sessions by phase
            if sessions:
                update_genotypes(sessions, mice_gen_dict)
                phase_sessions[phase] = sessions
        
        # Step 2: Compute the intersection of mouse IDs across phases
        mouse_ids_pre = set(s.mouse_id for s in phase_sessions['pre'])
        mouse_ids_post = set(s.mouse_id for s in phase_sessions['post'])
        intersection_mouse_ids = mouse_ids_pre.intersection(mouse_ids_post)
        
        # Optional: Log the number of common mouse IDs
        print(f"[INFO] {stage}: {len(intersection_mouse_ids)} common mouse IDs across phases.")
        
        # Step 3: Iterate over each phase and filter sessions by the intersection of mouse IDs
        for phase in phases:
            # Retrieve the sessions for the current phase
            sessions = phase_sessions[phase]
            
            # Filter sessions to include only those with mouse IDs in the intersection
            filtered_sessions = [s for s in sessions if s.mouse_id in intersection_mouse_ids]
            
            # Optional: Log the number of sessions after filtering
            print(f"[INFO] {stage} - {phase}: {len(filtered_sessions)} sessions after filtering.")
            
            # Step 4: Process the filtered sessions
            for event in all_event_types:
                for brain_reg in all_brain_regions:
                    key = (event, brain_reg)
                    sampled = sample_signals_and_metrics(
                        filtered_sessions, 
                        brain_reg,
                        'G', 
                        event, 
                        #weight_method='mice'
                        weight_method=weight_method
                    )
                    
                    # Check if sampling was successful
                    if sampled:
                        all_signals = sampled[0]
                        
                        # Group the signals by event, brain region, stage, and phase
                        signals_grouped[key][(stage, phase)] = all_signals
                        
                        other_phase = 'post' if phase == 'pre' else 'pre'
                        other_phase_key = (stage, other_phase)
    return signals_grouped

In [23]:
def preprocess_and_plot_signals(all_signals, event_type, brain_region, smoothing_len=10, title_suffix=None):
    # Assuming all_signals is predefined
    signals = all_signals

    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['fps']
    
    xs = np.arange(-interval_start, interval_end) / fps
    
    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # The lower and upper bounds
    lb = ci_lower.min()
    ub = ci_upper.max()

    ylim = (lb, ub)
    
    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color[brain_region]

    if title_suffix is None:
        title_suffix = ''
    plt.figure(dpi=300)
    plt.plot(xs, ys, color=color, label='Mean Signal')
    plt.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, label='95% CI')
    plt.ylim(ylim)
    title_name = f'{event_type}, {brain_region}, {title_suffix}, (n = {len(signals)})'
    plt.title(title_name)
    plt.xlabel('Time (s)')
    plt.ylabel('z-score')
    plt.legend()
    plt.grid()
    # plt.savefig(title_name)
    plt.show()

In [24]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

def preprocess_and_plot_signals(all_signals, event_type, brain_region, smoothing_len=10, title_suffix=None, ax=None):
    """
    Preprocess and plot signals on the given Axes.

    Parameters:
    - all_signals: numpy array of signals
    - event_type: string indicating the event type
    - brain_region: string indicating the brain region
    - smoothing_len: integer for smoothing window length
    - title_suffix: additional string for the plot title
    - ax: matplotlib Axes object to plot on
    """
    if ax is None:
        fig, ax = plt.subplots(dpi=300)

    signals = all_signals

    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['cpt']['fps']

    xs = np.arange(-interval_start, interval_end) / fps

    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color.get(brain_region, 'blue')  # Default to 'blue' if not found

    if title_suffix is None:
        title_suffix = ''
    
    ax.plot(xs, ys, color=color, label='Mean Signal')
    ax.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, label='95% CI')
    
    title_name = f'{event_type}, {brain_region}, {title_suffix}, (n = {len(signals)})'
    ax.set_title(title_name, fontsize=10)
    ax.set_xlabel('Time (s)', fontsize=8)
    ax.set_ylabel('z-score', fontsize=8)
    ax.legend(fontsize=6)
    ax.grid(True)

In [26]:
from itertools import product
import os
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np

def extract_y_bounds(signals_grouped, sync_y_limits, sync_scope):
    """Extract raw y-bound values from temporary figures for all (event, brain_reg) keys.
    Returns aggregated bounds.
    """
    if sync_scope == "per_br":
        agg_global = {}  # brain_reg -> {"mins": [], "maxs": []}
    else:
        agg_global = {"mins": [], "maxs": []}

    if sync_y_limits == "row":
        if sync_scope == "per_br":
            agg_row = {}  # brain_reg -> { row: {"mins": [], "maxs": []} for row in range(3)}
        else:
            agg_row = {row: {"mins": [], "maxs": []} for row in range(3)}
    else:
        agg_row = None

    for (event, brain_reg), stage_phase_signals in signals_grouped.items():
        fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10,15), dpi=300, sharex=True, sharey=False)
        for i, stage in enumerate(['stage1', 'stage2', 'stage3']):
            for j, phase in enumerate(['pre', 'post']):
                ax = axes[i][j]
                key = (stage, phase)
                if key in stage_phase_signals:
                    signals = stage_phase_signals[key]
                    preprocess_and_plot_signals(
                        signals,
                        event_type=event,
                        brain_region=brain_reg,
                        smoothing_len=10,
                        title_suffix=f'{stage}, {phase}',
                        ax=ax
                    )
                    if i == 0 and j == 0 and ax.lines:
                        trace_y = ax.lines[0].get_ydata()
                        data_min, data_max = np.min(trace_y), np.max(trace_y)
                    else:
                        if ax.collections:
                            poly = ax.collections[0]
                            vertices = poly.get_paths()[0].vertices
                            y_values = vertices[:,1]
                            data_min, data_max = np.min(y_values), np.max(y_values)
                        else:
                            data_min, data_max = ax.get_ylim()
                    if sync_scope == "per_br":
                        agg_global.setdefault(brain_reg, {"mins":[], "maxs":[]})
                        agg_global[brain_reg]["mins"].append(data_min)
                        agg_global[brain_reg]["maxs"].append(data_max)
                        if sync_y_limits == "row":
                            agg_row.setdefault(brain_reg, {row: {"mins":[], "maxs":[]} for row in range(3)})
                            agg_row[brain_reg][i]["mins"].append(data_min)
                            agg_row[brain_reg][i]["maxs"].append(data_max)
                    else:
                        agg_global["mins"].append(data_min)
                        agg_global["maxs"].append(data_max)
                        if sync_y_limits == "row":
                            agg_row[i]["mins"].append(data_min)
                            agg_row[i]["maxs"].append(data_max)
                else:
                    ax.axis('off')
        plt.close(fig)
    return agg_global, agg_row

def compute_aggregated_limits(agg_global, agg_row, sync_y_limits, sync_scope):
    """Compute final y-limits using aggregated values."""
    if sync_y_limits == "global":
        if sync_scope == "per_br":
            limits_global = {}
            for br, bounds in agg_global.items():
                gmin = min(bounds["mins"])
                gmax = max(bounds["maxs"])
                margin = (gmax - gmin) * 0.02
                limits_global[br] = (gmin - margin, gmax + margin)
        else:
            gmin = min(agg_global["mins"])
            gmax = max(agg_global["maxs"])
            margin = (gmax - gmin) * 0.02
            limits_global = (gmin - margin, gmax + margin)
        return limits_global, None
    elif sync_y_limits == "row":
        if sync_scope == "per_br":
            limits_row = {}
            for br, rows in agg_row.items():
                limits_row[br] = {}
                for row in range(3):
                    if rows[row]["mins"] and rows[row]["maxs"]:
                        rmin = min(rows[row]["mins"])
                        rmax = max(rows[row]["maxs"])
                        margin = (rmax - rmin) * 0.02
                        limits_row[br][row] = (rmin - margin, rmax + margin)
                    else:
                        limits_row[br][row] = None
        else:
            limits_row = {}
            for row in range(3):
                if agg_row[row]["mins"] and agg_row[row]["maxs"]:
                    rmin = min(agg_row[row]["mins"])
                    rmax = max(agg_row[row]["maxs"])
                    margin = (rmax - rmin) * 0.02
                    limits_row[row] = (rmin - margin, rmax + margin)
                else:
                    limits_row[row] = None
        return None, limits_row

def replot_figures(signals_grouped, limits_global, limits_row, sync_y_limits, sync_scope, folder_name):
    """Re-plot figures applying the aggregated y-limits."""
    for (event, brain_reg), stage_phase_signals in signals_grouped.items():
        fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10,15), dpi=300, sharex=True, sharey=False)
        fig.suptitle(f'Event: {event} | Brain Region: {brain_reg}', fontsize=16)
        for i, stage in enumerate(['stage1','stage2','stage3']):
            for j, phase in enumerate(['pre','post']):
                ax = axes[i][j]
                key = (stage, phase)
                if key in stage_phase_signals:
                    signals = stage_phase_signals[key]
                    preprocess_and_plot_signals(
                        signals,
                        event_type=event,
                        brain_region=brain_reg,
                        smoothing_len=10,
                        title_suffix=f'{stage}, {phase}',
                        ax=ax
                    )
                    if sync_y_limits == "global":
                        if sync_scope == "per_br":
                            if brain_reg in limits_global:
                                new_min, new_max = limits_global[brain_reg]
                                ax.set_ylim(new_min, new_max)
                        else:
                            new_min, new_max = limits_global
                            ax.set_ylim(new_min, new_max)
                    elif sync_y_limits == "row":
                        if sync_scope == "per_br":
                            if brain_reg in limits_row and limits_row[brain_reg][i] is not None:
                                new_min, new_max = limits_row[brain_reg][i]
                                ax.set_ylim(new_min, new_max)
                        else:
                            if limits_row[i] is not None:
                                new_min, new_max = limits_row[i]
                                ax.set_ylim(new_min, new_max)
                else:
                    ax.axis('off')
                    ax.set_title(f'{stage}, {phase}\nNo Data', fontsize=10)
        plt.tight_layout(rect=[0,0.03,1,0.95])
        plot_filename = os.path.join(folder_name, f'plot_event_{event}_brain_region_{brain_reg}_mice_agg')
        fig.savefig(f'{plot_filename}.png', format='png')
        fig.savefig(f'{plot_filename}.pdf', format='pdf')
        plt.close(fig)

# --- Main Loop ---
sync_y_limits_options = ['global', 'row']
agg_type_options = ['events']
subset_first_stage_options = [False, True]
sync_scope_options = ['per_br', 'all']

flag_combinations = product(sync_y_limits_options, agg_type_options, subset_first_stage_options, sync_scope_options)
total = (len(sync_y_limits_options) * len(agg_type_options) *
         len(subset_first_stage_options) * len(sync_scope_options))

for sync_y_limits, agg_type, subset_first_stage, sync_scope in tqdm(flag_combinations, total=total):
    folder_name = f"plots_sync-{sync_y_limits}_agg-{agg_type}_subset-{subset_first_stage}_scope-{sync_scope}_no_first_ci_3_stages"
    os.makedirs(folder_name, exist_ok=True)
    signals_grouped = create_signals_grouped(agg_type, subset_first_stage)

    agg_global, agg_row = extract_y_bounds(signals_grouped, sync_y_limits, sync_scope)
    limits_global, limits_row = compute_aggregated_limits(agg_global, agg_row, sync_y_limits, sync_scope)
    replot_figures(signals_grouped, limits_global, limits_row, sync_y_limits, sync_scope, folder_name)

  0%|          | 0/8 [00:00<?, ?it/s]

stage1 pre ['69', '71', '73', '75', '77', '85', '79', '87']
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 1
stage1 post ['75', '77', '69', '79', '73', '87', '85', '173', '129', '159', '135', '137', '171', '133', '127', '143', '131', '161', '145', '147', '139', '157', '125']
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 1
[INFO] stage1: 7 common mouse IDs across phases.
[INFO] stage1 - pre: 7 sessions after filtering.
[INFO] stage1 - post: 7 sessions after filtering.
stage2 pre ['79', '85', '75', '77', '69', '87', '173', '133', '127', '143', '131', '125', '143', '131', '157', '171']
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 0
stage2 post ['79', '71', '73', '71', '73', '87', '69', '85', '75', '147', '139', '145', '157', '161', '159', '125', '171', '127', '143', '131', '133', '135', '137', '173']
Valid genotypes found: {'TH-Cre', 'Wildtype'}
Genotype changes made: 2
[INFO] stage2: 13 common mouse IDs across phases.
[

In [None]:
from itertools import product
import os

# Define the options for the flags
sync_y_limits_options = ['global', 'row']
agg_type_options = ['mice', 'events']
subset_first_stage_options = [True, False]

# Generate all combinations of flags
flag_combinations = product(sync_y_limits_options, agg_type_options, subset_first_stage_options)

# Iterate over all combinations
for sync_y_limits, agg_type, subset_first_stage in flag_combinations:
    # Set folder name based on flags
    folder_name = f"plots_sync-{sync_y_limits}_agg-{agg_type}_subset-{subset_first_stage}_no_first_ci_3_stages"
    os.makedirs(folder_name, exist_ok=True)  # Create folder if it doesn't exist

    # Create and Plot Figures
    signals_grouped = create_signals_grouped(agg_type, subset_first_stage)
    for (event, brain_reg), stage_phase_signals in signals_grouped.items():
        # Create a figure with 3 rows (stages) and 2 columns (phases)
        fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 15), dpi=300, sharex=True, sharey=False)
        fig.suptitle(f'Event: {event} | Brain Region: {brain_reg}', fontsize=16)

        # Initialize lists to store y-limits
        if sync_y_limits == 'global':
            all_y_mins = []
            all_y_maxs = []
        elif sync_y_limits == 'row':
            # Only 3 rows now instead of 4
            row_y_mins = [[] for _ in range(3)]
            row_y_maxs = [[] for _ in range(3)]
        else:
            raise ValueError("sync_y_limits must be either 'global' or 'row'")

        # Iterate through the first 3 stages and 2 phases
        for i, stage in enumerate(['stage1', 'stage2', 'stage3']):
            for j, phase in enumerate(['pre', 'post']):
                ax = axes[i][j]
                key = (stage, phase)
                if key in stage_phase_signals:
                    all_signals = stage_phase_signals[key]
                    title_suffix = f'{stage}, {phase}'
                    preprocess_and_plot_signals(
                        all_signals,
                        event_type=event,
                        brain_region=brain_reg,
                        smoothing_len=10,
                        title_suffix=title_suffix,
                        ax=ax
                    )
                    # Determine data bounds
                    if i == 0 and j == 0 and ax.lines:
                        # For the upper left subplot: use only the trace (mean signal)
                        trace_y = ax.lines[0].get_ydata()
                        data_min, data_max = np.min(trace_y), np.max(trace_y)
                    else:
                        # For the other subplots: use the CI fill_between polygon
                        if ax.collections:
                            # Assume the first collection corresponds to the CI
                            poly = ax.collections[0]
                            vertices = poly.get_paths()[0].vertices  # Nx2 array
                            y_values = vertices[:, 1]
                            data_min, data_max = np.min(y_values), np.max(y_values)
                        else:
                            # Fallback in case the collection is missing
                            data_min, data_max = ax.get_ylim()

                    # Append computed data bounds to the appropriate lists
                    if sync_y_limits == 'global':
                        all_y_mins.append(data_min)
                        all_y_maxs.append(data_max)
                    elif sync_y_limits == 'row':
                        row_y_mins[i].append(data_min)
                        row_y_maxs[i].append(data_max)
                else:
                    ax.axis('off')  # Hide subplot if no data
                    ax.set_title(f'{stage}, {phase}\nNo Data', fontsize=10)

        # Synchronize Y-Axis Limits
        if sync_y_limits == 'global':
            if all_y_mins and all_y_maxs:
                global_min = min(all_y_mins)
                global_max = max(all_y_maxs)
                # Apply an expansion margin (2% of the total range)
                margin = (global_max - global_min) * 0.02
                new_min = global_min - margin
                new_max = global_max + margin
                for ax in axes.flat:
                    if ax.has_data():
                        ax.set_ylim(new_min, new_max)
        elif sync_y_limits == 'row':
            for row in range(3):  # Only 3 rows now
                if row_y_mins[row] and row_y_maxs[row]:
                    row_min = min(row_y_mins[row])
                    row_max = max(row_y_maxs[row])
                    # Apply an expansion margin to the row's data range (2% of the range)
                    margin = (row_max - row_min) * 0.02
                    new_min = row_min - margin
                    new_max = row_max + margin
                    for col in range(2):
                        ax = axes[row][col]
                        if ax.has_data():
                            ax.set_ylim(new_min, new_max)

        # Finalize and Save the Figure
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to accommodate the suptitle
        plot_filename = os.path.join(folder_name, f'plot_event_{event}_brain_region_{brain_reg}_mice_agg')
        # fig.savefig(f'{plot_filename}.png', format='png')
        # fig.savefig(f'{plot_filename}.pdf', format='pdf')
        plt.show()
        plt.close(fig)  # Close the figure to free memory

In [None]:
def session_br_event_count(session, br, event):
    total = 0
    maybe_signal_info = session.signal_info.get((br, event))
    if maybe_signal_info:
        total += maybe_signal_info['signal_matrix'].shape[0]
    return total


In [None]:
# Step 1: Load and Group Signals
signals_grouped = defaultdict(dict)

for stage in ['stage1', 'stage2', 'stage3', 'stage4']:
    for phase in ['pre', 'post']:
        pickle_name = f'sessions_two_{stage}_{phase}'
        curr_sessions = load_and_prepare_sessions(
            "../../../../Gq-DREADD_CPT_Training_Stages",
            load_from_pickle=True,
            remove_bad_signal_sessions=True,
            pickle_name=pickle_name
        )
        print(stage, phase, len(curr_sessions))
        # print(stage, phase, [(s.mouse_id, session_br_event_count(s, 'mPFC', 'dispimg')) for s in curr_sessions])
                              #if session_br_event_count(s, 'mPFC', 'dispimg') < 100])
        
        # # Filter sessions by mouse IDs
        # if stage == 'stage1':
        #     curr_sessions = [s for s in curr_sessions if s.mouse_id in first_mice_ids]
        
        # for event in all_event_types:
        #     for brain_reg in all_brain_regions:
        #         key = (event, brain_reg)
        #         if sample_signals_and_metrics(curr_sessions, event, brain_reg, weight_method='mice'):
        #             all_signals = sample_signals_and_metrics(curr_sessions, event, brain_reg, weight_method='mice')[0]
        #             signals_grouped[key][(stage, phase)] = all_signals