In [1]:
%load_ext autoreload
%autoreload 2
%aimport - pandas

In [None]:
import sys
sys.path.append('../../')
sys.path.append('../../../scripts')

from main import load_and_prepare_sessions
from processing.session_sampling import MiceAnalysis
from processing.plotting_setup import PlottingSetup
from analysis.timepoint_analysis import find_drug_split_x
from itertools import product

import numpy as np
import scipy.stats as stats 
import matplotlib.pyplot as plt

import config

from collections import defaultdict
import pickle



sessions = load_and_prepare_sessions("../../../trial_Gq-DREADD_CPT", load_from_pickle=True, remove_bad_signal_sessions=True)

# from processing.plotting_setup import PlottingSetup
# plotting_setup = PlottingSetup(**config.PLOTTING_CONFIG['cpt'])
# plotting_setup.apply_plotting_setup_to_sessions(sessions)

In [3]:
# brain_reg_to_color = {'VS': 'purple',
#                       'DMS': 'forestgreen',
#                       'DLS': 'C0'}

brain_reg_to_color = {'LH': 'orange',
                      'mPFC': 'cornflowerblue'}

In [4]:
def update_genotypes(sessions, mice_gen_dict):
    """
    Updates the genotypes for a list of sessions based on the provided genotype mapping.
    Prints whether the genotypes are valid ('TH-Cre', 'Wildtype') 
    and the number of genotype changes made.

    Parameters:
    - sessions: A list of session objects to process.
    - mice_gen_dict: A dictionary mapping mouse IDs to new genotypes.
    """
    geno_mapping = {
        "Cre": "TH-Cre",
        "WT": "Wildtype"
    }
    # Map mice_gen_dict to use TH-Cre and Wildtype
    mapped_genotypes = {k: geno_mapping[v] for k, v in mice_gen_dict.items()}
    
    # Initialize counters and trackers
    valid_genotypes = {'TH-Cre', 'Wildtype'}
    all_genotypes = set()
    genotype_changes = 0

    for session in sessions:
        original_genotype = session.genotype
        int_id = int(session.mouse_id)
        
        if int_id in mapped_genotypes:
            session.genotype = mapped_genotypes[int_id]
            # Count changes if the genotype was updated
            if session.genotype != original_genotype:
                genotype_changes += 1
        
        all_genotypes.add(session.genotype)
    
    # Print results
    if all_genotypes.issubset(valid_genotypes):
        print(f"Valid genotypes found: {all_genotypes}")
    else:
        print(f"Invalid genotypes found: {all_genotypes}")
    
    print(f"Genotype changes made: {genotype_changes}")

In [5]:
mice_gen_dict = {
    69: "Cre",
    71: "WT",
    73: "Cre",
    75: "WT",
    77: "Cre",
    79: "WT",
    85: "WT",
    87: "WT",
    135: "WT",
    137: "WT",
    139: "Cre",
    133: "WT",
    127: "WT",
    125: "WT",
    129: "Cre",
    131: "WT",
    143: "Cre",
    145: "WT",
    147: "WT",
    157: "Cre",
    159: "Cre",
    161: "WT",
    171: "Cre",
    173: "Cre"
}

In [6]:
update_genotypes(sessions, mice_gen_dict)

Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 2


In [10]:
from collections import defaultdict

# Instead of defaultdict(dict), we need a nested structure keyed by (event, brain_reg, genotype).
# The value remains a dict keyed by (stage, phase).
all_sessions = defaultdict(dict)

stages = ['stage1', 'stage2', 'stage3', 'stage4']
phases = ['pre', 'post']

for stage in stages:
    # Temporary storage for sessions in each phase
    phase_sessions = {}
    
    # Step 1: Load all sessions for the current stage across all phases
    for phase in phases:
        pickle_name = f'sessions_two_{stage}_{phase}'
        
        # Load and prepare sessions
        sessions = load_and_prepare_sessions(
            "../../../Gq-DREADD_CPT_Training_Stages",
            load_from_pickle=True,
            remove_bad_signal_sessions=True,
            pickle_name=pickle_name
        )
        
        # Update genotypes in-place
        update_genotypes(sessions, mice_gen_dict)
        phase_sessions[phase] = sessions
    
    # Step 2: Compute the sets of mouse IDs in each phase
    mouse_ids_pre  = {s.mouse_id for s in phase_sessions['pre']}
    mouse_ids_post = {s.mouse_id for s in phase_sessions['post']}
    intersection_mouse_ids = mouse_ids_pre & mouse_ids_post

    # print(f"[INFO] {stage}: {len(intersection_mouse_ids)} common mouse IDs across phases.")
    
    # ← new: compute mice only in pre or only in post, and build three session-lists
    pre_only_ids       = mouse_ids_pre - mouse_ids_post
    post_only_ids      = mouse_ids_post - mouse_ids_pre

    

    sessions_both      = [s for s in phase_sessions['pre']  + phase_sessions['post']
                          if s.mouse_id in intersection_mouse_ids]
    sessions_pre_only  = [s for s in phase_sessions['pre']
                          if s.mouse_id in pre_only_ids]
    sessions_post_only = [s for s in phase_sessions['post']
                          if s.mouse_id in post_only_ids]

    # print(f"{stage} → both: {len(sessions_both)}, pre-only: {len(sessions_pre_only)}, post-only: {len(sessions_post_only)}")
    # # ← end new

    # num_both_mice      = len(intersection_mouse_ids)
    # num_pre_only_mice  = len(pre_only_ids)
    # num_post_only_mice = len(post_only_ids)

    # print(f"{stage}: {num_both_mice} mice in both, "
    #     f"{num_pre_only_mice} only in pre, "
    #     f"{num_post_only_mice} only in post.")
    
    for s in sessions_both:
        print(f'stage{stage[-1]}, both phases:', s.trial_id, s.mouse_id)

    for s in sessions_pre_only:
        print(f'stage{stage[-1]}, pre only phases:', s.trial_id, s.mouse_id)

    for s in sessions_post_only:
        print(f'stage{stage[-1]}, post only phases:', s.trial_id, s.mouse_id)

    # Step 3: Iterate over each phase and filter sessions by the intersection of mouse IDs
    for phase in phases:
        sessions = phase_sessions[phase]
        filtered_sessions = [s for s in sessions if s.mouse_id in intersection_mouse_ids]
        all_sessions[stage][phase] = filtered_sessions

Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 1
Valid genotypes found: {'Wildtype', 'TH-Cre'}
Genotype changes made: 1
stage1, both phases: T1_69.71.73.75 69
stage1, both phases: T1_69.71.73.75 73
stage1, both phases: T1_69.71.73.75 75
stage1, both phases: T2_77.85.79.87 77
stage1, both phases: T2_77.85.79.87 85
stage1, both phases: T2_77.85.79.87 79
stage1, both phases: T2_77.85.79.87 87
stage1, both phases: T3_e.e.e.75 75
stage1, both phases: T4_77.e.e.e 77
stage1, both phases: T5_69.e.e.e 69
stage1, both phases: T6_79.e.e.e 79
stage1, both phases: T7_e.e.73.e 73
stage1, both phases: T8_e.e.e.87 87
stage1, both phases: T9_e.e.e.85 85
stage1, pre only phases: T1_69.71.73.75 71
stage1, post only phases: T10_173.e.e.e 173
stage1, post only phases: T11_e.e.129.e 129
stage1, post only phases: T12_159.e.e.e 159
stage1, post only phases: T13_135.e.e.137 135
stage1, post only phases: T13_135.e.e.137 137
stage1, post only phases: T14_171.e.e.e 171
stage1, post only phas

In [12]:
stages = ['stage1', 'stage2', 'stage3', 'stage4']
phases = ['pre', 'post']

# for stage in stages:
#     for phase in phases:
#         sessions = all_sessions[stage][phase]
#         for session in sessions:
#             for brain_region in session.brain_regions:
#                 threshold_x = find_drug_split_x(session, brain_region)
#                 brain_region = brain_region.split('_')[0]
#                 for event in config.all_event_types:
#                     curr_signal_info = session.signal_info.get((brain_region, event))
#                     if curr_signal_info is None:
#                         continue
#                     signal_idx_ranges = curr_signal_info['signal_idx_ranges']

#                     # find the middle position (0 seconds) of the given signal
#                     thresholds = np.array([((tup[0] + tup[1]) // 2) for tup in signal_idx_ranges])

#                     # if the 0 seconds position is later than threshold x, we know that the signal is past the injection peak
#                     curr_signal_info["is_above_threshold"] = thresholds > threshold_x

In [13]:
for s in all_sessions['stage4']['pre']:
    print(s.genotype)

Wildtype
TH-Cre
TH-Cre
Wildtype
TH-Cre
Wildtype


In [14]:
for s in all_sessions['stage4']['post']:
    print(s.genotype)

Wildtype
TH-Cre
TH-Cre
Wildtype
TH-Cre
Wildtype


In [15]:
all_signal_groups = defaultdict(list)

for stage in stages:
    for phase in phases:
        sessions = all_sessions[stage][phase]
        for session in sessions:            
            for event in config.all_event_types:
                for brain_region in config.all_brain_regions:
                    curr_signal_info = session.signal_info.get((brain_region, event))
                    if curr_signal_info is None:
                    # if curr_signal_info is None or not curr_signal_info.get('signal_matrix', None): # Check matrix exists
                        continue

                    # --- CHANGE 1: Get genotype ---
                    genotype = session.genotype
                    # --- CHANGE 2: Add genotype to the key ---
                    curr_key = (stage, phase, event, brain_region)
                    all_signal_groups[curr_key].append(curr_signal_info['signal_matrix'])

# Do the final conversion **after** all loops finish
#all_signal_groups = {k: np.vstack(v) for k, v in all_signal_groups.items() if len(v) != 0}
all_signal_groups = {k: np.vstack(v) for k, v in all_signal_groups.items() if v} # Use if v for check
all_signals = {k: v for k, v in all_signal_groups.items() if len(v) != 0}


In [16]:
def preprocess_and_plot_signals(key, dict, smoothing_len=10):
    # Assuming all_signals is predefined
    # signals = all_signals[(event_type, brain_region)]
    signals = dict[key]
    #stage, phase, event, brain_region = key
    # --- CHANGE 3: Unpack genotype from key ---
    stage, phase, event, brain_region = key
 
    interval_start = config.peak_interval_config["interval_start"]
    interval_end = config.peak_interval_config["interval_end"]
    fps = config.PLOTTING_CONFIG['cpt']['fps']
    
    xs = np.arange(-interval_start, interval_end) / fps
    
    # Smooth the mean signal
    ys = np.mean(signals, axis=0)
    window = np.ones(smoothing_len) / smoothing_len
    ys = np.convolve(ys, window, 'same')

    # Calculate the standard deviation of the mean
    std_signal = np.std(signals, axis=0) / np.sqrt(len(signals))

    # Use scipy.stats.norm.interval to get the 95% confidence interval
    alpha = 0.95
    ci_lower, ci_upper = stats.norm.interval(alpha, loc=ys, scale=std_signal)

    # The lower and upper bounds
    lb = ci_lower.min()
    ub = ci_upper.max()

    # Check if ylim is valid
    if not np.isfinite(lb) or not np.isfinite(ub) or lb == ub:
        print(f"[WARNING] Invalid ylim ({lb}, {ub}) for key: {key}. Skipping plot.")
        return

    ylim = (lb, ub)
    
    # Assuming brain_reg_to_color is predefined
    color = brain_reg_to_color[brain_region]

    plt.figure(dpi=300)
    plt.plot(xs, ys, color=color, label='Mean Signal')
    plt.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, label='95% CI')
    
    # if brain_region == 'LH':
    #     plt.ylim(-.3, 0.5)
    # else:
    #     plt.ylim(-0.25, 0.35)
    ylim = (-0.5, 0.5)
    # ylim = (-0.4, 0.4)
    plt.ylim(ylim)

    plt.title(f"""stage: {stage}, phase: {phase},
              event: {event}, brain region: {brain_region}
              (n={signals.shape[0]})""")
    plt.xlabel('Time (s)')
    plt.ylabel('z-score')
    plt.legend()
    plt.grid()
    
    plt.tight_layout()
    # remove dose, genotype and threshold_text
    # plt.savefig(f"{stage}_{phase}_{genotype}_{dose}_{event}_{brain_region}_{threshold_text}_stages.pdf")
    # plt.savefig(f"{stage}_{phase}_{event}_{brain_region}_stages.pdf")
    # plt.savefig(f"{stage}_{phase}_{event}_{brain_region}_stages.png")
    plt.close('all')

    # Save the figure locall
    # plt.show()

# Example usage
# Assuming sessions, config, all_signals, brain_reg_to_color are defined
for k in all_signals.keys():
    stage, phase, event, brain_region = k
    if stage == 'stage4':
        preprocess_and_plot_signals(k, all_signals, smoothing_len=10)

In [17]:
# Define colors for genotypes
genotype_colors = {
    'Wildtype': 'blue',
    'TH-Cre': 'red'
    # Add other genotypes if necessary
}

In [18]:
# # Necessary Imports (ensure these are run)
# import matplotlib.pyplot as plt
# import numpy as np
# import config # Ensure config is imported and accessible
# # import scipy.stats as stats # Optional, needed only if using stats.norm.interval

# # --- Assumed to be defined in previous cells ---
# # all_sessions: The defaultdict(dict) from your loading code.
# # genotype_colors = {'Wildtype': 'blue', 'TH-Cre': 'red'}
# # config: Your configuration module/object.
# # -------------------------------------------------

# # --- Helper Function to Plot a Single Session ---
# def plot_single_session_trace(session, target_event, target_region, xs, smoothing_len=10, genotype_colors=None):
#     """
#     Calculates and plots the average trace + CI for a SINGLE session
#     in its own figure. Includes Mouse ID and Genotype in title.

#     Args:
#         session (object): The session object containing signal_info, genotype, mouse_id etc.
#         target_event (str): The event type to plot.
#         target_region (str): The brain region to plot.
#         xs (np.ndarray): The time axis array.
#         smoothing_len (int): Length of the moving average window.
#         genotype_colors (dict, optional): Mapping from genotype to color.

#     Returns:
#         bool: True if a plot was generated, False otherwise.
#     """
#     # --- Get Session Info ---
#     genotype = getattr(session, 'genotype', 'Unknown')
#     mouse_id = getattr(session, 'mouse_id', 'UnknownMouse')
#     # Create a unique session identifier for titles if session_id isn't directly available
#     session_identifier = getattr(session, 'session_id', f'Mouse_{mouse_id}')
#     color = genotype_colors.get(genotype, 'gray') if genotype_colors else 'gray'

#     # --- Get Signal Info ---
#     # !!! IMPORTANT: Adjust key if channel ('G') is needed: signal_key = (target_region, 'G', target_event)
#     signal_key = (target_region, target_event)
#     curr_signal_info = session.signal_info.get(signal_key)
#     if curr_signal_info is None: return False # Skip if no signal info

#     signal_matrix = curr_signal_info.get('signal_matrix')
#     if signal_matrix is None or not isinstance(signal_matrix, np.ndarray) or signal_matrix.size == 0:
#         return False # Skip if no valid matrix

#     n_trials = signal_matrix.shape[0]
#     if n_trials == 0: return False

#     # --- Calculate Mean ---
#     ys_mean = np.mean(signal_matrix, axis=0)
#     if len(ys_mean) != len(xs):
#          print(f"WARNING: Session {session_identifier} length mismatch. Skipping plot.")
#          return False

#     # --- Calculate CI (approximate using SEM * 1.96) ---
#     ci_lower, ci_upper = ys_mean, ys_mean # Default if n_trials < 2 or SEM is zero
#     ci_calculated = False
#     if n_trials >= 2:
#         # Use ddof=1 for sample standard deviation in SEM calculation
#         ys_sem = np.std(signal_matrix, axis=0, ddof=1) / np.sqrt(n_trials)
#         # Ensure SEM is finite (replace NaN/inf with 0)
#         ys_sem = np.nan_to_num(ys_sem, nan=0.0, posinf=0.0, neginf=0.0)
#         # Check if the CI width will be non-negligible
#         if not np.all(np.isclose(ys_sem * 1.96, 0, atol=1e-8)):
#             ci_lower = ys_mean - 1.96 * ys_sem
#             ci_upper = ys_mean + 1.96 * ys_sem
#             ci_calculated = True
#         # else: CI remains zero width if SEM is zero

#     # --- Smoothing (apply to mean for line plot) ---
#     window = np.ones(smoothing_len) / smoothing_len
#     ys_smooth = ys_mean
#     if len(ys_mean) >= smoothing_len:
#         ys_smooth = np.convolve(ys_mean, window, 'same')

#     # --- Plotting ---
#     plt.figure(dpi=150, figsize=(8, 5)) # New figure for each session

#     # Plot CI region first
#     if ci_calculated:
#         plt.fill_between(xs, ci_lower, ci_upper, color=color, alpha=0.2, linewidth=0, label='95% CI (approx.)')

#     # Plot smoothed mean line
#     plt.plot(xs, ys_smooth, color=color, linewidth=1.2, label=f'Mean (n={n_trials} trials)')

#     # Create title including Mouse ID and Genotype
#     title = f"Mouse: {mouse_id} ({genotype}) - Session: {session_identifier}\n" \
#             f"Stage3 / Post - Event: {target_event}, Region: {target_region}"
#     plt.title(title)
#     plt.xlabel('Time relative to event (s)')
#     plt.ylabel('z-score')
#     plt.ylim(-0.8, 0.8)
#     plt.axhline(0, color='grey', linestyle='--', linewidth=0.7)
#     plt.axvline(0, color='grey', linestyle='--', linewidth=0.7)
#     plt.grid(True, which='both', linestyle=':', linewidth=0.5)
#     plt.legend(loc='upper right')
#     plt.tight_layout()

#     plt.savefig(f"single_mice_stage3_post_{mouse_id}_{genotype}_sim_yscale.pdf")
#     plt.savefig(f"single_mice_stage3_post_{mouse_id}_{genotype}_sim_yscale.png")
#     plt.show() # Show plot immediately

#     return True # Indicate plot was generated


# # --- Main Plotting Execution ---

# # Define target parameters for the plots
# plot_target_event = 'hit'     # !!! REPLACE with your actual event key if different
# plot_target_region = 'mPFC'   # !!! REPLACE with your actual region key if different
# smoothing_len_single = 10     # Adjust smoothing if desired

# print(f"\n--- Generating Individual Plot Per Session for Stage3 / Post ---")
# print(f"Target Event: '{plot_target_event}', Target Region: '{plot_target_region}'")

# # --- Calculate time axis (xs) once ---
# try:
#     interval_start = config.peak_interval_config["interval_start"]
#     interval_end = config.peak_interval_config["interval_end"]
#     fps = config.PLOTTING_CONFIG['cpt']['fps']
#     xs_single = np.arange(-interval_start, interval_end) / fps
#     print(f"Time axis generated: {-interval_start/fps:.2f}s to {interval_end/fps:.2f}s ({len(xs_single)} points)")
# except Exception as e:
#     print(f"ERROR: Could not get plotting parameters from config: {e}. Cannot generate plots.")
#     # Consider adding 'raise' here if you want the script to stop on config error
#     xs_single = None # Ensure xs_single is None if calculation fails

# # --- Check if necessary data exists ---
# if xs_single is not None: # Proceed only if time axis was generated
#     target_stage = 'stage3'
#     target_phase = 'post'

#     if target_stage not in all_sessions:
#         print(f"ERROR: Stage '{target_stage}' not found in all_sessions dictionary.")
#     elif target_phase not in all_sessions[target_stage]:
#         print(f"ERROR: Phase '{target_phase}' not found for Stage '{target_stage}'.")
#     else:
#         sessions_to_plot = all_sessions[target_stage][target_phase]
#         if not sessions_to_plot:
#             print(f"INFO: No sessions found in all_sessions['{target_stage}']['{target_phase}'].")
#         else:
#             print(f"Found {len(sessions_to_plot)} sessions for {target_stage} / {target_phase}. Generating plots...")
#             plots_generated_count = 0
#             # --- Loop through sessions and plot individually ---
#             for session_obj in sessions_to_plot:
#                 # Call the helper function for each session
#                 plotted = plot_single_session_trace(
#                     session=session_obj,
#                     target_event=plot_target_event,
#                     target_region=plot_target_region,
#                     xs=xs_single, # Pass the pre-calculated time axis
#                     smoothing_len=smoothing_len_single,
#                     genotype_colors=genotype_colors # Pass the color dict
#                 )
#                 if plotted:
#                     plots_generated_count += 1
#             print(f"\nGenerated {plots_generated_count} individual session plots.")
# else:
#     print("\nPlot generation skipped due to configuration error.")

In [19]:
# --- Individual session traces per genotype ---

# Parameters: adjust as needed
target_stage   = 'stage3'
target_phase   = 'post'
target_event   = 'before_dispimg_hit'   # ← change to your event
target_region  = 'mPFC'                 # ← change to your region
target_channel = 'G'                    # ← change to your channel

# Colors for each genotype
genotype_colors = {'TH-Cre': 'blue', 'Wildtype': 'red'}

# Build the time axis (xs)
interval_start = config.peak_interval_config["interval_start"]
interval_end   = config.peak_interval_config["interval_end"]
fps            = config.PLOTTING_CONFIG['cpt']['fps']
xs = np.arange(-interval_start, interval_end) / fps

# Get the list of sessions for the chosen stage & phase
sessions_to_plot = all_sessions[target_stage][target_phase]

# Loop over genotypes and plot all individual-session mean traces
for genotype in ['TH-Cre', 'Wildtype']:
    fig, ax = plt.subplots(dpi=150, figsize=(8,5))
    
    for session in sessions_to_plot:
        if session.genotype != genotype:
            continue
        
        sig = session.signal_info.get((target_region, target_event))
        if sig is None:
            continue
        
        mat = sig['signal_matrix']
        if mat.size == 0:
            continue
        
        # compute per-session mean and smooth
        ys = np.mean(mat, axis=0)
        ys = np.convolve(ys, np.ones(10)/10, mode='same')  # smoothing_len=10
        
        ax.plot(xs, ys, color=genotype_colors[genotype], alpha=0.3)
    
    ax.set_title(f"{genotype} — individual session traces\n"
                 f"Stage: {target_stage}, Phase: {target_phase}")
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('z‑score')
    ax.axhline(0, color='grey', linestyle='--', linewidth=0.7)
    ax.axvline(0, color='grey', linestyle='--', linewidth=0.7)
    ax.grid(True, linestyle=':', linewidth=0.5)
    ax.set_ylim(-1, 1)
    plt.tight_layout()

#plt.show()
plt.close('all')
