In [1]:
import numpy as np
import pandas as pd
from statsmodels.stats.multitest import multipletests
from importlib import reload
import pickle
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
import states_analysis_tools as sat
reload(sat)
import h5py
import acme
from acme import ParallelMap
import superlet as slt

#read in the data, set general variables
project_folder='attention_paper_dataset'
species_data = sat.read_dfs(folder=project_folder)

project='behaviour_paper'

colormap_states = ListedColormap(['#b15c83ff','#519296ff','#43dfcdff','#f394aeff'])
colormap_states_speed = ListedColormap(['#5d5f6cff','#d6d7dcff'])
colormap_states_accuracy = ListedColormap(['#f394aeff','#68d7c9ff'])

sns.set_style("ticks", {
    "axes.spines.top": False,
    "axes.spines.right": False
})

# Update tick thickness globally via Matplotlib
plt.rcParams.update({
    "axes.linewidth": 1.5,
    "xtick.major.width": 1.2,  # Thickness of major ticks on the x-axis
    "ytick.major.width": 1.2,  # Thickness of major ticks on the y-axis
    "xtick.minor.width": 1.0,  # Thickness of minor ticks on the x-axis
    "ytick.minor.width": 1.0,  # Thickness of minor ticks on the y-axis
    "xtick.major.size": 6,     # Length of major ticks on the x-axis
    "ytick.major.size": 6,     # Length of major ticks on the y-axis
    "xtick.minor.size": 4,     # Length of minor ticks on the x-axis
    "ytick.minor.size": 4,     # Length of minor ticks on the y-axis
})

In [None]:
# DATA PREPROCESSESSING
species_data = sat.read_dfs(folder=project_folder)
selected_metrics = ['Session', 'TrialDuration', 'MorphTarget', 'MorphDistractor', 'Correct', 'Wrong', 'Prec', 'Bias', 'RT', 'SpeedMean', 'PL']

for species in species_data:
    sessions = species_data[species].copy()
    sessions_list = []

    for isesh, sesh in enumerate(sessions):
        # Initial setup
        sesh = sesh.loc[:, selected_metrics].copy()
        sesh = sesh.dropna().reset_index(drop=True)
        
        # Remove learning period for humans
        if species == 'human':
            cutoff = sat.find_combined_learning_cutoff(sesh, window=10, threshold=0.5, consecutive=10)
            sesh = sesh.iloc[cutoff:].reset_index(drop=True)
        
        # Remove outlier trials based on duration and path length
        mask = (sesh['TrialDuration'] <= 10) & (sesh['PL'] >= 250)
        sesh = sesh[mask].reset_index(drop=True)
        
        # Flip RT and Bias
        sesh['RT'] = -sesh['RT']
        sesh['Bias'] = -sesh['Bias']
        
        # Convert to numeric
        sesh = sesh.apply(pd.to_numeric, errors='ignore')
        
        sessions_list.append(sesh)
    
    sat.save_dfs(sessions_list, project, species, savepath=None)

preprocessed_data = sat.read_dfs(folder=project)

In [None]:
#VARIABLES
preprocessed_data=sat.read_dfs(folder=project)

accuracy_metrics=['Correct','Bias','Prec']
speed_metrics=['RT','SpeedMean']
training_metrics=accuracy_metrics+speed_metrics

n_iter=1000
n_workers = 100
trials_per_worker = n_iter // n_workers

lower=150 #slower 6
upper=15 #faster 66
steps = 200#100# 

fs = 1000 
foi = np.linspace(fs/lower, fs/upper, steps)
tpc=fs/foi
c1 = 5
ord = (1, 10)


In [None]:
#HMM
preprocessed_data = sat.read_dfs(folder=project)
for species in species_data:
    sessions=preprocessed_data[species]
    columns, values =sat.prepare_and_train_dataset(1,sessions, accuracy_metrics, speed_metrics, shuffle=False)
    training_data = pd.DataFrame(values, columns=columns)
    sat.save_pickle(training_data,f'training_data',species)

In [None]:
#SUPERLETS
metrics_to_slt=['States_Speed','States_Accuracy']##'States_Speed'#'SpeedMean'#'Combined_Speed'#,'RT_Raw','SpeedMean_Raw','Correct_Raw'
for species in species_data:
    for metric in metrics_to_slt:    
        training_data = sat.load_pickle('training_data', species)
        sesh_id=training_data['Session'].values
        power_per_sesh=[]
        for isesh in np.unique(sesh_id):
            sesh=training_data[sesh_id==isesh].copy()
            data=sesh[metric].values
            result = slt.superlets(data[np.newaxis,:], fs, foi, c1, ord)
            power_per_sesh.append(result)
        power_per_sesh=np.concatenate(power_per_sesh,axis=1)

        sat.save_pickle(power_per_sesh,f'power_{metric}',species)  

In [None]:
#HMM - SHUFFLED

for species in species_data:
    original_data = preprocessed_data[species]  # Load the original dataset
    num_iterations = trials_per_worker * n_workers

    # Shuffle and train HMM
    client = acme.esi_cluster_setup(partition="8GBXS", n_workers=n_workers)

    # Pass a range variable just to iterate
    with ParallelMap(sat.prepare_and_train_dataset, range(num_iterations), original_data, accuracy_metrics, speed_metrics,n_inputs=num_iterations) as pmap:
        results = pmap.compute()

    sat.save_pickle(results, f'acme_results', species)

acme.cluster_cleanup()

# LOAD HMM RESULTS
for species in species_data:
    results = sat.load_pickle(f'acme_results', species)

    loaded_data = []
    for result in results:
        with h5py.File(result, 'r') as f:
            columns=list(f['result_0'])
            columns = [col.decode('utf-8')  for col in columns]
            values=list(f['result_1'])
            #columns, values = f['result_0']  # Load the training data from the result
            df = pd.DataFrame(values, columns=columns)
            loaded_data.append(df)

    sat.save_pickle(loaded_data, f'shuffled_data', species)

In [None]:
#SUPERLETS - SHUFFLED
metrics_to_slt=['States_Speed','States_Accuracy']##'States_Speed'#'SpeedMean'#'Combined_Speed'#,'RT_Raw','SpeedMean_Raw','Correct_Raw'
for species in species_data:
    for metric in metrics_to_slt:  

        data = sat.load_pickle('training_data', species)
        sesh_id=data['Session'].values
        
        #Step 2: Load shuffled matrixes
        shuffled_data=sat.load_pickle(f'shuffled_data',species)

        shuffled_slt_input=[]
        
        # select the metric for superlets
        for iteration in shuffled_data:   
            shuffled_slt_input.append(iteration[metric].values)
                
        client = acme.esi_cluster_setup(partition="8GBXS", n_workers=n_workers)

        with ParallelMap(sat.compute_superlets_per_sesh,shuffled_slt_input,sesh_id,
                        fs, foi, n_inputs=trials_per_worker * n_workers) as pmap:
            slt_results = pmap.compute()
        
        sat.save_pickle(slt_results,f'results_slt_shuffled_{metr}',species)  
    
acme.cluster_cleanup()

In [None]:
#FIGURE 1. C: EXAMPLE RUNNING PATHS  
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
def norm_bar(data, map_name, max_value):
    """Creates a normalized colormap and scalar mappable object."""
    norm = Normalize(vmin=0.4, vmax=max_value)
    cmap = plt.get_cmap(map_name)
    mappable = ScalarMappable(norm=norm, cmap=cmap)
    return cmap, mappable
def plot_selected_running_paths(species_data, selected_indices_dict, metric='Prec', color_map='flare_r'):
    """
    Plot running paths for different species with color-coded metrics using selected indices.
    """
    # Set up the figure and colormap
    max_value = 1.0
    cmap, mappable = norm_bar(None, color_map, max_value)
    
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    
    for i, (species, data) in enumerate(species_data.items()):
        # Select appropriate session data
        if species == 'human':
            sesh = data[1] 
        if species == 'mouse':
            sesh = data[10] 
        if species == 'monkey':
            sesh = data[3] 
        sesh.dropna(inplace=True)
        sesh.reset_index(drop=True, inplace=True)
        
        # Extract relevant data
        distance = sesh['Distance']
        boxsize = sesh.attrs['Metadata']['BoxSize']
        precision_data = sesh[metric]
        
        # Filter paths based on correct trials
        selected = sesh['Correct'].astype(bool)
        precision_data = precision_data[selected].reset_index(drop=True)
        paths = sesh['Location'][selected].reset_index(drop=True)
        rt = sesh['RT'][selected].reset_index(drop=True)
        
        # Get indices for this species
        try:
            species_indices = np.load(f'selected_paths_{species}.npy', allow_pickle=True)
            selected_paths = [paths[idx] for idx in species_indices]
            selected_precision = [precision_data[idx] for idx in species_indices]
            selected_rt = [rt[idx] for idx in species_indices]
            selected_distance = [distance[idx] for idx in species_indices]
        except:
            print(f"No selected paths found for {species}")
            continue
            
        # Print diagnostic information
        print(f"{species}: plotting {len(selected_paths)} selected paths")
        print(f"{species} precision range: {min(selected_precision):.3f} to {max(selected_precision):.3f}")
        
        # Plot paths
        for ipath, (path, prec, rt_idx, dist) in enumerate(zip(selected_paths, selected_precision, selected_rt, selected_distance)):
            # Find point nearest to target
            x_near_targ = np.argmin(abs(path[:, 0] - dist - boxsize / 2))
            
            # Calculate coordinates based on species
            if species == 'mouse':
                y = path[:x_near_targ, 0] - path[:, 0][0]
                x = path[:x_near_targ, 1] - path[:, 1][0]
            else:
                y = path[:x_near_targ, 0]
                x = path[:x_near_targ, 1]
            
            # Plot path with normalized color
            normalized_value = mappable.norm(prec)
            axs[i].plot(x, y, color=cmap(normalized_value))
            
            # Plot reaction time point
            axs[i].plot(
                x[rt_idx],
                y[rt_idx],
                marker='o',
                zorder=3,
                markerfacecolor='None',
                markeredgecolor='k'
            )
        
        # Set plot limits and title
        axs[i].set_xlim([-300, 300])
        axs[i].set_ylim([-20, 600] if species == 'mouse' else [-20, 450])
        axs[i].set_title(f"{species.capitalize()} (n={len(selected_paths)})")
    
    # Add colorbar
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    fig.colorbar(mappable, cax=cbar_ax, label=metric)
    
    plt.tight_layout(rect=[0, 0, 0.9, 1])
    sat.save_plot('cross_species',project, filename=f"example_paths_{metric}.svg")
    plt.show()

# Usage:
plot_selected_running_paths(species_data, None)  # None for selected_indices_dict as we're loading from files

In [None]:
#FIGURE 2. A: METRICS CORRELATION MATRIX
all_species_data = []

for species in species_data:
    data = sat.load_pickle('training_data', species)
    all_species_data.append(data[training_metrics])

# Combine data for all species
combined_data = pd.concat(all_species_data, ignore_index=True)

# Compute correlation matrix on the combined data
combined_corr_matrix = combined_data.corr()

# Plot the correlation heatmap for combined data
plt.figure(figsize=(6, 6))  # Adjust size as needed
sns.heatmap(combined_corr_matrix, annot=False, cmap="BuPu", vmin=-0.1, vmax=1)
plt.title('Correlation Heatmap - Combined Species')
plt.tight_layout()

# Save plot if necessary
sat.save_plot('cross_species', project, filename="metrics_correlation_heatmap_combined.svg")

plt.show()

In [None]:
#FIGURE 2. B:  HIT RATE vs. RT
sesh_counts=[]
for i, species in enumerate(['human','monkey','mouse']):
    data=sat.load_pickle(f'training_data',species)
    data['RT']=data['RT']#*-1
    plt.figure(figsize=(3, 3))
    x='Correct'
    y='RT'    
    plt.axhline(0, color='black', linewidth=0.8, zorder=-1)
    plt.axvline(0, color='black', linewidth=0.8, zorder=-1)
    sns.scatterplot(data=data,x=y,y=x,hue=data['States'],palette=colormap_states, alpha=1, legend=False)
    plt.xlim([-6.0,6])
    plt.ylim([-5.5,2.5])

    plt.title(f'{x} vs {y}, {species}')
    sat.save_plot('cross_species', project, filename=f"cloud_hi_vs_rt_{species}.svg")
    plt.show()
    

In [None]:
#FIGURE 2. C: STATE MEANS
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
for idx, species in enumerate(species_data):
    data = sat.load_pickle('training_data', species)
    states = data['States']

    for state in range(4):
        mean_values = np.mean(data[training_metrics][states == state], axis=0)
        error_values = np.std(data[training_metrics][states == state], axis=0)
        x_values = range(len(mean_values))  

        # Use errorbar instead of plot to include error bars
        axs[idx].errorbar(x_values, mean_values, yerr=error_values, label=f'State {state}', fmt='-o',elinewidth=1,
            capsize=3,
            linewidth=3,  # Thicker line
            markersize=7, color=colormap_states.colors[state])
        
    axs[idx].set_ylim([-2.5, 2.5])
    axs[idx].set_xticks(x_values)
    axs[idx].set_xticklabels(training_metrics)
    axs[idx].set_title(species)
    axs[idx].axhline(0, color='black')

plt.tight_layout()
sat.save_plot('cross_species',project,filename=f"means_species.svg")
plt.show()

In [None]:
#FIGURE 3. A: STATES TIMESERIES
for i, species in enumerate(['human','monkey','mouse']):
    data=sat.load_pickle(f'training_data',species)
    sat.plot_states_timeseries(data,colormap_states, species)
    sat.save_plot('cross_species', project, filename=f"example_sessions_{species}.svg")

In [None]:
#FIGURE 3. B: STATES PROPORTION
# Group by session and calculate the frequency of each state in each session
sesh_counts=[]
for i, species in enumerate(['human','monkey','mouse']):
    data=sat.load_pickle(f'training_data',species)
    state_session_counts = data['States'].value_counts(normalize=True).sort_index()#.unstack(fill_value=0)
    sesh_counts.append(state_session_counts)
all_sp_counts = pd.concat(sesh_counts,axis=1).T
# Plotting the distribution of states across sessions
plt.figure(figsize=(6, 5))
all_sp_counts.plot(kind='bar', stacked=True, colormap=colormap_states, ax=plt.gca())
plt.gca().set_xticklabels([])
plt.legend(title='State', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
sat.save_plot('cross_species', project, filename=f"state_proportion.svg")
plt.show()

In [None]:
#FIGURE 3. C: EARLY VS. LATE STATE PROPORTION
# List of species to process
species_data = ['human', 'monkey', 'mouse']

# List of state types to process
state_types = ['States', 'States_Accuracy', 'States_Speed']

# Loop through each state type to create a separate figure for each
for state_type in state_types:
    fig, axs = plt.subplots(3, 1, figsize=(3, 8), sharex=True)  # Create subplots with shared x-axis

    for idx, species in enumerate(species_data):
        data = sat.load_pickle(f'training_data', species)

        early = []
        late = []
        for isesh in np.unique(data['Session']):
            sesh = data[data['Session'] == isesh]
            bin1 = sesh[:len(sesh) // 2][state_type].value_counts().sort_index()
            bin1 = bin1 / bin1.sum()

            bin2 = sesh[len(sesh) // 2:][state_type].value_counts().sort_index()
            bin2 = bin2 / bin2.sum()

            early.append(bin1)
            late.append(bin2)

        # Compute average proportions and SEM
        early_df = pd.DataFrame(early).fillna(0)
        late_df = pd.DataFrame(late).fillna(0)

        average_early = early_df.mean()-np.min([early_df.mean(),late_df.mean()])
        average_late = late_df.mean()-np.min([early_df.mean(),late_df.mean()])

        # Calculate SEM for error bars
        sem_early = early_df.std() / np.sqrt(early_df.count())
        sem_late = late_df.std() / np.sqrt(late_df.count())

        # Plotting with the specified style on the current axis
        for i in average_early.index:  # Loop over existing states
            axs[idx].errorbar(
                [0, 1],
                [average_early[i], average_late[i]],
                yerr=[sem_early[i], sem_late[i]],
                linestyle='-',    # Line with markers
                marker='o',       # Circle markers
                elinewidth=1,     # Error bar line width
                capsize=3,        # Error bar cap size
                linewidth=3,      # Thicker line
                markersize=7,     # Larger markers
                color=sns.color_palette(colormap_states.colors)[int(i)], 
                label=f'State {int(i)}'
            )

        axs[idx].set_ylim([-0.05,0.30])
        axs[idx].set_xlim([-0.2,1.2])

    plt.tight_layout()
    plt.suptitle(f'{state_type}', y=1.02)  # Set a super title for each figure
    
    sat.save_plot('cross_species', project, filename=f"early_late_{state_type}.svg")
    plt.show()

In [None]:
#FIGURE 3. C: STATS

# List of species to process
species_data = ['human', 'monkey', 'mouse']

# Initialize dictionaries to store p-values for each species and each state measure
all_species_p_values = {
    'States': {},
    'States_Speed': {},
    'States_Accuracy': {}
}

# Loop through each species
for species in species_data:
    # Load data and shuffled data for the species
    data = sat.load_pickle(f'training_data', species)
    shuffled = sat.load_pickle(f'shuffled_data', species)  # List of shuffled DataFrames

    for state_type in ['States', 'States_Speed', 'States_Accuracy']:
        early = []
        late = []

        # Calculate observed differences for the current state type
        for isesh in np.unique(data['Session']):
            sesh = data[data['Session'] == isesh]
            bin1 = sesh[:len(sesh) // 2][state_type].value_counts().sort_index()
            bin1 = bin1 / bin1.sum()

            bin2 = sesh[len(sesh) // 2:][state_type].value_counts().sort_index()
            bin2 = bin2 / bin2.sum()

            early.append(bin1)
            late.append(bin2)

        early_df = pd.DataFrame(early).fillna(0)
        late_df = pd.DataFrame(late).fillna(0)

        average_early = early_df.mean()
        average_late = late_df.mean()

        # Calculate the observed differences between early and late for the current state type
        observed_diffs = {i: average_early[i] - average_late[i] for i in average_early.index}

        # Initialize dictionary to store permutation differences
        perm_diffs = {i: [] for i in average_early.index}

        # Loop through each shuffled DataFrame
        for shuffled_df in shuffled:
            early = []
            late = []
            for isesh in np.unique(shuffled_df['Session']):
                sesh = shuffled_df[shuffled_df['Session'] == isesh]
                bin1 = sesh[:len(sesh) // 2][state_type].value_counts().sort_index()
                bin1 = bin1 / bin1.sum()

                bin2 = sesh[len(sesh) // 2:][state_type].value_counts().sort_index()
                bin2 = bin2 / bin2.sum()

                early.append(bin1)
                late.append(bin2)

            # Calculate means for shuffled data
            early_df = pd.DataFrame(early).fillna(0)
            late_df = pd.DataFrame(late).fillna(0)

            average_early_perm = early_df.mean()
            average_late_perm = late_df.mean()

            # Calculate the permutation differences
            for i in average_early.index:
                perm_diffs[i].append(average_early_perm[i] - average_late_perm[i])

        # Calculate p-values by comparing observed differences to permutation differences
        p_values = {}
        for i in average_early.index:
            # Count how many permuted differences are as extreme or more extreme than observed
            extreme_count = np.sum(np.abs(perm_diffs[i]) >= np.abs(observed_diffs[i]))
            p_values[i] = extreme_count / len(perm_diffs[i])

        # Store p-values for this species and state type
        all_species_p_values[state_type][species] = p_values

# Print p-values for all species and state types
for state_type, species_p_vals in all_species_p_values.items():
    print(f'P-values for {state_type}:')
    for species, p_vals in species_p_vals.items():
        print(f'  {species}:')
        for state, p_val in p_vals.items():
            print(f'    State {state}: p-value = {p_val:.3f}')


In [None]:
#FIGURE 4. A: DWELL TIMES
# List of species to process
species_data = ['human', 'monkey', 'mouse']

# Define state types to process
state_types = ['States', 'States_Accuracy', 'States_Speed']

# Loop through each state type to create a separate figure for each
for state_type in state_types:
    fig, axs = plt.subplots(3, 1, figsize=(4, 11), sharex=True)  # Create subplots for each species

    for i, species in enumerate(species_data):
        data = sat.load_pickle(f'training_data', species)
        state_durations = data['TrialDuration']
        states = data[state_type]  # Use the current state type

        in_seconds = False
        z, p, d, values = sat.rle_with_delays(states)  # Run-length encoding, etc.

        # Define time format and limits
        time_format = 'trials'
        upper_dwell = 35  # Set the upper dwell time limit
        upper_delay = 100  # Not used but kept for consistency

        # Prepare data for plotting
        plot_data = []
        unique_states = np.unique(values)  # Identify the unique states present in the data

        for j in unique_states:
            plot_data.extend([(str(j), val) for val in z[values == j]])

        df = pd.DataFrame(plot_data, columns=['State', 'Dwell Time'])

        # Plotting on the corresponding subplot
        sns.stripplot(
            x='State', y='Dwell Time', data=df, 
            palette=colormap_states.colors, size=4, alpha=0.3, 
            jitter=0.15, edgecolor='black', linewidth=0.5, ax=axs[i]
        )
        sns.boxplot(
            x='State', y='Dwell Time', data=df, 
            palette=colormap_states.colors, showfliers=False, width=0.6, ax=axs[i]
        )

        # Set the y-limit and y-ticks divisible by 10
        axs[i].set_ylim([0, 40])
        axs[i].set_yticks(range(0, 45, 10))  # Ticks at 0, 10, 20, 30
        axs[i].set_xlim([-0.5, max(unique_states) + 0.5])
        
    # Adjust layout and remove excess spines
    plt.tight_layout()
    sns.despine()
    
    # Save plots in different formats
    sat.save_plot('cross_species', project, filename=f"dwell_times_{state_type}.svg")

    plt.show()
    plt.close()

In [None]:
#FIGURE 4. A: STATS
from scipy.stats import kruskal
import scikit_posthocs as sp

in_seconds = False
if in_seconds:
    time_format = 'seconds'
else:
    time_format = 'trials'

# Collect dwell times for each species
species_dwell_times = {}

for species in species_data:
    data = sat.load_pickle('training_data', species)
    states = data['States']
    state_durations = data['TrialDuration']
    z, p, d, values = sat.rle_in_seconds(states, state_durations, time_seconds=in_seconds)

    # Store dwell times for each species
    species_dwell_times[species] = z

# Print average dwell times across all states per species
average_dwell_times = {species: np.mean(dwell_times) for species, dwell_times in species_dwell_times.items()}
print(f"Average dwell times {time_format}:", average_dwell_times)

# Perform Kruskal-Wallis test across species
dwell_time_values = list(species_dwell_times.values())
h_stat, p_value = kruskal(*dwell_time_values)

print(f"Kruskal-Wallis H-statistic: {round(h_stat, 2)}, p-value: {round(p_value, 2)}")

# Perform post-hoc Dunn's test with Bonferroni correction if Kruskal-Wallis is significant
if p_value < 0.05:
    species_names = list(species_dwell_times.keys())
    posthoc = sp.posthoc_dunn(dwell_time_values, p_adjust='bonferroni')

    # Set the species names as indices and columns for the post-hoc result
    posthoc.index = species_names
    posthoc.columns = species_names
    
    print("Post-hoc pairwise comparisons (Dunn's test with Bonferroni correction):\n",round(posthoc, 2) )


In [None]:
#FIGURE 4. B: TRANSITION MATRICES

for species in species_data:
    data = sat.load_pickle(f'training_data', species)
    transision_matrix = sat.create_transition_matrix(data['States'].values, minus_diag=True, window=1)
    fig, ax = plt.subplots(ncols=1, figsize=(4, 4))

    # Plot the heatmap with annotations, limit range, and format numbers to 2 decimals
    sns.heatmap(
        transision_matrix, 
        ax=ax, 
        annot=True, 
        fmt=".2f",        # Format numbers to 2 decimal places
        cmap='Greys', 
        vmax=0.8, 
        cbar=False        # Disable colorbar if not needed
    )
    # Remove ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")

    # Save and display the plot
    sat.save_plot('cross_species', project, filename=f"transition_matrix_{species}.svg")
    plt.show()

In [None]:
#FIGURE 4. B: STATS

p_values = {}

for species in species_data:
    # Load original data and compute the transition matrix
    data = sat.load_pickle('training_data', species)
    original_states = data['States'].values
    original_transition_matrix = sat.create_transition_matrix(original_states, minus_diag=True, window=1)
    
    # Compute transition matrices for 1000 shuffles using block shuffling
    shuffled_transmats = []
    for _ in range(1000):
        shuffled_states = sat.shuffle_blocks(original_states)
        shuffled_transition_matrix = sat.create_transition_matrix(shuffled_states, minus_diag=True, window=1)
        shuffled_transmats.append(shuffled_transition_matrix)
    mean_shuffled_transmat = np.mean(shuffled_transmats, axis=0)
    
    # Compute effect sizes and p-values
    n_states = original_transition_matrix.shape[0]
    p_matrix = np.zeros((n_states, n_states))
    effect_size_matrix = np.zeros((n_states, n_states))
    
    for i in range(n_states):
        for j in range(n_states):
            shuffled_vals = np.array([mat[i, j] for mat in shuffled_transmats])
            mean_shuffled = np.mean(shuffled_vals)
            std_shuffled = np.std(shuffled_vals)
            
            # Compute two-sided p-value based on absolute differences
            original_val = original_transition_matrix[i, j]
            p_matrix[i, j] = np.mean(
                np.abs(shuffled_vals - mean_shuffled) >= 
                np.abs(original_val - mean_shuffled)
            )

            # Compute Cohen's d effect size
            if std_shuffled > 0:
                effect_size_matrix[i, j] = (original_val - mean_shuffled) / std_shuffled
            else:
                effect_size_matrix[i, j] = 0
    
    # Apply FDR correction to the p-values
    flat_p = p_matrix.flatten()
    rejected, p_corrected, _, _ = multipletests(flat_p, alpha=0.05, method='fdr_bh')#bonferroni
    significant = rejected.reshape(p_matrix.shape)
    # Create annotations for effect sizes with significance markers
    effect_size_annot = np.empty_like(effect_size_matrix, dtype=object)
    for i in range(n_states):
        for j in range(n_states):
            effect_size_annot[i, j] = f"{effect_size_matrix[i,j]:.1f}" + ("*" if significant[i,j] else "")

    # Plot the heatmaps
    fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
    
    # Original transition matrix
    sns.heatmap(original_transition_matrix, ax=axes[0], annot=True, fmt='.2f', cmap='Blues',vmax=0.6)
    axes[0].set_title(f'Original Transition Matrix\n{species}')
    
    # Mean shuffled transition matrix
    sns.heatmap(mean_shuffled_transmat, ax=axes[1], annot=True, fmt='.2f', cmap='Blues',vmax=0.6)
    axes[1].set_title(f'Mean Shuffled Matrix\n{species}')
    
    # Effect size matrix with significance markers
    # Using diverging colormap centered at 0 for effect sizes
    sns.heatmap(effect_size_matrix, ax=axes[2], 
                annot=effect_size_annot, fmt='', 
                cmap='RdBu_r', center=0,
                vmin=-3, vmax=3)  # Limiting to typical Cohen's d range
    axes[2].set_title(f'Effect Sizes (Cohen\'s d)\n* = significant, {species}')
    
    plt.tight_layout()
    plt.show()

In [None]:
#FIGURE 5. A-B: POWER SPECTRUM OF STATE SEQUENCE
metrics_to_slt=['States_Speed','States_Accuracy']##'States_Speed'#'SpeedMean'#'Combined_Speed'#,'RT_Raw','SpeedMean_Raw','Correct_Raw'
for metr in metrics_to_slt:   
    fig, axs = plt.subplots(3, 1, figsize=(5, 10), sharex=True)
    for i, species in enumerate(species_data):
        power = sat.load_pickle(f'power_{metr}', species)
        avg_power_fr = np.mean(power, axis=1)

        slt_results = sat.load_pickle(f'results_slt_shuffled_{metr}', species)
        means = []
        for result in slt_results:
            with h5py.File(result, 'r') as f:
                data = f['result_0']
                means.append(np.mean(data, axis=1))

        means = np.array(means)
        mean_shuffled_power = np.mean(means, axis=0)
        q5, q95 = np.percentile(means, [2.5, 97.5], axis=0)
        
        axs[i].plot(fs/foi, avg_power_fr, label='Actual Power', color='black')
        axs[i].plot(fs/foi, mean_shuffled_power, color='gray', linestyle='--', label='Shuffled Mean')
        axs[i].fill_between(fs/foi, q5, q95, color='gray', alpha=0.3, label='95 CI')
        axs[i].set_xlim(15, 100)
        axs[i].set_ylim(0.003,0.0155)

    plt.tight_layout()
    plt.legend()
    sat.save_plot('cross_species', project, filename=f"avg_pw_fr_line_states_{metr}.svg")
    plt.show()

In [None]:
#FIGURE 6. DIFFICULTY EFFECT

# Define the species to process
species_data = ['human', 'monkey', 'mouse']
metric='States_Speed'#'States_Accuracy'#
num_permutations = 1000  # Set the number of permutations

# Initialize a dictionary to store difficulty differences for each species and transition type
transition_diffs_all_species = {
    species: {(from_state, to_state): [] for from_state in range(4) for to_state in range(4) if from_state != to_state} 
    for species in species_data
}

# Process the data for each species
for species in species_data:
    data = sat.load_pickle(f'training_data', species)
    states = data[metric]
    correct = data['Correct']  # Performance data
    z, p, values = sat.rle(states)
    all_morphs = []

    for isesh in data['Session'].unique():
        this_sesh = data['Session'] == isesh
        session_data = data[this_sesh]
        morphs = session_data['MorphTarget'].apply(sat.fold_morph_values)
        
        # Normalize difficulty per session
        mean_difficulty = morphs.mean()
        all_morphs.extend(morphs - mean_difficulty)

    # Iterate through each state change position
    for idx in range(1, len(p)):
        from_state = values[idx - 1]
        to_state = values[idx]
        pos = p[idx]

        # Skip self-transitions
        if from_state == to_state:
            continue

        # Define the window around the state change (+-10 trials)
        start = pos - 10
        end = pos + 10

        # Ensure that the window is fully within the range of all_morphs
        if start < 0 or end >= len(all_morphs):
            continue  # Skip this transition if the window is out of bounds

        # Get the morph difficulties within the window
        window_difficulties = all_morphs[start:end + 1]
        
        # Store the difficulty differences based on transition type (from_state -> to_state)
        transition_diffs_all_species[species][(from_state, to_state)].append(window_difficulties)

# Set up the plot
num_species = len(species_data)
num_transitions = 2  # Assuming 2 transitions: slow->fast and fast->slow
fig, axes = plt.subplots(num_transitions, num_species, figsize=(8, 5))
fig.subplots_adjust(hspace=0.1, wspace=0.2)

# Plot all transitions for each species and calculate the null distribution
for col, species in enumerate(species_data):
    transition_diffs = transition_diffs_all_species[species]
    
    # Add species name at the top of each column
    axes[0, col].set_title(f'{species.capitalize()}', fontsize=14, pad=20)
    
    row = 0
    for transition, diffs in transition_diffs.items():
        if diffs:  # Only plot if there are transitions of that type
            # Convert list of transitions (each transition is a list of 21 trials) into a numpy array
            diffs_array = np.array(diffs)
            
            # Calculate the percentage of trials above the mean (positive trials) for each trial position
            percent_positive_per_trial = np.mean(diffs_array >= 0, axis=0) * 100
            
            # Calculate null distribution by shuffling difficulty labels
            null_distribution = sat.compute_null_distribution(diffs_array, num_permutations)
            
            # Calculate 95th percentile for each trial point
            upper_threshold = np.percentile(null_distribution, 97.5 , axis=1)
            lower_threshold = np.percentile(null_distribution, 2.5, axis=1)
            
            trials = np.arange(-10, 11)  # Define trial range (-10 to +10 around transition)

            # Plot in the appropriate subplot
            ax = axes[row, col]
            
            # Plot observed percentages
            ax.plot(trials, percent_positive_per_trial, color='r', label="Observed")
            
            # Plot null distribution upper and lower thresholds
            ax.fill_between(trials, lower_threshold, upper_threshold, color='gray', alpha=0.3, label="95% Null")
            
            # Plot chance line
            ax.axhline(np.mean(percent_positive_per_trial), color='black', linestyle='--', label="Mean Observed")  # Adjusted line
            ax.axvline(0, color='black', linestyle='-')  # Transition line

            # Set limits
            ax.set_yticks([0, 50, 100])  # Set y-ticks
            ax.set_ylim([0, 100])  # Set y-axis limits
            ax.set_xlim([-10, 10])  # Set x-axis limits

            row += 1  # Move to the next row for plotting

# Adjust plot appearance
for ax in axes.flat:
    ax.label_outer()  # Hide outer labels but keep inner ticks

plt.tight_layout()

sat.save_plot('cross_species', project, filename=f"difficulty_psth_{metric}.svg")
plt.show()


In [None]:
#FIGURE S.1. METRICS DISTRIBUTION

n_species = 3
n_metrics = 5

plt.figure(figsize=(n_metrics * 2, n_species * 2))

# Remove 'Correct' and 'Wrong' from selected metrics
selected_metrics = ['Bias', 'Prec', 'RT', 'SpeedMean']

data = {}

# Concatenate sessions data for each species
for isp, species in enumerate(species_data):
    sessions = species_data[species].copy()
    sesh_list = []   
    for isesh, sesh in enumerate(sessions):
        sesh = sesh.copy()
        sesh = sesh.dropna().reset_index()
        sesh = sesh.loc[:, selected_metrics + ['Correct', 'Wrong']]  # Keep Correct and Wrong only for Outcome calculation
        sesh['RT'] = sesh['RT'] / 60  # Convert reaction time
        sesh['Outcome'] = sesh['Correct'] * 2 + sesh['Wrong']  # Compute the Outcome based on Correct and Wrong
        sesh_list.append(sesh)
    data[species] = pd.concat(sesh_list)
selected_metrics=['Outcome']+selected_metrics
# Plot the distributions
for isp, species in enumerate(species_data):
    for i, column in enumerate(selected_metrics, 1):
        ax = plt.subplot(n_species, n_metrics, isp * n_metrics + i)
        
        # Plot histograms split by 'Outcome'
        sns.histplot(data[species], x=column, hue='Outcome',  kde=True, ax=ax, legend=(isp == 0 and i == 1))#multiple='stack',
        
        # Show x-axis label only for the bottom row
        if isp == n_species - 1:
            ax.set_xlabel(column)
        else:
            ax.set_xlabel('')

        # Show y-axis label only for the first column
        if i == 1:
            ax.set_ylabel(species)
        else:
            ax.set_ylabel('')
            
handles, labels = ax.get_legend_handles_labels()
plt.legend(handles, labels, loc='upper right')

plt.tight_layout()

sat.save_plot('cross_species', project, filename=f"metrics_distribution_by_outcome.svg")
plt.show()


In [None]:
#FIG S.2. METRICS CORRELATIONS
selected_metrics = ['Correct_Raw','Bias_Raw', 'Prec_Raw', 'RT_Raw', 'SpeedMean_Raw']

# Plot heatmaps of correlation for each species
for isp, species in enumerate(species_data):
    plt.figure(figsize=(6, 6))
    fig, ax = plt.subplots(ncols=1, figsize=(4, 4))
    data = sat.load_pickle('training_data', species)# Adjust size based on the number of metrics
    corr_matrix = data[selected_metrics].corr()  # Compute correlation matrix for selected metrics
    sns.heatmap(corr_matrix,ax=ax, annot=True, fmt=".2f", cmap="BuPu", vmin=-0.1, vmax=1,cbar=False)

    # Remove ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")
    plt.tight_layout()

    # Save plot (if necessary)
    sat.save_plot('cross_species', project, filename=f"metrics_correlation_heatmap_{species}.svg")

plt.show()


In [None]:
#FIGURE S.3. METRICS SCATTERPLOTS

# Define the pairs to plot
metric_pairs = [('Correct', 'Bias'), ('Correct', 'Prec'), ('Bias', 'Prec')]

metric_pairs += [(x, y) for y in ['Correct', 'Bias', 'Prec'] for x in ['RT', 'SpeedMean']]

# Add ('RT', 'SpeedMean') for a total of 10 plots
metric_pairs.append(('RT', 'SpeedMean'))

for species in ['human', 'monkey', 'mouse']:
    data = sat.load_pickle(f'training_data', species)
    
    # Create a figure with subplots for the 10 unique pairs in a 2x5 grid
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    #fig.suptitle(f'Metric Clouds for {species}', fontsize=16)
    
    # Flatten the axes array for easy iteration
    axes = axes.flatten()
    
    label_map = {
        'Correct': 'Hit rate',
        'Prec': 'Precision',
        'SpeedMean': 'Speed'
    }
    
    # Loop through each specified pair of metrics
    for idx, (metric_x, metric_y) in enumerate(metric_pairs):
        ax = axes[idx]
        sns.scatterplot(data=data, x=metric_x, y=metric_y, hue=data['States'], 
                        palette=colormap_states, alpha=0.8, legend=False, ax=ax)
        ax.set_xlabel(label_map.get(metric_x, metric_x), fontsize=18)
        ax.set_ylabel(label_map.get(metric_y, metric_y), fontsize=18)
        ax.tick_params(axis='both', labelsize=14)
    # Adjust layout and show the plot
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to make room for the main title
    sat.save_plot('cross_species', project, filename=f"organized_metric_clouds_{species}.svg")
    plt.show()


In [None]:
#FIGURE S.4. EFFECT OF SMOOTHING WINDOWSIZE
preprocessed_data = sat.read_dfs(folder=project)

# Define window sizes and initialize storage lists
windows = np.arange(2, 20, 1)
n_runs = 20
n_states = 4
accuracy_metrics = ['Correct', 'Bias', 'Prec']
speed_metrics = ['RT', 'SpeedMean']

try:
    with open('species_results1.pkl', 'rb') as f:
        species_results = pickle.load(f)
except FileNotFoundError:
    species_results = {}

for species in preprocessed_data:
    if species in species_results:
        continue  # Skip already processed species
    sessions = preprocessed_data[species]

    median_dwell = []
    perc_5_dwell = []
    perc_95_dwell = []

    median_dwell_acc = []
    perc_5_dwell_acc = []
    perc_95_dwell_acc = []

    median_dwell_sp = []
    perc_5_dwell_sp = []
    perc_95_dwell_sp = []

    for window in windows:
        run_dwell = []
        run_dwell_acc = []
        run_dwell_sp = []

        for run in range(n_runs):
            columns, values = prepare_and_train_dataset(1,sessions, accuracy_metrics, speed_metrics, shuffle=False)
            training_data = pd.DataFrame(values, columns=columns)
            
            states=training_data['States']
            states_1=training_data['States_Accuracy']
            states_2=training_data['States_Speed']
            z, p, d, values = sat.rle_with_delays(states)
            z_acc, p, d, values_acc = sat.rle_with_delays(states_1)
            z_sp, p, d, values_sp = sat.rle_with_delays(states_2)


            # Store dwell times for this run, using np.nan for missing states
            max_states = 2  # Set max_states to the total number of expected states
            run_dwell.append([np.median(z[values == state]) if state in values else np.nan for state in range(max_states*2)])
            run_dwell_acc.append([np.median(z_acc[values_acc == state]) if state in values_acc else np.nan for state in range(max_states)])
            run_dwell_sp.append([np.median(z_sp[values_sp == state]) if state in values_sp else np.nan for state in range(max_states)])


        # Calculate median and percentiles across runs for each state
        run_dwell = np.array(run_dwell)
        run_dwell_acc = np.array(run_dwell_acc)
        run_dwell_sp = np.array(run_dwell_sp)

        median_dwell.append(np.nanmedian(run_dwell, axis=0))
        perc_5_dwell.append(np.nanpercentile(run_dwell, 5, axis=0))
        perc_95_dwell.append(np.nanpercentile(run_dwell, 95, axis=0))

        median_dwell_acc.append(np.nanmedian(run_dwell_acc, axis=0))
        perc_5_dwell_acc.append(np.nanpercentile(run_dwell_acc, 5, axis=0))
        perc_95_dwell_acc.append(np.nanpercentile(run_dwell_acc, 95, axis=0))

        median_dwell_sp.append(np.nanmedian(run_dwell_sp, axis=0))
        perc_5_dwell_sp.append(np.nanpercentile(run_dwell_sp, 5, axis=0))
        perc_95_dwell_sp.append(np.nanpercentile(run_dwell_sp, 95, axis=0))

    # Store results in the dictionary
    species_results[species] = {
        'median_dwell': median_dwell,
        'perc_5_dwell': perc_5_dwell,
        'perc_95_dwell': perc_95_dwell,
        'median_dwell_acc': median_dwell_acc,
        'perc_5_dwell_acc': perc_5_dwell_acc,
        'perc_95_dwell_acc': perc_95_dwell_acc,
        'median_dwell_sp': median_dwell_sp,
        'perc_5_dwell_sp': perc_5_dwell_sp,
        'perc_95_dwell_sp': perc_95_dwell_sp,
    }
    # Save the updated results after each species to prevent data loss
    #sat.save_pickle(data,f'species_results',species)
    with open('species_results.pkl', 'wb') as f:
        pickle.dump(species_results, f)


In [None]:
#FIGURE S.4 EFFECT OF SMOOTHING WINDOWSIZE PLOT

with open('species_results.pkl', 'rb') as f:
    species_results = pickle.load(f)
windows = np.arange(2, 20, 1)   
values_acc = [0,1]
values_sp = [0,1]
for species, results in species_results.items():
    median_dwell = results['median_dwell']
    perc_5_dwell = results['perc_5_dwell']
    perc_95_dwell = results['perc_95_dwell']

    median_dwell_acc = results['median_dwell_acc']
    perc_5_dwell_acc = results['perc_5_dwell_acc']
    perc_95_dwell_acc = results['perc_95_dwell_acc']

    median_dwell_sp = results['median_dwell_sp']
    perc_5_dwell_sp = results['perc_5_dwell_sp']
    perc_95_dwell_sp = results['perc_95_dwell_sp']

    fig, axs = plt.subplots(1, 3, figsize=(10, 3))
    states = [0,1,2,3]#np.unique(values)

    for i, state in enumerate(states):
        median_dwell_state = [md[i] for md in median_dwell]
        perc_5_dwell_state = [p5[i] for p5 in perc_5_dwell]
        perc_95_dwell_state = [p95[i] for p95 in perc_95_dwell]
        axs[0].plot(windows, median_dwell_state, '-o', label=f'State {state}', color=colormap_states(int(state)))
        axs[0].fill_between(windows, perc_5_dwell_state, perc_95_dwell_state, color=colormap_states(int(state)), alpha=0.2)

    states_acc = np.unique(values_acc)
    for i, state in enumerate(states_acc):
        median_dwell_acc_state = [md[i] for md in median_dwell_acc]
        perc_5_dwell_acc_state = [p5[i] for p5 in perc_5_dwell_acc]
        perc_95_dwell_acc_state = [p95[i] for p95 in perc_95_dwell_acc]
        axs[1].plot(windows, median_dwell_acc_state, '-o', label=f'State {state}', color=colormap_states_accuracy(int(state)))
        axs[1].fill_between(windows, perc_5_dwell_acc_state, perc_95_dwell_acc_state, color=colormap_states_accuracy(int(state)), alpha=0.2)

    states_sp = np.unique(values_sp)
    for i, state in enumerate(states_sp):
        median_dwell_sp_state = [md[i] for md in median_dwell_sp]
        perc_5_dwell_sp_state = [p5[i] for p5 in perc_5_dwell_sp]
        perc_95_dwell_sp_state = [p95[i] for p95 in perc_95_dwell_sp]
        axs[2].plot(windows, median_dwell_sp_state, '-o', label=f'State {state}', color=colormap_states_speed(int(state)))
        axs[2].fill_between(windows, perc_5_dwell_sp_state, perc_95_dwell_sp_state, color=colormap_states_speed(int(state)), alpha=0.2)

    axs[0].set_title(f'All States ({species})')
    axs[0].set_ylim(0, 25)
    axs[1].set_title(f'States Accuracy ({species})')
    axs[1].set_ylim(0, 45)
    axs[2].set_title(f'States Speed ({species})')
    axs[2].set_ylim(0, 45)

    for ax in axs:
        ax.set_xlabel('Window Size')
        ax.set_ylabel('Median Dwell Time')
        ax.legend()

    plt.tight_layout()
    sat.save_plot('cross_species', project, filename=f"smoothing_win_dwell_time_{species}.svg")
    plt.show()


In [None]:
#FIGURE S.5. PSTH OF STATE OCCURANCE

window = 15
gap=1
num_states = 4

# Create figure outside the loop
for i, species in enumerate (species_data):
    fig, axs = plt.subplots(num_states, num_states, figsize=(9, 9))  # Adjust size as needed
    fig.suptitle(f'State PSTH {species}')
    
    data = sat.load_pickle(f'training_data', species)
    states = data['States']
    z, p, v = sat.rle(states)
    
    all_psth = []

    # Run-length encoding for each species
    z, p, v = sat.rle(data['States'])
    for s1 in range(num_states):
        for s2 in range(num_states):
            psth = sat.psth_state_occurrence(states, z, p, v, s1, s2, window, gap=gap)
            all_psth.append(psth)

    global_max = np.max([np.max(arr) for arr in all_psth])
    
    for state_1 in range(num_states):
        for state_2 in range(num_states):
            psth = sat.psth_state_occurrence(states, z, p, v, state_1, state_2, window,gap=gap)
            psth_norm = psth / global_max
            # Select the appropriate subplot
            ax = axs[state_1, state_2]
            
            # Plot on the selected subplot
            #ax.bar(np.arange(-window-gap/2, window + gap/2), psth, color=colormap_states.colors[state_2], width=1 )# , edgecolor=colormap_states.colors[state_2])
            #ax.bar(np.arange(-gap/2,gap/2), ylim, color=colormap_states.colors[state_1], width=1)
            
            ax.bar(np.arange(-window ,window+1), psth_norm, color=colormap_states.colors[state_2], width=1 )# , edgecolor=colormap_states.colors[state_2]),linewidth=0
            ax.axvline(0,linewidth=3, color=colormap_states.colors[state_1])   
            ax.set_ylim([0,1])

    # Adjust layout to prevent overlap
    plt.tight_layout()
    #sat.save_plot('cross_species', project, filename=f"state_occurrence_psth_{species}_state_start.svg")
    plt.show()

In [None]:
#FIGURE S.6. TRIAL DURATIONS
n_species = len(species_data)

# Set up a compact figure with subplots, one for each species
fig, axes = plt.subplots(1, n_species, figsize=(7, 3), sharey=True)

# Loop through each species and plot the trial duration distribution
for i, species in enumerate(species_data):
    # Load the data for the current species
    data = sat.load_pickle('training_data', species)
    
    # Extract trial duration data and filter out trials longer than 50
    state_durations = data['TrialDuration']
    #state_durations = state_durations[state_durations <= 50]  # Remove trials > 50
    
    # Plot the histogram for each species in its respective subplot
    sns.histplot(state_durations, kde=True, bins=30, ax=axes[i], color='C'+str(i))

    # Set x-axis limit to 50 for all subplots
    axes[i].set_xlim(0, 10)
    
    # Set labels and title for each subplot
    axes[i].set_xlabel('Trial Duration (sec)')

# Set the shared y-axis label
axes[0].set_ylabel('Trial count')

# Adjust layout to make the figure compact
plt.tight_layout()
sat.save_plot('cross_species', project, filename=f"hist_trial_duration.svg")
# Show the plot
plt.show()
