In [10]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import h5py
import re

from collections import defaultdict
from scipy.stats import ttest_ind

FPS = 60

## Load data

In [11]:
# Path to the .h5 file
file_path = '../syllable_analysis_julia_11_vids/results.h5'

# Initialize a dictionary to hold dfs for each Mouse ID
include_latent_state = True

# Using defaultdict in case there are multiple groups per Mouse ID
dfs = defaultdict(list)

# Regular expression pattern to extract Mouse ID
# Assumes Mouse ID is the number before 'DLC' in the group name
mouse_id_pattern = re.compile(r'_(\d+)DLC_')

with h5py.File(file_path, 'r') as file:
    for group_name in file.keys():
        # Extract Mouse ID using regex
        match = mouse_id_pattern.search(group_name)
        if match:
            mouse_id = match.group(1)
        else:
            # Handle cases where Mouse ID is not found
            print(f"Warning: Mouse ID not found in group name '{group_name}'. Skipping this group.")
            continue  # Skip this group
        
        # Check if Mouse ID already exists in the dictionary (this would indicate an error)
        if mouse_id in dfs:
            raise ValueError(f"Error: Multiple groups found for Mouse ID '{mouse_id}' in the file. Only one group per Mouse ID is expected.")
        
        group = file[group_name]
        
        # List to hold individual DataFrames for each dataset
        df_list = []
        
        for dataset_name in group.keys():
            # Exclude 'latent_state' datasets if the flag is False
            if not include_latent_state and dataset_name.startswith('latent_state'):
                continue
            
            dataset = group[dataset_name][:]
            
            # Check if dataset is at least 1D
            if dataset.ndim == 1:
                # Convert to 2D array with one column
                dataset = dataset.reshape(-1, 1)
            
            # Create column names by appending index to dataset name
            # Example: 'centroid_0', 'centroid_1', ...
            num_cols = dataset.shape[1]
            columns = [f"{dataset_name}_{i}" for i in range(num_cols)]
            
            # Convert dataset to DataFrame
            df = pd.DataFrame(dataset, columns=columns)
            df_list.append(df)
        
        if df_list:
            # Concatenate all DataFrames horizontally (axis=1)
            concatenated_df = pd.concat(df_list, axis=1)
            # Add to the dictionary with the Mouse ID as the key
            dfs[mouse_id] = concatenated_df
        else:
            print(f"Warning: No datasets found in group '{group_name}'.")

## Prepare data

In [12]:
for df in dfs.values():
    df['timestamp'] = df.index / FPS
    df['current_minute'] = df['timestamp'] / 60

In [13]:
syllable_info = pd.read_csv('../syllable_analysis_julia_11_vids/syllable,behavior_group_dim4.csv', index_col=0)
#pd.read_csv('../syllable_analysis_julia/syllable,behavior.csv', index_col=0)
#syllable_info.loc[41] = 'faulty'

syllable_map = syllable_info.to_dict()['behavior']

mouse_info = pd.read_csv('../syllable_analysis_julia_11_vids/syllables_mouseCoh1_info.csv', index_col=0)

In [14]:
for df in dfs.values():
    df['syllable_name'] = df['syllable_0'].map(syllable_map)
    df['syllable_name'] = df['syllable_name'].fillna('misc')  # Replace NaN with 'faulty'

## Split by injection

In [15]:
import numpy as np
import pandas as pd
from collections import defaultdict

def bin_df_by_injection_sliding_window_signed(df, mouse_id, window_size, stride):
    """
    Bins the dataframe using a sliding window approach separately for pre-injection and post-injection periods,
    assigning negative bin numbers to pre-injection windows and positive to post-injection windows.

    Parameters:
    - df (pd.DataFrame): The dataframe containing the data.
    - mouse_id (int or str): The ID of the mouse.
    - window_size (int): Size of the sliding window in minutes.
    - stride (int): Stride length for the sliding window in minutes.

    Returns:
    - dict: A dictionary with bin labels as keys and binned dataframes as values.
    """
    # Retrieve injection times from mouse_info DataFrame
    injection_times = mouse_info.loc[mouse_id][['PreInjStart', 'PreInjEnd', 'PostInjStart', 'PostInjEnd']].to_list()
    pre_inj_start, pre_inj_end, post_inj_start, post_inj_end = injection_times

    # Initialize the dictionary to store binned dataframes
    binned_dfs = {}

    # Initialize bin counters
    bin_counter_pre = -1  # Start from -1 for pre-injection bins
    bin_counter_post = 1  # Start from 1 for post-injection bins

    ### Binning Pre-Injection Data ###
    current_start = pre_inj_start
    current_end = current_start + window_size

    while current_end <= pre_inj_end:
        # Define the window
        window_df = df[(df['current_minute'] >= current_start) & (df['current_minute'] < current_end)]
        bin_label = f'bin_{bin_counter_pre}'
        binned_dfs[bin_label] = window_df

        # Update counters and window positions
        bin_counter_pre -= 1
        current_start += stride
        current_end = current_start + window_size

    ### Binning Post-Injection Data ###
    current_start = post_inj_start
    current_end = current_start + window_size

    while current_end <= post_inj_end:
        # Define the window
        window_df = df[(df['current_minute'] >= current_start) & (df['current_minute'] < current_end)]
        bin_label = f'bin_{bin_counter_post}'
        binned_dfs[bin_label] = window_df

        # Update counters and window positions
        bin_counter_post += 1
        current_start += stride
        current_end = current_start + window_size

    return binned_dfs


In [16]:
# Initialize a defaultdict to store binned dataframes
all_centroids = {}
window_size = 5
stride = 2

for mouse_id, df in dfs.items():
    # Apply the binning function
    binned_dfs = bin_df_by_injection_sliding_window_signed(df, int(mouse_id), window_size, stride)
    
    # Get genotype information
    genotype = mouse_info.loc[int(mouse_id)]['Genotype']  # Adjust column name as needed
    
    for bin_label, bin_df in binned_dfs.items():
        bin_number = int(bin_label.split('_')[1])
        
        # Select columns that start with 'latent'
        latent_cols = [col for col in bin_df.columns if col.startswith('latent')]
        
        # Check if there are latent columns
        if latent_cols:
            # Optionally sort the columns to maintain consistent order
            latent_cols.sort()
            # Compute the mean of latent columns
            latent_mean = bin_df[latent_cols].mean().values
            #all_centroids[(bin_number, genotype)].append(latent_mean)
            all_centroids[(bin_number, genotype)] = latent_mean
        else:
            # Handle bins without latent columns if necessary
            print(f"No latent columns in bin {bin_label} for mouse {mouse_id}")
            continue  # Skip this bin

all_centroids = {k: np.vstack(v) for k, v in all_centroids.items()}

In [17]:
from sklearn.decomposition import PCA  # Import PCA

# Initialize a defaultdict to store 2D binned dataframes
all_centroids_2d = {}

# Extract all centroid vectors and corresponding keys
keys = list(all_centroids.keys())
centroids_4d = np.array([all_centroids[key] for key in keys])

# Apply PCA to reduce dimensions from 4D to 2D
pca = PCA(n_components=2)
centroids_2d = pca.fit_transform(centroids_4d)

# Populate the new 2D centroids dictionary
for i, key in enumerate(keys):
    all_centroids_2d[key] = centroids_2d[i]

ValueError: Found array with dim 3. PCA expected <= 2.

In [70]:
# Define marker styles for genotypes
genotype_markers = {'WT': 'o', 'Het': 's'}  # Customize as needed

# Organize data by genotype
data_by_genotype = defaultdict(list)
bin_numbers_by_genotype = defaultdict(list)

for (bin_number, genotype), coords in all_centroids_2d.items():
    data_by_genotype[genotype].append(coords)
    bin_numbers_by_genotype[genotype].append(bin_number)

In [95]:
from matplotlib.lines import Line2D
from matplotlib.colors import Normalize, TwoSlopeNorm

# Step 4: Define the Plotting Function with TwoSlopeNorm
def plot_smoothed_trajectories_from_centroids(
    data_by_genotype,
    bin_numbers_by_genotype,
    genotype_markers,
    smoothing=True,
    window_size=3
):
    """
    Plots smoothed trajectories for each genotype using a rolling mean.
    
    Parameters:
    - data_by_genotype (dict): Dictionary mapping genotype to list of 2D centroid coordinates.
    - bin_numbers_by_genotype (dict): Dictionary mapping genotype to list of bin numbers.
    - genotype_markers (dict): Dictionary mapping genotypes to marker styles.
    - smoothing (bool): Whether to apply rolling mean smoothing.
    - window_size (int): The window size for the rolling mean.
    """
    # Step 4.1: Determine the Maximum Absolute Bin Number
    max_abs_bin = 0
    for genotype, bin_numbers in bin_numbers_by_genotype.items():
        if bin_numbers:
            min_bin = min(bin_numbers)
            max_bin = max(bin_numbers)
            max_abs_bin = max(max_abs_bin, abs(min_bin), abs(max_bin))
    
    # Ensure that max_abs_bin is at least 1 to avoid division by zero
    max_abs_bin = max(max_abs_bin, 1)
    
    # Step 4.2: Set Up the Colormap and Normalization
    cmap = plt.get_cmap('seismic')  # Using a diverging colormap
    norm = TwoSlopeNorm(vmin=-max_abs_bin, vcenter=-window_size/2, vmax=max_abs_bin)
    sm_color = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm_color.set_array([])
    
    fig, ax = plt.subplots(figsize=(12, 8))
    
    for genotype in data_by_genotype:
        coords_list = data_by_genotype[genotype]
        bin_numbers = bin_numbers_by_genotype[genotype]
        
        # Sort the data by bin_number to maintain chronological order
        sorted_indices = np.argsort(bin_numbers)
        sorted_bin_numbers = np.array(bin_numbers)[sorted_indices]
        sorted_coords = np.array(coords_list)[sorted_indices]
        
        x = sorted_coords[:, 0]
        y = sorted_coords[:, 1]
        
        # Apply rolling mean smoothing if required
        if smoothing and len(x) >= window_size:
            df_coords = pd.DataFrame({'x': x, 'y': y})
            df_coords['x_smooth'] = df_coords['x'].rolling(
                window=window_size, center=True, min_periods=1
            ).mean()
            df_coords['y_smooth'] = df_coords['y'].rolling(
                window=window_size, center=True, min_periods=1
            ).mean()
            x_smooth = df_coords['x_smooth'].values
            y_smooth = df_coords['y_smooth'].values
        else:
            x_smooth, y_smooth = x, y
        
        # Plot trajectory lines with color gradient based on bin numbers
        for i in range(len(x_smooth) - 1):
            ax.plot(
                x_smooth[i:i+2],
                y_smooth[i:i+2],
                color=cmap(norm(sorted_bin_numbers[i])),
                linewidth=2
            )
        
        # Plot markers with colors based on bin numbers
        scatter = ax.scatter(
            x_smooth,
            y_smooth,
            c=sorted_bin_numbers,
            cmap=cmap,
            norm=norm,
            marker=genotype_markers.get(genotype, 'o'),
            s=100,
            edgecolors='k',
            linewidths=0.5,
            alpha=0.9,
            label=genotype
        )
    
    # Step 4.3: Create Custom Legend for Genotypes
    legend_elements = [
        Line2D(
            [0], [0],
            marker=marker,
            color='w',
            label=genotype,
            markerfacecolor='black',
            markersize=12
        )
        for genotype, marker in genotype_markers.items()
    ]
    ax.legend(handles=legend_elements, title='Genotypes', loc='best')
    
    # Step 4.4: Add Colorbar
    cbar = plt.colorbar(sm_color, ax=ax)
    cbar.set_label('Bin Number', fontsize=12)
    
    # Step 4.5: Label Axes and Set Title
    plt.xlabel('Principal Component 1', fontsize=14)
    plt.ylabel('Principal Component 2', fontsize=14)
    plt.title(
        'Smoothed Trajectories from Centroids with Rolling Mean\n'
        'Colored by Bin Number and Styled by Genotype',
        fontsize=16
    )
    
    plt.tight_layout()
    plt.show()

In [None]:
# Call the plotting function
plot_smoothed_trajectories_from_centroids(
    data_by_genotype=data_by_genotype,
    bin_numbers_by_genotype=bin_numbers_by_genotype,
    genotype_markers=genotype_markers,
    smoothing=True,
    window_size=15  # Adjust window_size as needed
)
