In [None]:
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import balanced_accuracy_score
from sklearn.linear_model import Lasso
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
import numpy as np
import copy
import pickle
import matplotlib.cm as cm

n_splits = 10
results_dict = {}  # Store mean/std r² per i and n_timesteps
best_models = {}   # Store best model per (i, n_timesteps)
T = 5
fig, ax = plt.subplots()
for i, (resampled_onset, exp_onset) in enumerate(zip(resampled_onsets, exp_onsets)):
    score_means = []
    score_stds = []
    # all_r2_means = []
    all_scores_per_i = {}
    n_timesteps_list = []
    

    # for n_timesteps in range(5, 10, 2):
    for n_timesteps in range(1, 15, 1):
        print(f"\nProcessing i={i}, n_timesteps={n_timesteps}")
        feature_names = np.array([[f"vel_t-{n_timesteps - t - 1}", f"accel_t-{n_timesteps - t - 1}",
                                   f"curv_t-{n_timesteps - t - 1}", f"beh_t-{n_timesteps - t - 1}"]
                                  for t in range(n_timesteps)]).flatten()

        X_stim = prep_FB_inputs(resampled_vel, resampled_acc, resampled_curve, resampled_rev,
                                resampled_turn, T=n_timesteps)  # shape (n_tracks, n_frames, n_features)
        
        rev_bin = exp_behaviors
        X_stim_all = X_stim[:, resampled_onset - T, :]

        
        Y_latency = latency_to_reversal(rev_bin, exp_onset, max_latency = 6*durations[i]+1)/6 # nan where not 
        no_rev_at_onset = Y_latency!=0
        Y_prob = np.logical_not(np.isnan(Y_latency[no_rev_at_onset]))# classfy y/n did rev happen
        X_prob = copy.deepcopy(X_stim_all[no_rev_at_onset])
        
        ###    
        # X_train, X_test, y_train, y_test = train_test_split(X_prob, Y_prob, test_size=test_size, random_state=random_state)
        balanced_accuracy_scores = []
        models = []

        for split_idx in range(n_splits):
            # X_train, X_test, y_train, y_test = train_test_split(
            #     X_stim_all_latency, Y_latency_valid, test_size=test_size, random_state=split_idx)
            
            X_train, X_test, y_train, y_test = train_test_split(X_prob, Y_prob, test_size=test_size, random_state=split_idx)
            
            param_grid = {
                'C': [0.001, 0.01, 0.1, 1, 10, 100]
            }

            model = LinearSVC(penalty='l1', dual=False, max_iter=10000)
            # Grid search
            grid = GridSearchCV(model, param_grid, cv=5, scoring='balanced_accuracy')
            grid.fit(X_train, y_train)
            c = grid.best_params_['C']
            best_model = grid.best_estimator_
            
            y_pred = best_model.predict(X_test)
            score =balanced_accuracy_score(y_test, y_pred)  

            balanced_accuracy_scores.append(score)
            # models.append(best_model)
            models.append((best_model,X_train, X_test, y_train, y_test, c))

        accuracy_mean = np.mean(balanced_accuracy_scores)
        accuracy_std = np.std(balanced_accuracy_scores)
        best_model_idx = np.argmax(balanced_accuracy_scores)
        best_model_overall = models[best_model_idx]

        # Save results
        n_timesteps_list.append(n_timesteps)
        # all_r2_means.append(r2s)
        all_scores_per_i[(i, n_timesteps)] = balanced_accuracy_scores
        score_means.append(accuracy_mean)
        score_stds.append(accuracy_std)
        results_dict[(i, n_timesteps)] = (accuracy_mean, accuracy_std)
        best_models[(i, n_timesteps)] = best_model_overall


    # cmap1 = cm.get_cmap('tab10', len(resampled_onsets))
    ax.errorbar(n_timesteps_list, score_means, yerr=score_stds, #color=cmap1(i),
                fmt='-o', capsize=5, label = f"stim {i}")

    ax.set_title(f"{neuron}; Reversal prediction; Mean ± STD balanced accuracy")
    ax.set_xlabel("max time delay (frames)")
    ax.set_ylabel("balanced accuracy")
ax.legend()
plt.grid(True)
plt.tight_layout()
# plt.savefig(f"r2_plot_i{i}.png")
# plt.close()

In [None]:

###get inputs into lstm 
beh_map = [z_norm.min(), z_norm.min()/2 + z_norm.max()/2, z_norm.max()]

def resample_fps(feature_arr, target_fps,  original_fps):
    # Resample using linear interpolation
    indices = np.arange(len(feature_arr))
    new_indices = np.linspace(0, len(feature_arr) - 1, int(len(feature_arr) * (target_fps / original_fps)))
    resampled_values = np.interp(new_indices, indices, feature_arr)
    return resampled_values

def resample_2d(inferred_phases_all_shifted, target_fps,  original_fps):
    n_tracks = inferred_phases_all_shifted.shape[0]
    resampled_tracks = []
    for track_i in range(n_tracks):
        resampled = resample_fps(inferred_phases_all_shifted[track_i, :], target_fps,  original_fps)
        resampled_tracks.append(resampled[None,:])

    inferred_phases_all_shifted_high_fps = np.concatenate(resampled_tracks, axis= 0 )
    return inferred_phases_all_shifted_high_fps


def feature_all_to_resampled(inferred_phases_all, n_tracks):
   
    inferred_phases_all = inferred_phases_all.reshape(n_tracks, -1)
   
    n_timesteps = inferred_phases_all.shape[1]
    print("n_timesteps", n_timesteps)
    # inferred_phases_all_shifted = np.zeros((n_tracks,n_timesteps+5+8))+np.nan #ask Bennet about the 8 extra time steps missing
    # inferred_phases_all_shifted[:, 5:-8] = inferred_phases_all
    
    inferred_phases_all_shifted = np.zeros((n_tracks,n_timesteps+5+8+8))+np.nan #ask Bennet about the 8 extra time steps missing
    inferred_phases_all_shifted[:, 5:-16] = inferred_phases_all
    return inferred_phases_all_shifted

def flatten_and_remove_nans(resampled_features):
    
    # return resampled_features[:, 5:-8].flatten()
    return resampled_features.flatten()


# n_tracks  = resampled_vel.shape[0]



# print(phase_resampled.shape)
# print(radii_resampled.shape)


def prep_FB_inputs_donut_only( inferred_phases_all, inferred_rad_all, resampled_onset, n_tracks):
    phase_resampled = feature_all_to_resampled(inferred_phases_all, n_tracks)
    radii_resampled = feature_all_to_resampled(inferred_rad_all, n_tracks) 
    
    X_donut = np.concatenate(phase_resampled[:, :, None] ,  radii_resampled[:, :, None], axis = 2)
    X_donut_stim = X_donut[:, resampled_onset,:]
    
    feature_names = ["phase", "radius"]
    return X_donut_stim, feature_names
    



def prep_FB_inputs_features_only(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn, resampled_onset,   T = 5
                #    times, beh_map, 
                #    lag=16, inclusion_thresh =3, 
                #    remove_revs=False, 
                  
                   ): 

    behavior_input = np.array(resampled_rev+2*resampled_turn, dtype=np.float64)
    behavior_input -= np.array(z).mean()
    behavior_input /= np.array(z).std()

    X_all_LSTM = []
    for new_worm_idx in range(len(resampled_vel)):
       
        X_new_worm = np.stack([resampled_vel[new_worm_idx], resampled_acc[new_worm_idx], resampled_curve[new_worm_idx], behavior_input[new_worm_idx]], axis=1)  
        # X_new_worm = np.stack([resampled_vel[new_worm_idx], resampled_acc[new_worm_idx], behavior_input[new_worm_idx]], axis=1)  
        X_new_tensor = torch.tensor(X_new_worm, dtype=torch.float32)
        
        if T > 0:
            X_new_seq1 = create_X_sequences(X_new_tensor, T).numpy() #torch.Size([475, 5, 4]), = n_frames, time delay, 4 is the feature  #
        else: 
            X_new_seq1 = X_new_tensor.numpy()
        
        n_frames, delay, n_features = X_new_seq1.shape
        X_new_seq1 = X_new_seq1.reshape((n_frames,  delay*n_features)) # check how this is shaped it will be f1_t-T, f2_t-T, f3_t-T, f4_t-T, ; f1_t-T+1, f2_T+1, f3_T+1 f4_T+1
        
            
        X_all_LSTM.append(X_new_seq1)
    # X_all_LSTM = np.concatenate(X_all_LSTM, axis = 1 ) # i guess should check if reshape how it reshapes.. ie if feautres are all together or not 
    X_all_LSTM = np.array(X_all_LSTM)#shape (n_tracks, n_frames, n_features)
    feature_names = np.array([[f"vel_t-{T - t - 1}", f"accel_t-{T - t - 1}",
                                   f"curv_t-{T - t - 1}", f"beh_t-{T - t - 1}"]
                                  for t in range(T)]).flatten()
    
    X_stim_features = X_all_LSTM[:, resampled_onset - T, :]

    return X_stim_features, feature_names


    
def prep_FB_inputs_feature_and_donut(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn, inferred_phases_all, inferred_rad_all,resampled_onset, n_tracks,   T = 5):
    
    X_stim_features, feature_names = prep_FB_inputs_features_only(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn, resampled_onset,  T = T)#shape (n_tracks, n_frames, n_features)
    X_donut_stim, donut_names = prep_FB_inputs_donut_only( inferred_phases_all, inferred_rad_all, resampled_onset, n_tracks)

    X_stim_all =  np.concatenate(X_stim_features[:, :, None] ,  X_donut_stim[:, :, None], axis = 2)
    feature_names_all = feature_names+donut_names
    return X_stim_all, feature_names_all





for i, (resampled_onset, exp_onset) in enumerate(zip(resampled_onsets, exp_onsets)):
    
    np.concatenate(inferred_phases[idxs[i]]), np.concatenate(inferred_rad[idxs[i]])


In [None]:
def prep_FB_inputs_features_only(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn, resampled_onset,   T = 5
                #    times, beh_map, 
                #    lag=16, inclusion_thresh =3, 
                #    remove_revs=False, 
                  
                   ): 

    behavior_input = np.array(resampled_rev+2*resampled_turn, dtype=np.float64)
    behavior_input -= np.array(z).mean()
    behavior_input /= np.array(z).std()

    X_all_LSTM = []
    for new_worm_idx in range(len(resampled_vel)):
       
        X_new_worm = np.stack([resampled_vel[new_worm_idx], resampled_acc[new_worm_idx], resampled_curve[new_worm_idx], behavior_input[new_worm_idx]], axis=1)  
        # X_new_worm = np.stack([resampled_vel[new_worm_idx], resampled_acc[new_worm_idx], behavior_input[new_worm_idx]], axis=1)  
        X_new_tensor = torch.tensor(X_new_worm, dtype=torch.float32)
        
        if T > 0:
            X_new_seq1 = create_X_sequences(X_new_tensor, T).numpy() #torch.Size([475, 5, 4]), = n_frames, time delay, 4 is the feature  #
        else: 
            X_new_seq1 = X_new_tensor.numpy()
        
        n_frames, delay, n_features = X_new_seq1.shape
        X_new_seq1 = X_new_seq1.reshape((n_frames,  delay*n_features)) # check how this is shaped it will be f1_t-T, f2_t-T, f3_t-T, f4_t-T, ; f1_t-T+1, f2_T+1, f3_T+1 f4_T+1
        
            
        X_all_LSTM.append(X_new_seq1)
    # X_all_LSTM = np.concatenate(X_all_LSTM, axis = 1 ) # i guess should check if reshape how it reshapes.. ie if feautres are all together or not 
    X_all_LSTM = np.array(X_all_LSTM)#shape (n_tracks, n_frames, n_features)
    feature_names = np.array([[f"vel_t-{T - t - 1}", f"accel_t-{T - t - 1}",
                                   f"curv_t-{T - t - 1}", f"beh_t-{T - t - 1}"]
                                  for t in range(T)]).flatten()
    
    X_stim_features = X_all_LSTM[:, resampled_onset - T, :]

    return X_stim_features, feature_names



def prep_FB_inputs(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn,  T = 5
                #    times, beh_map, 
                #    lag=16, inclusion_thresh =3, 
                #    remove_revs=False, 
                  
                   ): 
    #num time lags

    # if remove_revs:
    #     rev_id = beh_map[1]
    # else:
    #     rev_id = 100000

    behavior_input = np.array(resampled_rev+2*resampled_turn, dtype=np.float64)
    behavior_input -= np.array(z).mean()
    behavior_input /= np.array(z).std()

    # inferred_phases = np.zeros((len(times), len(resampled_vel))) +np.nan
    # inferred_rad = np.zeros((len(times), len(resampled_vel))) +np.nan
    # final_behaviors = np.zeros((len(times), len(resampled_vel))) +np.nan
    # initial_behaviors = np.zeros((len(times), len(resampled_vel))) +np.nan
    # initial_cts_beh = np.zeros((len(times), len(resampled_vel))) +np.nan

    # final_behaviors_all = []
    # inferred_phases_all = []
    # inferred_rad_all = []
    # behaviors_all = []
    X_all_LSTM = []
    for new_worm_idx in range(len(resampled_vel)):
       
        X_new_worm = np.stack([resampled_vel[new_worm_idx], resampled_acc[new_worm_idx], resampled_curve[new_worm_idx], behavior_input[new_worm_idx]], axis=1)  
        # X_new_worm = np.stack([resampled_vel[new_worm_idx], resampled_acc[new_worm_idx], behavior_input[new_worm_idx]], axis=1)  
        X_new_tensor = torch.tensor(X_new_worm, dtype=torch.float32)
        
        if T > 0:
            X_new_seq1 = create_X_sequences(X_new_tensor, T).numpy() #torch.Size([475, 5, 4]), = n_frames, time delay, 4 is the feature  #
        else: 
            X_new_seq1 = X_new_tensor.numpy()
        
        n_frames, delay, n_features = X_new_seq1.shape
        X_new_seq1 = X_new_seq1.reshape((n_frames,  delay*n_features)) # check how this is shaped it will be f1_t-T, f2_t-T, f3_t-T, f4_t-T, ; f1_t-T+1, f2_T+1, f3_T+1 f4_T+1
        
        # # Normalize the inputs
        # X_train_mean, X_train_std = X_new_seq1.mean(axis=0), X_new_seq1.std(axis=0)
        # if normalize:
        #     X_new_seq1 = (X_new_seq1 - X_train_mean) / (X_train_std + 1e-8)  
            
        X_all_LSTM.append(X_new_seq1)
    # X_all_LSTM = np.concatenate(X_all_LSTM, axis = 1 ) # i guess should check if reshape how it reshapes.. ie if feautres are all together or not 
    X_all_LSTM = np.array(X_all_LSTM)#shape (n_tracks, n_frames, n_features)
    return X_all_LSTM



In [None]:
# X_stim_all, feature_names_all = prep_FB_inputs_features_only(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn, onset, n_timesteps,   lstm_lag = lstm_lag)#shape (n_tracks, n_frames, n_features)
X_stim_all, feature_names_all =  prep_FB_inputs_feature_and_donut(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn, inferred_phases_all, inferred_rad_all,resampled_onset, n_tracks,   n_timesteps, lstm_lag = 5)
n_timesteps = 11
# feature_names = np.array([[f"vel_t-{n_timesteps - t - 1}", f"accel_t-{n_timesteps - t - 1}",
#                             f"curv_t-{n_timesteps - t - 1}", f"beh_t-{n_timesteps - t - 1}"]
#                             for t in range(n_timesteps)]).flatten()
for i in range(len(resampled_onsets)):
    model,X_train, X_test, y_train, y_test, c  = best_models[(i, n_timesteps)] 
    
    
    X_stim_features, feature_names = prep_FB_inputs_features_only(resampled_vel, resampled_acc, resampled_curve, resampled_rev, resampled_turn, resampled_onset,   n_timesteps, lstm_lag = lstm_lag)#shape (n_tracks, n_frames, n_features)
    X_donut_stim, donut_names = prep_FB_inputs_donut_only( inferred_phases_all, inferred_rad_all, resampled_onset, n_tracks)
    X_stim_all =  np.concatenate([X_stim_features ,  X_donut_stim], axis = 1)
    feature_names_all = np.array(feature_names.tolist()+donut_names)

    model_label= f"{neuron} stim: {i}; linear SVC; donut and features"#; lasso_a{np.round(alpha,2)}"
    visualize_model_classification(model, model_label, feature_names_all, X_train, X_test, y_train, y_test, n_timesteps,  feature_names_ordered = False, coeffs = model.coef_[0])#, xlim = [-1,1])
    
plt.show()


In [None]:
from matplotlib.backends.backend_pdf import PdfPages
pdf_save_dir = "/Users/friederikebuck/Downloads/worm notes/rslds_all_dates/"
os.makedirs(pdf_save_dir, exist_ok=True)


exp_dates = get_exp_dates() 
colors = ["dodgerblue", "magenta", "navy", "orange", "green"]
all_figs = []
for exp_i, (date, z_w, q_z_w, labels, traces) in enumerate(zip(exp_dates, z, q_z, full_neural_labels, full_traces)):#(len(z)):
    # Plot the true and inferred states
    
    n_sections = 4
    # fig, axs = plt.subplots(2,1, figsize=(18,6))
    fig, axs = plt.subplots(2*n_sections,1, figsize=(18,6*n_sections))
    axs[0].imshow(z_w[None,:], aspect="auto", cmap=cmap, alpha=0.3, vmin=0, vmax=len(palette))
    axs[1].imshow(q_z_w[None,:], aspect="auto", cmap=cmap, alpha=0.3, vmin=0, vmax=len(palette))
    axs[0].set_yticks([]); axs[1].set_yticks([])
    axs[0].set_title(f"{date}; Beh"); axs[1].set_title("Inferred by rSLDS")
    axs[1].set_xticks([])
    

    ###plot RID; RIB; AVB 
    neurons = ["RID", "RIB", "AVB"]
    colors = ["darkred", "crimson", "purple"]
    neuron_to_color = dict(zip(neurons, colors))
    plot_states_and_neurons(neurons, neuron_to_color, z_w, q_z_w, fig = fig, axs = axs[2:4])
    
    
    neurons = ["RIM", "AIB", "AVA"]
    colors = ["dodgerblue", "darkmagenta", "navy"]
    neuron_to_color = dict(zip(neurons, colors))
    plot_states_and_neurons(neurons, neuron_to_color, z_w, q_z_w, fig = fig, axs = axs[4:6])
    
        
    ###plot AIY; RIM; RIB 
    neurons = ["OLQ", "URY", "OLL", "RIV"]
    colors = ["olivedrab", "seagreen", "darkslategray", "purple"]
    neuron_to_color = dict(zip(neurons, colors))
    plot_states_and_neurons(neurons, neuron_to_color, z_w, q_z_w, fig = fig, axs = axs[6:])
    

    all_figs.append(fig)

# with PdfPages(os.path.join(pdf_save_dir,f"all_dates.pdf")) as pdf:
#     for fig in all_figs: 
#         pdf.savefig(fig)
#         # break

In [None]:
import numpy as np
from itertools import combinations
import itertools
import random
import matplotlib.pyplot as plt
import seaborn as sns

    
# 1. Extract feature means for each occurrence
from collections import defaultdict
occurrence_means = defaultdict(list)
from scipy.stats import ttest_ind



def get_pairwise_distances(f1_means, f2_means, n_permutations = 1000 ):
    # Pairwise comparisons
    # rand_i =
    rand_i = np.array([random.randint(0,  f2_means.shape[0]-1) for _ in range(n_permutations)])
    within_dists = []
    between_dists = []
    # comb_is = [[[v1,v2] for v1 in range(len(f1_means))] for v2 in range(len(f2_means))]
    comb_is = np.array(list(itertools.product( range(len(f1_means)),  range(len(f2_means)))))[rand_i]
    for i, j in comb_is:
        dist = f1_means[i] - f2_means[i]
        # # dist = np.abs(X_occ[i] - X_occ[j])  # or use squared distance
        # if labels_occ[i] == labels_occ[j]:
        #     within_dists.append(dist)
        # else:
        between_dists.append(dist)

    # Convert to arrays
    # within_dists = np.array(within_dists)/
    between_dists = np.array(between_dists)
    return between_dists


def get_means_from_occurances(match_start_idx, match_end_idx, features):
    occurrence_means = []
    for start, end in zip(match_start_idx, match_end_idx):
       
        mean_feat = features[start:end].nanmean(axis=0)  # mean over this occurrence
        occurrence_means.append(mean_feat)
    occurrence_means = np.vstack(occurrence_means)
    return occurrence_means




def find_runs(x):
    """Find start indices, end indices, and values of runs in a 1D array."""
    n = len(x)
    if n == 0:
        return np.array([], dtype=int), np.array([], dtype=int), np.array([], dtype=x.dtype)

    change_idx = np.diff(x, prepend=x[0]-1).nonzero()[0]
    start_idx = change_idx
    end_idx = np.append(change_idx[1:], n)
    values = x[start_idx]
    return start_idx, end_idx, values
def filter_runs_for_duration(start_idx, end_idx, thresh):
    good_is = np.argwhere(end_idx - start_idx > thresh).flatten()
    return start_idx[good_is], end_idx[good_is]


def plot_feature_avg_diff_heatmap(f1_means, f2_means, feature_labels=None, title="Feature Differences"):
    # 1. Compute mean difference
    mean_diff = np.mean(f1_means, axis=0) - np.mean(f2_means, axis=0)
    abs_diff = np.abs(mean_diff)

    # 2. Sort by magnitude
    sort_idx = np.argsort(abs_diff)[::-1]
    sorted_diff = mean_diff[sort_idx]

    if feature_labels is None:
        feature_labels = [f"feat_{i}" for i in range(len(mean_diff))]
    sorted_labels = np.array(feature_labels)[sort_idx]

    # 3. Plot as heatmap
    plt.figure(figsize=(10, 1 + 0.2 * len(sorted_diff)))
    sns.heatmap(sorted_diff[:, np.newaxis], annot=True, fmt=".2f", cmap="vlag",
                yticklabels=sorted_labels, xticklabels=["Mean Diff (match - miss)"], cbar=True)
    plt.title(title)
    plt.tight_layout()
    plt.show()

def plot_feature_diff_summary(f1_means, f2_means, feature_labels=None, title="Feature Differences"):
    # 1. Compute mean difference
    mean_diff = np.mean(f1_means, axis=0) - np.mean(f2_means, axis=0)
    abs_diff = np.abs(mean_diff)

    # 2. Sort by absolute magnitude
    sort_idx = np.argsort(abs_diff)[::-1]
    sorted_diff = mean_diff[sort_idx]

    # 3. Handle labels
    if feature_labels is None:
        feature_labels = [f"feat_{i}" for i in range(len(mean_diff))]
    sorted_labels = np.array(feature_labels)[sort_idx]

    # 4. Plot: heatmap and bar graph
    fig, axs = plt.subplots(1, 2, figsize=(14, 0.4 * len(sorted_diff) + 2), gridspec_kw={'width_ratios': [1, 2]})

    # 4a. Heatmap
    sns.heatmap(sorted_diff[:, np.newaxis], annot=True, fmt=".2f", cmap="vlag",
                yticklabels=sorted_labels, xticklabels=["Diff"], cbar=True, ax=axs[0])
    axs[0].set_title("Heatmap")

    # 4b. Bar graph
    axs[1].barh(np.arange(len(sorted_diff)), sorted_diff, color='skyblue', edgecolor='k')
    axs[1].set_yticks(np.arange(len(sorted_diff)))
    axs[1].set_yticklabels(sorted_labels)
    axs[1].invert_yaxis()  # Highest diff on top
    axs[1].set_xlabel("Mean Difference (F1 - F2)")
    axs[1].set_title("Bar Plot")

    plt.suptitle(title)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()


def plot_pairwise_distance_heatmaps(within_dists, between_dists, feature_labels=None, title="Pairwise Distance Summary"):    
# def plot_pairwise_distance_heatmaps(f1_means, f2_means, feature_labels=None, title="Pairwise Distance Summary"):
    # 1. Compute mean pairwise distances per feature
    # within_dists = get_pairwise_distances(f1_means, f1_means)   # shape (n_features,)
    # between_dists = get_pairwise_distances(f1_means, f2_means)  # shape (n_features,)
    diff_dists = between_dists - within_dists                   # shape (n_features,)

    # 2. Sort by difference
    sort_idx = np.argsort(diff_dists)[::-1]

    within_sorted = within_dists[sort_idx]
    between_sorted = between_dists[sort_idx]
    diff_sorted = diff_dists[sort_idx]

    if feature_labels is None:
        feature_labels = [f"feat_{i}" for i in range(len(diff_dists))]
    labels_sorted = np.array(feature_labels)[sort_idx]

    # 3. Stack for heatmap display
    data_matrix = np.vstack([
        within_sorted,
        between_sorted,
        diff_sorted
    ]).T  # shape (n_features, 3)

    # 4. Plot heatmap
    plt.figure(figsize=(8, 0.5 * len(labels_sorted) + 2))
    sns.heatmap(data_matrix, annot=True, fmt=".2f", cmap="coolwarm",
                xticklabels=["Within", "Between", "Diff (B - W)"],
                yticklabels=labels_sorted, cbar=True)

    plt.title(title)
    plt.tight_layout()
    plt.show()
    
##1) concatenate beh and rsdls states 
duration_thresh = 1 
beh_states = [0,1,2]
for beh_state in [0,1,2]:
    for rsdls_state in np.difference(beh_states, beh_state):
        true_pos = np.logical_and(z == beh_state, q_z == beh_state)
        match_start_idx, match_end_idx, _ = find_runs(true_pos)
        fig, ax = plt.subplots()
        ax.hist(match_end_idx-match_start_idx, bins = 50)
        ax.set_title(f"{beh_state} match run duraitons ")
        
        miss =  np.logical_and(z == beh_state, q_z == rsdls_state)
        miss_start, miss_end , _ = find_runs(true_pos)
        fig, ax = plt.subplots()
        ax.hist(miss_end-miss_start, bins = 50)
        ax.set_title(f"{beh_state} miss {rsdls_state} run duraitons ")
        
        
        #plot run lengths 
        
        
        
        # get runs longer thanduraiton thrsh + pltot o make sure ok
        match_start_idx, match_end_idx = filter_runs_for_duration(match_start_idx, match_end_idx, duration_thresh)
        match_neural_means = get_means_from_occurances(match_start_idx, match_end_idx, neural_features)
        match_beh_means = get_means_from_occurances(match_start_idx, match_end_idx, beh_features)
                                    
                                    
        miss_start, miss_end  = filter_runs_for_duration(miss_start, miss_end , duration_thresh)
        miss_neural_means = get_means_from_occurances(match_start_idx, match_end_idx, neural_features)
        miss_beh_means = get_means_from_occurances(match_start_idx, match_end_idx, beh_features)
        
        ax.hist(match_end_idx-match_start_idx, bins = 50)
        ax.set_title(f"{beh_state} match run duraitons filtered")
        ax.hist(miss_end-miss_start, bins = 50)
        ax.set_title(f"{beh_state} miss {rsdls_state} run duraitons filtered")
        
        
        
        f1_means = match_neural_means
        f2_means = miss_neural_means
        
        # option1 Compute difference in means across occurrences;#plot heatmap of distances 
        plot_feature_avg_diff_heatmap(f1_means, f2_means, feature_labels=neural_labels, title="Feature Differences")
        plot_feature_diff_summary(f1_means, f2_means, feature_labels=neural_labels, title="Feature Differences")

        #cotpion 2 vs pairwise differneces? 
        within_dists = get_pairwise_distances(f1_means, f1_means,  n_permutations = 1000)
        between_dists = get_pairwise_distances(f1_means, f2_means,  n_permutations = 1000)
        plot_pairwise_distance_heatmaps(within_dists, between_dists, feature_labels=neural_labels, title="Pairwise Distance Summary")
        
        
        f1_means = match_beh_means
        f2_means = miss_beh_means
        beh_labels = ["accel", "vel", "curvature"]
        
                # option1 Compute difference in means across occurrences;#plot heatmap of distances 
        plot_feature_avg_diff_heatmap(f1_means, f2_means, feature_labels=neural_labels, title="Feature Differences")
        plot_feature_diff_summary(f1_means, f2_means, feature_labels=neural_labels, title="Feature Differences")

        #cotpion 2 vs pairwise differneces? 
        within_dists = get_pairwise_distances(f1_means, f1_means,  n_permutations = 1000)
        between_dists = get_pairwise_distances(f1_means, f2_means,  n_permutations = 1000)
        plot_pairwise_distance_heatmaps(within_dists, between_dists, feature_labels=neural_labels, title="Pairwise Distance Summary")
        
        ##plot heatmap of distances  ; sorting by ratio of distance  

        ##plot heatmap of distances  ; sorting by diffence of distance  


        # # 5. Sort features by importance
        # top_k = 10
        # top_idx = np.argsort(np.abs(mean_diff))[::-1][:top_k]

        # for i in top_idx:
        #     print(f"Feature {i}: Mean A={X_A[:, i].mean():.3f}, Mean B={X_B[:, i].mean():.3f}, "
        #         f"Diff={mean_diff[i]:.3f}, p={p_vals[i]:.3e}")




       
       
       ###for each occurance get mean of features (neuron acitvities); max of features; Mean of first x frames; and compare to control 
       
       
       
       
##3) 

    

In [None]:
# color ethograms by false neg afla pos etc 
# color ethograms by false neg afla pos etc 



##remove neuron ? s


In [None]:
traces, neural_labels, behavior_classification, mask = load_all_data_but_pretend_its_all_one_worm()

velocity = np.array([full_beh_data[i]["velocity"][0:1599] for i in range(len(full_beh_data))])
acceleration = np.array([full_beh_data[i]["acceleration"][1:1600] for i in range(len(full_beh_data))])
head_curvature = np.array([full_beh_data[i]["head_angle"][0:1599] for i in range(len(full_beh_data))])
worm_curvature = np.array([full_beh_data[i]["worm_curvature"][0:1599] for i in range(len(full_beh_data))])
pumping = np.array([full_beh_data[i]["pumping"][0:1599] for i in range(len(full_beh_data))])

q_z_all = np.concatenate(q_z)
z_all = np.concatenate(z)

print(traces.shape)
print(neural_labels.shape)
print(acceleration.shape)
print(q_z_all.shape)

for i in range(len(exp_dates)):
    
    n_sections = 42
    # fig, axs = plt.subplots(2,1, figsize=(18,6))
    fig, axs = plt.subplots(2*n_sections,1, figsize=(18,6*n_sections))
    axs[0].imshow(z_w[None,:], aspect="auto", cmap=cmap, alpha=0.3, vmin=0, vmax=len(palette))
    axs[1].imshow(q_z_w[None,:], aspect="auto", cmap=cmap, alpha=0.3, vmin=0, vmax=len(palette))
    axs[0].set_yticks([]); axs[1].set_yticks([])
    axs[0].set_title(f"{date}; Beh"); axs[1].set_title("Inferred by rSLDS")
    axs[1].set_xticks([])
    
    z_w = z_all[1599*i:(i+1)*1599]
    q_z_w = z_all[1599*i:(i+1)*1599]
    traces_w = traces[1599*i:(i+1)*1599]

    ###plot RID; RIB; AVB 
    neurons = ["RID", "RIB", "AVB"]
    colors = ["darkred", "crimson", "purple"]
    neuron_to_color = dict(zip(neurons, colors))
    plot_states_and_neurons(neurons, neuron_to_color, z_w, q_z_w, traces_w,labels,  fig = fig, axs = axs[2:4])
    
    
    

In [None]:

    #set up behavior classification dict
    behavior_classification = dict()
    behavior_classification["is_turn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_pause"] = np.zeros(T*len(full_traces))
    behavior_classification["is_rev"] = np.zeros(T*len(full_traces))
    behavior_classification["is_fwd"] = np.zeros(T*len(full_traces))
    behavior_classification["is_revturn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_purerev"] = np.zeros(T*len(full_traces))
    behavior_classification["is_pureturn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_rev_of_rev_turn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_turn_of_rev_turn"] = np.zeros(T*len(full_traces))

    #fill it in
    w=0
    for bc in full_beh_classification:
        for key in bc.keys():
            behavior_classification[key][w*T:(w+1)*T] = bc[key][1:(T+1)] # shifting by one bc of the trace
        w+=1 #update worm index


In [None]:

def load_all_data_but_pretend_its_all_one_worm():
    # This function concatenates everything into one worm and also returns a mask that tells you which neurons are and arent present
    # Thing to ponder: adding columns of NaNs between worms would probably help the rSLDS learn better
    full_traces, full_neural_labels, full_beh_classification, full_beh_data = load_all_data()

    #get list of all neurons recorded in at least one trial
    neural_labels_set = set()
    for nl in full_neural_labels:
        neural_labels_set = neural_labels_set.union(set(nl))
    neural_labels = np.sort(list(neural_labels_set))

    #now make the traces array, where the different recordings are concatenated and matched by neuron
    #note: a couple of worms have 15 extra timesteps. I am truncating those to make things easier down the line
    T = 1599
    traces = np.zeros((T*len(full_traces), neural_labels.shape[0]))*np.nan #initialize traces to nan 
    #build the traces matrix
    w=0
    for tr, labels, in zip(full_traces, full_neural_labels): #for each worm, w
        for i in range(tr.shape[1]): #for each neuron, i
            label = labels[i] #get neuron name
            idx = np.where(neural_labels==label)[0][0] #get index in full array
            traces[w*T:(w+1)*T, idx] = tr[0:T,i] #put the neuron's activity in the appropriate spot
        w+=1 #update worm index

    

    #set up behavior classification dict
    behavior_classification = dict()
    behavior_classification["is_turn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_pause"] = np.zeros(T*len(full_traces))
    behavior_classification["is_rev"] = np.zeros(T*len(full_traces))
    behavior_classification["is_fwd"] = np.zeros(T*len(full_traces))
    behavior_classification["is_revturn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_purerev"] = np.zeros(T*len(full_traces))
    behavior_classification["is_pureturn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_rev_of_rev_turn"] = np.zeros(T*len(full_traces))
    behavior_classification["is_turn_of_rev_turn"] = np.zeros(T*len(full_traces))

    #fill it in
    w=0
    for bc in full_beh_classification:
        for key in bc.keys():
            behavior_classification[key][w*T:(w+1)*T] = bc[key][1:(T+1)] # shifting by one bc of the trace
        w+=1 #update worm index

    #return mask of nan data
    mask =  (~np.isnan(traces)).astype(int)
    return traces, neural_labels, behavior_classification, mask