In [None]:
from cedne import utils
import os
import json
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import matplotlib.cm as cm
import pywt
import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.decomposition import tucker
from tensorly.decomposition import partial_tucker
from tensorly.tenalg import multi_mode_dot
from tensorly import kruskal_to_tensor
from sklearn.preprocessing import StandardScaler
from scipy.ndimage import gaussian_filter1d

In [None]:
def calculate_cross_corr(n1,n2, window=(0,-1,1)):
    start, end, step = window
    return np.corrcoef(n1[start:end:step], n2[start:end:step])[0,1]

In [None]:
def wavelet_denoising(signal, wavelet='db4', level=4, threshold_factor=1.):
    coeffs = pywt.wavedec(signal, wavelet, mode='symmetric', level=level)
    
    # Estimate the universal threshold (Median Absolute Deviation)
    sigma = np.median(np.abs(coeffs[-level])) / 0.6745  
    threshold = threshold_factor * sigma * np.sqrt(2 * np.log(len(signal)))
    
    # Apply thresholding
    coeffs[1:] = [pywt.threshold(c, threshold, mode='soft') for c in coeffs[1:]]
    denoised_signal = pywt.waverec(coeffs, wavelet, mode='symmetric')
    
    return denoised_signal

In [None]:
def simpleaxis(axes, every=False, outward=False):
    if not isinstance(axes, (list, np.ndarray)):
        axes = [axes]
    for ax in axes:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        if (outward):
            ax.spines['bottom'].set_position(('outward', 10))
            ax.spines['left'].set_position(('outward', 10))
        if every:
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
        ax.get_xaxis().tick_bottom()
        ax.get_yaxis().tick_left()
        ax.set_title('')


In [None]:
jsons = {}
for js in os.listdir('/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/'):
    with open ("/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/{}".format(js), 'r') as f:
        jsons['Atanas et al (2023) ' +  js] = json.load(f)

In [None]:
w = utils.makeWorm('atanas')
nn = w.networks['Neutral']

In [None]:
weight_mat = utils.loadSynapticWeights(nn)

In [None]:
measuredNeurons = {}
neuron_labels = []
for js, p in jsons.items():
    sortedKeys = sorted ([int(x) for x in (p['labeled'].keys())])
    labelledNeurons = {p['labeled'][str(x)]['label']:x for x in sortedKeys if not '?' in p['labeled'][str(x)]['label']} # Removing unsure hits
    measuredNeurons[js] = {m:i for i,m in enumerate(set(labelledNeurons))}
    neuron_labels+=measuredNeurons[js].keys()
neuron_labels = sorted(set(neuron_labels))

In [None]:
jsons.keys()
database = 'Atanas et al (2023) 2022-06-14-01.json'
num_timepoints = len(jsons[database]['trace_array'][measuredNeurons[database][neuron_labels[20]]])

In [None]:
for neuron in nn.neurons:
    if neuron in measuredNeurons[database]: 
        nn.neurons[neuron].set_property('amplitude', jsons[database]['trace_array'][measuredNeurons[database][neuron]])

In [None]:
def return_window_correlation(database, window_sizes=range(5,50,10), num_samples=100):
    by_window = {}
    #steps = range(0,num_timepoints,t_step)
    for window_size in window_sizes:
        steps = sorted(np.random.randint(0, num_timepoints - window_size, num_samples))
        num_steps = len(steps)
        corr_window = np.empty((len(neuron_labels), len(neuron_labels)))
        corr_window[:] = np.nan
        for i,n1 in enumerate(neuron_labels):
            for j in range(i+1, len(neuron_labels)):
                n2 = neuron_labels[j]
                if n1 in measuredNeurons[database] and n2 in measuredNeurons[database]:
                    corr_window[i,j] = 0.
                    for tstart in steps:
                        window = (tstart,tstart+window_size,1)
                        corr_window[i,j] += calculate_cross_corr(jsons[database]['trace_array'][measuredNeurons[database][n1]], jsons[database]['trace_array'][measuredNeurons[database][n2]], window=window)
        corr_window/=num_steps
        by_window[window_size] = corr_window
    return by_window

In [None]:
# window_sizes=range(1,250,25)
# by_database = {}
# for database in jsons.keys():
#     by_database[database] = return_window_correlation(database=database, window_sizes=window_sizes)

In [None]:
# avg_by_database = np.nanmean(np.stack([by_database[database][window_sizes[5]] for database in list(jsons.keys())]), axis=0)

In [None]:
# f, ax = plt.subplots(figsize=(24,24))
# cbar = ax.pcolor(avg_by_database, vmin=-1, vmax=1, cmap='PuOr')
# ax.set_yticks(np.arange(len(neuron_labels))+0.5, neuron_labels)
# ax.set_xticks(np.arange(len(neuron_labels))+0.5, neuron_labels, rotation=45, ha='right')
# plt.colorbar(cbar)
# plt.show()

In [None]:
# f, ax = plt.subplots(figsize=(24,24))
# cbar = ax.pcolor(np.tril(avg_by_database,k=-1), vmin=-1, vmax=1, cmap='PuOr')
# ax.set_yticks(np.arange(len(neuron_labels))+0.5, neuron_labels)
# ax.set_xticks(np.arange(len(neuron_labels))+0.5, neuron_labels, rotation=45, ha='right')
# plt.colorbar(cbar)
# plt.show()

In [None]:
for corr_thres in np.linspace(0,1,10):
    connected_syn = []
    connected_all = []
    xind, yind = np.where(np.abs(np.tril(avg_by_database,k=-1))>corr_thres)
    if len(xind):
        for (x,y) in zip(xind, yind):
            connected_all.append(tuple(sorted([neuron_labels[x], neuron_labels[y]])))
            if (nn.neurons[neuron_labels[x]], nn.neurons[neuron_labels[y]],0) in nn.connections.connections or (nn.neurons[neuron_labels[y]], nn.neurons[neuron_labels[x]],0) in nn.connections.connections:
                connected_syn.append(tuple(sorted([neuron_labels[x], neuron_labels[y]])))
        print(corr_thres, len(set(connected_syn))/len(connected_all))

In [None]:
# corr, minLength = {x:[] for x in range(1,7)},  []
# for x in range(np.shape(avg_by_database)[0]):
#     for y in range(x+1,np.shape(avg_by_database)[1]):
#         minLength.append(np.min([utils.nx.shortest_path_length(nn, nn.neurons[neuron_labels[x]], nn.neurons[neuron_labels[y]]), utils.nx.shortest_path_length(nn, nn.neurons[neuron_labels[y]], nn.neurons[neuron_labels[x]])]))
#         corr[minLength[-1]].append(avg_by_database[x,y])

# f, ax = plt.subplots()
# for k in corr.keys():
#     ax.scatter(np.nanmean(corr[k]), k)
# plt.show()

In [None]:
# corr_thres = 0.25
# avg_by_database [np.where(np.abs(avg_by_database)<corr_thres)] = np.nan
# f, ax = plt.subplots(figsize=(24,24))
# cbar = ax.pcolor(avg_by_database, vmin=-1, vmax=1, cmap='PuOr')
# ax.set_yticks(np.arange(len(neuron_labels))+0.5, neuron_labels)
# ax.set_xticks(np.arange(len(neuron_labels))+0.5, neuron_labels, rotation=45, ha='right')
# plt.colorbar(cbar)
# plt.show()

In [None]:
# np.where(avg_by_database>corr_thres)

In [None]:
# avg_corr = []
# for window_size in window_sizes:
#     avg_corr_by_window = []
#     for database in jsons.keys():
#         avg_corr_by_window.append(np.nanmean(np.abs(np.tril(by_database[database][window_size],k=-1))))
#     avg_corr.append(np.mean(avg_corr_by_window))

In [None]:
# avg_corr

In [None]:
# plt.plot(np.array(window_sizes)*jsons[database]['avg_timestep'], avg_corr)
# plt.xlabel("Time (mins)")
# plt.show()

In [None]:
# print(window_sizes[np.argmax(avg_corr)])

In [None]:
# print(window_sizes[np.argmax(avg_corr)])

In [None]:
worms = jsons.keys()

In [None]:
behav_var = []
for worm in list(worms)[:1]:
    for key in jsons[worm].keys():
        if isinstance(jsons[worm][key], list):
            if len(jsons[worm][key]) == jsons[worm]['max_t']:
                behav_var.append(key)

In [None]:
behav_var.append('head_frequency')
behav_var.append('acceleration')

In [None]:
ncounts = {neuron: 0 for neuron in neuron_labels}
nthres = 8
for database in jsons:
    for neuron in measuredNeurons[database]:
        ncounts[neuron]+=1
    
print(ncounts, len(ncounts))
popns = [n for n in ncounts if ncounts[n]<nthres]
_ = [ncounts.pop(p) for p in popns]

In [None]:
nlabels = list(ncounts.keys())
type_list = ['sensory', 'interneuron', 'motorneuron']
nindices = []
for ty in type_list:
    for n in nn.neurons:
        if nn.neurons[n].type == ty and n in nlabels:
            nindices.append(list(nlabels).index(n))

In [None]:
threshold_factor = 1.47
rank = 20
method = 'parafac'
full_tensor = np.empty((len(jsons), len(nlabels), num_timepoints))
worms = list(jsons.keys())
for i,worm in enumerate(worms):
    for j,n in enumerate(nlabels):
        if n in measuredNeurons[worm]:
            full_tensor[i,j,:] = wavelet_denoising(jsons[worm]['trace_array'][measuredNeurons[worm][n]][:num_timepoints], wavelet='db6', level=2, threshold_factor=threshold_factor)

if method == 'parafac':
    weights, factors = parafac(full_tensor, rank=rank, init='random', n_iter_max=100)
    imputed_tensor = kruskal_to_tensor((weights, factors))
    #best_core, best_factors = tucker(imputed_tensor, rank=rank)

## It looks like 100 points (1 minute) is enough to capture max correlations. Making a Hankel matrix at 100 points, with delays and then a 4D tensor that contains all this delay embedding in a fresh dimension.

In [None]:
def create_hankel_matrix(tensor, k):
    W, N, T = tensor.shape
    T_2 = T - k + 1
    
    # Initialize the resulting tensor with zeros
    hankel_tensor = np.zeros((W, N, T_2, k))
    
    # Populate the hankel_tensor
    for w in range (W):
        for n in range(N):
            for t_2 in range(T_2):
                hankel_tensor[w, n, t_2, :] = tensor[w, n, t_2:t_2 + k]
    
    return hankel_tensor


def moving_window_average(vector, window_size, step_size):
    n = len(vector)
    averages = []
    
    for start in range(0, n - window_size + 1, step_size):
        window = vector[start:start + window_size]
        window_avg = np.mean(window)
        averages.append(window_avg)
    
    return np.array(averages)

def explained_variance(true_tensor, core, factors):
    reconstructed = tl.tenalg.multi_mode_dot(core, factors)
    total_variance = tl.norm(true_tensor, 2) ** 2
    explained_var = tl.norm(reconstructed, 2) ** 2 / total_variance
    return explained_var

In [None]:
worm

In [None]:
imputed_tensor.shape

In [None]:
window_size = 25 #100 

t_rank = window_size #10 
n_rank = 15 #10 
d_rank = num_timepoints - window_size +1
w_rank = len(jsons.keys())

# Define a grid of ranks
rank_grid = {
    'w_rank': w_rank,
    'n_rank': n_rank,
    'd_rank': d_rank,
    't_rank': t_rank
}

core_worms = {}
factor_worms = {}
#for worm, wormname in zip(imputed_tensor, jsons.keys()):
#    hankel_transformed_tensor = create_hankel_matrix(worm, window_size)
hankel_transformed_tensor = create_hankel_matrix(imputed_tensor, window_size)
X = tl.tensor(hankel_transformed_tensor)

In [None]:
ranks = [
        rank_grid['w_rank'], 
        rank_grid['n_rank'], 
        rank_grid['d_rank'],
        rank_grid['t_rank']
        ]
best_core, best_factors = tucker(X, rank=ranks)
#core_worms[wormname] = best_core
#factor_worms[wormname] = best_factors
explained_vars = explained_variance(hankel_transformed_tensor, best_core, best_factors)
print(explained_vars)

In [None]:
best_factors[1].T.shape

In [None]:
factor_num = 1
# for worm in jsons.keys():
#best_core, best_factors = core_worms[worm], factor_worms[worm]
# f, ax = plt.subplots(figsize=(2*10,10), layout='constrained', nrows=ranks[factor_num])
for j,fac in enumerate(best_factors[factor_num].T[:10]):
    # if factor_num == 0:
    #     f, ax = plt.subplots(figsize=(24,3), layout='constrained')
    #     ax.plot(fac)
    #     ax.set_xticks(np.arange(len(best_factors[0])), np.arange(len(best_factors[0]))+1)
    #     simpleaxis(ax)
    #     plt.show()
    if factor_num == 1:
        f, ax = plt.subplots(figsize=(24,3), layout='constrained')
        ax.plot(fac)
        ax.set_xticks(range(len(nlabels)), nlabels, rotation=45)
        simpleaxis(ax)
        plt.show()
    if factor_num == 2:
        # f, ax = plt.subplots(figsize=(24,3), layout='constrained')
        ax[j].plot(fac)
        ax[j].set_xticks(np.arange(0,len(best_factors[factor_num]),10),np.arange(0,len(best_factors[factor_num]),10), rotation=45)
        # simpleaxis(ax)
        # plt.show()
    if factor_num == 3:
        # f, ax = plt.subplots(figsize=(3,2), layout='constrained')
        ax[j].plot(np.arange(window_size)*jsons[database]['avg_timestep'], fac)
        ax[j].set_xticks(np.linspace(0,len(best_factors[factor_num])*jsons[database]['avg_timestep'],2))
simpleaxis(ax)
plt.show()

In [None]:
projected_data = tl.tenalg.multi_mode_dot(best_core, best_factors, skip=1)

In [None]:
f, ax = plt.subplots(figsize=(6,12), layout='constrained')
cbar = ax.pcolor(best_factors[1])
ax.set_xticks(np.arange(n_rank)+0.5, np.arange(n_rank)+1)
ax.set_yticks(np.arange(len(nlabels))+0.5, nlabels)
simpleaxis(ax)
f.colorbar(cbar)
plt.show()

In [None]:
### Find where the cutoff for component weight crosses 90% of the CDF
cutoff = 0.75
hist = np.histogram(np.abs(best_factors[1].ravel()), bins=np.linspace(0,0.35,30), density=True)

cdf = np.cumsum(hist[0])
cdf/=cdf[-1]
for i in range(len(cdf)):
    if cdf[i] >= cutoff:
        break
thres = hist[1][i]
print(thres)


In [None]:
# thres = 0.145
# cutoff = 0.9
f, ax = plt.subplots(figsize=(1.5,1.5))
ax.hist(np.abs(best_factors[1].ravel()), bins=np.linspace(0,0.35,30), color='gray', cumulative=True, density=True)
ax.axvline(thres, color='red', ls='--')
ax.axhline(cutoff, color='red', ls='--')
ax.set_ylabel("Cumulative Density")
ax.set_xlabel("Component weight")
simpleaxis(ax)
plt.show()

In [None]:
for fac in best_factors[1].T:
    f, ax = plt.subplots(figsize=(1,1), layout='constrained')
    ax.hist(np.abs(fac), bins=np.linspace(0,0.3,25), color='gray')
    ax.axvline(thres, color='red', ls='--')
    simpleaxis(ax)
    plt.show()

In [None]:
nlabels_by_factor = []
for fac in best_factors[1].T:
    nlabels_by_factor.append([nlabels[j[0]] for j in np.argwhere(np.abs(fac)>thres)])

In [None]:
for i,j in (enumerate(nlabels_by_factor)):
    print(f"{i}: {len(j)}: {j}")

In [None]:
subnet_arr = {}
for j in range(len(nlabels_by_factor)):
    print(j)
    subnet_arr[j] = nn.subnetwork(neuron_names=nlabels_by_factor[j], name=f"component_{j}")

In [None]:
f, ax = plt.subplots(figsize=(6,12), layout='constrained')
cbar = ax.pcolor(best_factors[1])
ax.set_xticks(np.arange(n_rank)+0.5, np.arange(n_rank)+1)
ax.set_yticks(np.arange(len(nlabels))+0.5, nlabels)
simpleaxis(ax)
f.colorbar(cbar)
plt.show()

In [None]:
for j in range(len(nlabels_by_factor)):
    print(j)
    utils.plot_spiral(subnet_arr[j], figsize=(5,5), save=f'./tucker-decomposition-hankel/spiral_neuron_component_{window_size}-{j}.png')

In [None]:
projected_convolved = np.zeros((projected_data.shape[0], projected_data.shape[1], projected_data.shape[2]+projected_data.shape[3]-1))
for ind in range(projected_data.shape[1]):
    for j in range(projected_data.shape[0]):
        for k in range(projected_data.shape[2]):
            projected_convolved[j,ind,k:k+projected_data.shape[3]] += projected_data[j,ind,k,:]

In [None]:
# def normalized_cross_correlation(activity, behavior):
#     # Compute raw cross-correlation
#     raw_corr = np.correlate(activity, behavior, mode='full')
    
#     # Lengths of signals
#     n = len(activity)
#     m = len(behavior)
    
#     # Precompute sums of squares for normalization
#     activity_energy = np.sum(activity ** 2)
#     behavior_energy= np.sum(behavior ** 2)
    
#     # Normalize each lag
#     norm_corr = raw_corr / np.sqrt(activity_energy * behavior_energy)
    
#     # Remove any NaN values (if behavior_energy is zero in some regions)
#     norm_corr = np.nan_to_num(norm_corr)
    
#     return norm_corr
from scipy.signal import correlate
def normalized_cross_correlation(activity, behavior, lag_min=-50, lag_max=50):
    raw_corr = np.correlate(activity, behavior, mode='full')

    # Generate all possible lags
    lags = np.arange(-len(activity) + 1, len(behavior))

    # Restrict to the specified lag range
    valid_lag_indices = (lags >= lag_min) & (lags <= lag_max)
    restricted_corr = raw_corr[valid_lag_indices]
    restricted_lags = lags[valid_lag_indices]

    # Normalize for each lag
    norm_corr = []
    for lag in restricted_lags:
        # Calculate the overlapping segments
        if lag < 0:  # behavior leads activity
            overlap_activity = activity[-lag:]
            overlap_behavior = behavior[:len(activity) + lag]
        else:  # activity leads behavior
            overlap_activity = activity[:len(behavior) - lag]
            overlap_behavior = behavior[lag:]

        # Compute energy and normalize
        activity_energy = np.sum(overlap_activity ** 2)
        behavior_energy = np.sum(overlap_behavior ** 2)
        norm_factor = np.sqrt(activity_energy * behavior_energy)

        if norm_factor > 0:
            norm_corr.append(restricted_corr[np.where(restricted_lags == lag)[0][0]] / norm_factor)
        else:
            norm_corr.append(0)
    return norm_corr, restricted_lags

In [None]:
### Rate of change of sensory stream? Velocity calculation? Successive feedforward loop motifs. Acceleration?

In [None]:
def calculate_moving_window_zero_crossing_frequency_same_length(time_array, direction_array, window_size):
    """
    Calculate zero-crossing frequency with a moving window, returning an array of the same length.

    Parameters:
    - time_array: 1D NumPy array of time values.
    - direction_array: 1D NumPy array of head direction values.
    - window_size: Window size in seconds.

    Returns:
    - A NumPy array of zero-crossing frequencies, same length as the input arrays.
    """
    # Ensure arrays are sorted by time
    sorted_indices = np.argsort(time_array)
    time_array = time_array[sorted_indices]
    direction_array = direction_array[sorted_indices]
    
    # Calculate zero crossings
    zero_crossings = np.diff(np.sign(direction_array)) != 0
    zero_crossings = np.insert(zero_crossings, 0, False)  # Pad to match original array length
    
    # Compute half-window size in terms of indices
    half_window = int(window_size / 2 / np.mean(np.diff(time_array)))  # Convert seconds to indices
    
    # Initialize result array
    frequencies = np.zeros_like(time_array, dtype=float)
    
    for i in range(len(time_array)):
        # Define the window range (clamp to array bounds)
        start_idx = max(0, i - half_window)
        end_idx = min(len(time_array), i + half_window)
        
        # Count zero crossings in the window
        zero_crossing_count = np.sum(zero_crossings[start_idx:end_idx])
        
        # Calculate frequency (bobs per second)
        window_duration = time_array[end_idx - 1] - time_array[start_idx]
        frequencies[i] = zero_crossing_count / window_duration if window_duration > 0 else 0

    return frequencies

def calculate_acceleration(velocity_array, time_array, window_size=5):
    """
    Calculate acceleration from velocity data and smooth it using a Gaussian filter.

    Parameters:
    - velocity_array: 1D NumPy array of velocity values.
    - time_array: 1D NumPy array of time values.
    - smoothing_sigma: Standard deviation for Gaussian kernel (higher = more smoothing).

    Returns:
    - A NumPy array of smoothed acceleration values, same length as the input arrays.
    """
    # Ensure the arrays are sorted by time
    sorted_indices = np.argsort(time_array)
    time_array = time_array[sorted_indices]
    velocity_array = velocity_array[sorted_indices]
    
    # Calculate differences in velocity and time
    dv = np.diff(velocity_array, prepend=velocity_array[0])
    dt = np.diff(time_array, prepend=time_array[0])
    
    # Avoid division by zero
    dt[dt == 0] = np.nan
    
    # Compute acceleration (dv/dt)
    acceleration = dv / dt
    
    # Handle edge cases (e.g., divide by zero leading to NaN at first index)
    acceleration[0] = acceleration[1] if len(acceleration) > 1 else 0
    
    # Smooth acceleration using a Gaussian filter
    #smoothed_acceleration = gaussian_filter1d(acceleration, sigma=smoothing_sigma)
    smoothed_acceleration = np.convolve(acceleration, np.ones(window_size) / window_size, mode='same')
    
    return smoothed_acceleration

In [None]:
for j, worm in enumerate(worms):
    beh = 'head_curvature'
    time = np.linspace(0, jsons[worm]['max_t'], len(jsons[worm][beh]))[:projected_convolved.shape[2]]
    head_freq = calculate_moving_window_zero_crossing_frequency_same_length(time, np.array(jsons[worm][beh]), 25)
    plt.plot(jsons[worm][beh], color='gray')
    plt.plot(head_freq, color='orange')
    plt.show()

In [None]:
for j, worm in enumerate(worms):
    time = np.linspace(0, jsons[worm]['max_t'], len(jsons[worm]['head_curvature']))[:projected_convolved.shape[2]]
    jsons[worm]['head_frequency'] = calculate_moving_window_zero_crossing_frequency_same_length(time, np.array(jsons[worm]['head_curvature'])[:projected_convolved.shape[2]], window_size)

    time = np.linspace(0, jsons[worm]['max_t'], len(jsons[worm]['velocity']))[:projected_convolved.shape[2]]
    jsons[worm]['acceleration'] = calculate_acceleration(np.array(jsons[worm]['velocity'])[:projected_convolved.shape[2]], time, window_size)

In [None]:
behav_activity = {b: [] for b in behav_var}
for j, worm in enumerate(worms):
    f, ax = plt.subplots(figsize=(12,(1+len(behav_var))*2), layout='constrained', nrows=len(behav_var)+1)
    for ind in range(projected_data.shape[1]):
        ax[0].plot(projected_convolved[j,ind,:])
        simpleaxis(ax[0])
    mapping = {}
    for k,beh in enumerate(behav_var):
        time = np.linspace(0, jsons[worm]['max_t'], len(jsons[worm][beh]))[:projected_convolved.shape[2]]
        corr_ind = []
        for ind in range(projected_data.shape[1]):
            corr_ind.append(np.corrcoef(projected_convolved[j,ind,:], jsons[worm][beh][:projected_convolved.shape[2]])[0,1])
        ax[k+1].plot(time, jsons[worm][beh][:projected_convolved.shape[2]], color='#4A90E2')
        ax1 = ax[k+1].twinx()

        # Compute cross-correlation
        behavior = np.array(jsons[worm][beh][:projected_convolved.shape[2]])
        behavior_std = np.std(behavior)
        best_match = None
        best_score = -np.inf
        best_lag = 0
        for ind in range(projected_data.shape[1]):
            activity = projected_convolved[j,ind,:]
            
            # corr = correlate(activity, behavior, mode='full')

            # # Normalize cross-correlation
            # activity_std = np.std(activity)
            # norm_factor = len(activity) * len(behavior) * activity_std * behavior_std
            # normalized_corr = corr / norm_factor

            normalized_corr, all_lags = normalized_cross_correlation(activity, behavior)
            # Find the best lag
            # all_lags = np.arange(-len(activity) + 1, len(behavior))
            lag = all_lags[np.argmax(np.abs(normalized_corr))]
            score = np.abs(normalized_corr[np.argmax(np.abs(normalized_corr))])
            actual_corr = normalized_corr[np.argmax(np.abs(normalized_corr))]
            
            # Update best match
            if score > best_score:
                best_match = ind
                best_score = score
                best_lag = lag
                best_corr = actual_corr
        
            # Assign best match
        mapping[k] = (best_match, best_lag, best_corr)

        # ax1.plot(time, projected_convolved[j, np.argmax(corr_ind), :], color='#F5A623')
        ax1.plot(time, projected_convolved[j, best_match, :], color='#F5A623')
        simpleaxis(ax[k+1])
        # ax[k+1].set_title(f"{beh}, {np.argmax(corr_ind)}, {max(corr_ind)}")
        ax[k+1].set_title(f"{beh}, {mapping[k]}")
        if best_score>0.7:
            behav_activity[beh].append(best_match)
    plt.show()

In [None]:
from collections import Counter
f, ax = plt.subplots(ncols= len(behav_var), figsize=((len(behav_var))*2, 2), layout='constrained', sharex=True, sharey=True)
for k,beh in enumerate(behav_var):
    c = Counter(behav_activity[beh])
    clist, cbar = zip(*c.items())
    ax[k].bar(clist, cbar)
    simpleaxis(ax[k])
    ax[k].set_title(beh)
plt.show()

In [None]:
ind1 = 2
ind2 = 4
f, ax = plt.subplots(figsize=(4,12), layout='constrained', ncols=2, sharex=True, sharey=True)
ax[0].pcolor(np.mean(projected_data[:,:,ind1,:], axis=0))
ax[1].pcolor(np.mean(projected_data[:,:,ind2,:], axis=0))
simpleaxis(ax)
ax[0].set_yticks(np.arange(len(nlabels))+0.5, nlabels)
plt.show()

In [None]:
neuron

In [None]:
sorted(measuredNeurons[database].keys())

In [None]:
neuron = 'AWCL'
f, ax = plt.subplots(figsize=(12,2))
ax1 = ax.twinx()
ax1.plot(jsons[database]['velocity'][:-1], color='gray')
ax.plot(moving_window_average(np.diff(jsons[database]['trace_array'][measuredNeurons[database][neuron]],1), 1,1), color='purple')
plt.show()

In [None]:
np.corrcoef(jsons[database]['velocity'][:-10], moving_window_average(np.diff(jsons[database]['trace_array'][measuredNeurons[database][neuron]],1), 10,1))

In [None]:
thres = 0.5
mpro = np.mean(projected_data[:,:,ind,:50], axis=0)
print([nlabels[n] for n in sorted(set(np.where(np.abs(mpro)>thres)[0]))])
mpro = np.mean(projected_data[:,:,ind+1,50:], axis=0)
print([nlabels[n] for n in sorted(set(np.where(np.abs(mpro)>thres)[0]))])

In [None]:
nn.neurons['RMDL'].get_connections()

In [None]:
ind = 69
np.corrcoef(best_factors[factor_num].T[4][ind:], best_factors[factor_num].T[3][:-ind])[0,1]

In [None]:
comp = 0
step=50
trace_grid_space = 1
nlabs_newind = [list(nlabels)[i] for i in nindices]
# for j,w in enumerate(worms):
    # if j<2:
        #best_core, best_factors = core_worms[worm], factor_worms[worm]
projected_data = tl.tenalg.multi_mode_dot(best_core, best_factors, skip=2)
f = plt.figure(figsize=(12,12), layout='constrained')
gs = matplotlib.gridspec.GridSpec(len(behav_var)+1,1, figure = f, height_ratios=[30]*1 + [1]*len(behav_var))
ax = f.add_subplot(gs[:trace_grid_space, 0])
vm = np.max(np.abs(projected_data[:,:,:,comp]))
red_time = np.arange(projected_data[:,:,:,comp].shape[1])
real_time = np.linspace(0,jsons[w]['max_t'], projected_data[:,:,:,comp].shape[1])
proj_mat = projected_data[:,nindices,:,comp]
ax.imshow(proj_mat, extent=(0, jsons[w]['max_t'], 0, len(nindices)), cmap='PuOr', vmin=-vm, vmax=vm, aspect='auto', origin='lower')
xticks_real = real_time[::step]
xticks_red = red_time[::step]
ax.grid(True, axis='x', which='both')
simpleaxis(ax)
ax.set_xticks(xticks_real)
ax.set_yticks(np.arange(projected_data.shape[0])+0.5)
ax.set_yticklabels(nlabs_newind)
# for n in range(projected_data.shape[1]):
#     ax.plot(real_time, (n+100)*projected_data[j,n,:,comp])
#     xticks = real_time[::step]
#ax.set_xticks(xticks, [f"{x:.0f}" for x in xticks])
for k,beh in enumerate(behav_var):
    ax1 = f.add_subplot(gs[trace_grid_space+k:trace_grid_space+k+1, 0], sharex=ax)
    time = np.linspace(0, jsons[w]['max_t'], len(jsons[w][beh]))
    ax1.plot(time, jsons[w][beh])
    # ax1.set_xticks([])
    simpleaxis(ax1)
    ax1.set_title(beh)
    ax1.grid(True, axis='x')
#ax1.set_xticks(time)
plt.show()

In [None]:
comp = 0
step=50
trace_grid_space = 1
nlabs_newind = [list(nlabels)[i] for i in nindices]
for j,w in enumerate(worms):
    if j<2:
        f = plt.figure(figsize=(12,12), layout='constrained')
        gs = matplotlib.gridspec.GridSpec(len(behav_var)+projected_data.shape[1],1, figure = f, height_ratios=[1]*projected_data.shape[1] + [1]*len(behav_var))
        vm = np.max(np.abs(projected_data[j,:,:,comp]))
        red_time = np.arange(projected_data[j,:,:,comp].shape[1])
        real_time = np.linspace(0,jsons[w]['max_t'], projected_data[j,:,:,comp].shape[1])
        proj_mat = projected_data[j,:,:,comp]
        for l,p in enumerate(proj_mat):
            ax = f.add_subplot(gs[l:l+1, 0], sharex=ax)
            ax.plot(p)
            ax.grid(True, axis='x', which='both')
            simpleaxis(ax)
            xticks_real = real_time[::step]
            ax.set_xticks(xticks_real)
            ax.set_title(f"proj{l+1}")
        for k,beh in enumerate(behav_var):
            ax1 = f.add_subplot(gs[projected_data.shape[1]+k:projected_data.shape[1]+k+1, 0], sharex=ax)
            time = np.linspace(0, jsons[w]['max_t'], len(jsons[w][beh]))
            ax1.plot(time, jsons[w][beh])
            # ax1.set_xticks([])
            simpleaxis(ax1)
            ax1.set_title(beh)
            ax1.grid(True, axis='x')
        #ax1.set_xticks(time)
        plt.show()


In [None]:
corr_thres = 0.9
for comp in range(2,n_rank):
    for wind,w in enumerate(worms):
        best_core, best_factors = core_worms[worm], factor_worms[worm]
        projected_data = tl.tenalg.multi_mode_dot(best_core, best_factors, skip=2)
        proj_mat = projected_data[:,:,comp]
        for i,p in enumerate(proj_mat):
            for j,b in enumerate(behav_var):
                corr_np = np.corrcoef(p,moving_window_average(jsons[w][b][:num_timepoints], window_size, 1))[0,1]
                if np.abs(corr_np)>corr_thres:
                    print(wind, comp, (nlabels[i],b), corr_np)
                    f = plt.figure(figsize=(2,2), layout='constrained')
                    #gs = matplotlib.gridspec.GridSpec(1,12)
                    ax = f.add_subplot() #gs[:1]
                    if corr_np>0:
                        ax.plot(p, color='k', label=i+1)
                    else:
                        ax.plot(-p, color='k', label=i+1)
                    ninds = np.argsort(best_factors[0].T[i])
                    nlab_sorted = [list(nlabels)[nind] for nind in ninds]
                    ax1 = ax.twinx()
                    ax1.plot(moving_window_average(jsons[w][b][:num_timepoints], window_size, 1), color='gray', label=b)

                    # ax1 = f.add_subplot(gs[1:]) 
                    # ax1.scatter(np.arange(len(nlabels)), best_factors[1].T[i][ninds])
                    # ax1.set_xticks(np.arange(len(nlabels)), nlab_sorted, rotation=45, fontsize='x-small')
                    #simpleaxis([ax, ax1])
                    plt.legend()
                    plt.show()
    

In [None]:
factor_avg = []
#for wind,w in enumerate(worms):
    #best_core, best_factors = core_worms[w], factor_worms[w]
vm = np.abs(best_factors[1]).max() 
f, ax = plt.subplots(figsize=(4,12), layout='constrained')
ax.pcolor(best_factors[1], vmin=-vm, vmax=vm, cmap='PuOr')
plt.xticks(np.arange(n_rank)+0.5, np.arange(n_rank)+1)
plt.yticks(np.arange(len(nlabels))+0.5, nlabels)
plt.show()
#factor_avg.append(best_factors[1])

In [None]:
mean_factor = np.mean(factor_avg, axis=0)
vm = np.abs(mean_factor).max()
f, ax = plt.subplots(figsize=(4,12), layout='constrained')
cbar = ax.pcolor(mean_factor, vmin=-vm, vmax=vm, cmap='PuOr')
plt.xticks(np.arange(n_rank)+0.5, np.arange(n_rank)+1)
plt.yticks(np.arange(len(nlabels))+0.5, nlabels)
plt.colorbar(cbar)
plt.show()

In [None]:
conns = {}
ligands= ('Serotonin', ) #'Dopamine',) #)
connFilter = 'chemical-synapse' #'gap-junction'
for c,e in nn.connections.items():
    if connFilter:
        if e.connection_type == connFilter:
            #print(e.putative_neurotrasmitter_receptors)
            for ligand in ligands:
                if ligand in e.ligands:
                    for nt_rec in e.putative_neurotrasmitter_receptors:
                        if nt_rec[0] in ligands:
                            if not nt_rec in conns:
                                conns[nt_rec] = []
                            conns[nt_rec].append([c[0].name, c[1].name])

In [None]:
conns

In [None]:
all_ser = []
for key in conns.keys():
    all_ser+=conns[key][0]
    all_ser+=conns[key][1]
all_ser = set(all_ser)

In [None]:
mat_nx.shape

In [None]:
fac1_dot = best_factors[1].T[19]
conn_arr = np.linspace(np.min(fac1_dot), 0.99*np.max(fac1_dot), 100)
edges = []
fracCommon = []
commonNeurs = []
for conn_thres in conn_arr:
    mat_nx = np.abs(fac1_dot)>conn_thres
    connected_n = [nlabels[j] for j in np.argwhere(mat_nx).T[0]]
    commons = set(connected_n).intersection(all_ser)
    commonNeurs.append(commons)
    fracCommon.append(len(commons)/len(connected_n))
fracAll = len(all_ser)/len(mat_nx)

In [None]:
f, ax = plt.subplots()
ax.hist(np.ravel(fac1_dot), bins=50, cumulative=True, density=True, histtype='step')
ax.axhline(y=0.8, linestyle='--')
plt.show()

In [None]:
weights_neuron, factors_neuron = best_core, best_factors
projected_data_neuron = tl.tenalg.multi_mode_dot(best_core, best_factors, skip=1)

In [None]:
f,ax = plt.subplots(figsize=(1.5,1.5))
ax.scatter(conn_arr, fracCommon, color='gray', s=8)
ax.axhline(y=fracAll, linestyle='--', color='k')
simpleaxis(ax)
ax.set_ylim((0,1.05))
plt.show()

In [None]:
connected_n

In [None]:
np.mean(projected_data_neuron[:,j,:,0], axis=0).shape

In [None]:
weight_mat

In [None]:
best_factors[3].T[0].shape, projected_data_neuron[0,0,:,0].shape

In [None]:
t_ind = 14
m = np.convolve(best_factors[3].T[0], projected_data_neuron[0,0,:,0])
n = best_factors[3].T[t_ind]
plt.plot(n)
plt.show()

In [None]:
best_factors[3].T.shape, rank_grid['t_rank'], corr_thres, projected_data_neuron[:,j,:,:].shape

In [None]:
def moving_average_recompose(recon_tensor):
    reconstructed_matrix = np.zeros((recon_tensor.shape[0], recon_tensor.shape[1]+recon_tensor.shape[2]-1))
    weight_matrix = np.zeros_like(reconstructed_matrix)
    for i in range(recon_tensor.shape[1]):
        reconstructed_matrix[:, i:i+100] += recon_tensor[:, i, :]
        weight_matrix[:, i:i+100] += 1
        # Average where overlaps occurred
    reconstructed_matrix /= weight_matrix
    return reconstructed_matrix

In [None]:
f, ax = plt.subplots(figsize=(12,3*rank_grid['n_rank']), nrows=rank_grid['n_rank'], sharex=True, sharey=True, layout='constrained')
ax1 = [a.twinx() for a in ax]
naive_avg = {}
trained_avg = {}
t_ind = 2
corr_thres = 0.81
for j in range(rank_grid['n_rank']):
    # naive_avg[r] = []
    # trained_avg[r] = []
    avg_beh = []
    avg_comp_tc = []
    reconstruction = moving_average_recompose(projected_data_neuron[:,j,:,:])
    for i,(p,w) in enumerate(zip(reconstruction, jsons.keys())):
        component_timecourse = p
        avg_beh.append(jsons[w]['pumping'][:num_timepoints])
        avg_comp_tc.append(component_timecourse)
        if np.abs(np.corrcoef(component_timecourse, jsons[w]['pumping'][:num_timepoints])[0,1])>corr_thres:
            print(j,i, np.corrcoef(component_timecourse, jsons[w]['pumping'][:num_timepoints])[0,1])
            ax[j].plot(component_timecourse, color='gray', alpha=0.5)
            ax1[j].plot(jsons[w]['pumping'], color='orange', alpha=0.5)
        # ax[j].plot(jsons[w]['pumping'], color='orange', alpha=0.5)
    # ax[j].plot(np.mean(avg_comp_tc, axis=0), color='k')
    # ax[j].plot(np.mean(avg_beh, axis=0), color='red')
    # naive_avg[r] = np.mean(naive_avg[r], axis=0)
    # trained_avg[r] = np.mean(trained_avg[r], axis=0)
    # ax[j].plot(np.linspace(0,90, len(p)), naive_avg[r], color='k') 
    # ax[j].plot(np.linspace(0,90, len(p)), trained_avg[r], color='r')
    simpleaxis(ax[j])
f.supxlabel("Time(s)")
f.supylabel(f"Compomnent projection amplitude")
plt.show()

In [None]:
for a,b in zip(fracCommon, commonNeurs):
    print(a,b)

In [None]:
plt.plot(conn_arr[:-1], np.diff(edges))
plt.show()

In [None]:
conn_thres = 0.01
mat_nx = np.abs(fac1_dot)>conn_thres
graph = utils.nx.from_numpy_array(mat_nx)
node_labels = {i:n for i,n in enumerate(nlabels)}
G = utils.nx.relabel_nodes(graph, node_labels)

H = G.edge_subgraph(G.edges())

In [None]:
len(graph.edges())

In [None]:
all_ser

In [None]:
H.nodes

In [None]:
set(all_ser).intersection(set(H.edges()))

In [None]:
f, ax = plt.subplots(figsize=(12,12))
utils.nx.draw_kamada_kawai(H, with_labels=True, ax=ax, node_size=1200)
plt.show()

In [None]:
utils.loadNeurotransmitters(nn)

In [None]:
for ntr, conn in conns.items():
    print(ntr)
    pos = utils.plot_layered(conn, nn, nodeColors={}, edgeColors = 'gray', save=False, title=ntr, extraNodes=[], extraEdges=[], pos=[], mark_anatomical=False, colorbar=False)

In [None]:
%matplotlib inline
x, y, z = np.nonzero(best_core)
c = best_core[x, y, z]  # Color by value
alpha = np.where(c > 1e-2, 1.0, 0.0)
vm=0.01
plt.pcolor(best_core[:,:,9], cmap='PuOr', vmin=-vm, vmax=vm)
plt.show()

In [None]:
%matplotlib qt
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Get the coordinates for non-zero elements
x, y, z = np.nonzero(best_core)
c = best_core[x, y, z]  # Color by value

# Scatter plot
sc = ax.scatter(x, y, z, c=c, cmap='viridis', s=100)

# Add color bar
plt.colorbar(sc, ax=ax)

# Set labels
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')

plt.title("3D Core Tensor from Tucker Decomposition")
plt.show()

In [None]:
c

In [None]:
across_pairs = {}
min_pairs = 20
min_reps = 10
while len(across_pairs)<min_pairs:
    i = np.random.randint(len(neuron_labels))
    j = np.random.randint(len(neuron_labels)) 
    n1 = neuron_labels[i]
    n2 = neuron_labels[j]
    tempvec = []
    if n1 == 'RMGL':
        for database in jsons.keys():
            vecs = [by_database[database][window_size][i,j] for window_size in window_sizes if not np.isnan(by_database[database][window_size][i,j])]
            if len(vecs):
                tempvec.append(vecs)
        
        if len(tempvec)>min_reps:
            f, ax = plt.subplots()
            for t in tempvec:
                ax.plot(window_sizes, t, color='gray')
            ax.plot(window_sizes, np.median(tempvec, axis=0), color='k')

            ax.set_title((n1,n2))
            ax.set_ylim((-1,1))
            plt.show()
            across_pairs[(n1,n2)] = tempvec

In [None]:
weight_window = np.zeros((len(measuredNeurons[database]), len(measuredNeurons[database])))
conn_mat = np.zeros((len(measuredNeurons[database]), len(measuredNeurons[database]))) 

for i,n1 in enumerate(neuron_labels):
    for j,n2 in enumerate(neuron_labels):
        if n1 in weight_mat:
            if n2 in weight_mat[n1]:
                weight_window[i,j] = weight_mat[n1][n2]
                if (nn.neurons[n1], nn.neurons[n2], 0) in nn.connections.keys():
                    conn_mat[i,j] = 1

ax0_ind = np.argwhere(conn_mat.ravel()==0)
ax1_ind = np.argwhere(conn_mat.ravel()==1)

f, ax = plt.subplots(figsize= (4,2*len(by_window.keys())), ncols=2, nrows=len(by_window.keys()), sharex=True, sharey=True)

for j, window_size in enumerate(window_sizes):
    ax[j,0].scatter(by_window[window_size].ravel()[ax0_ind], weight_window.ravel()[ax0_ind], c='gray')
    ax[j,1].scatter(by_window[window_size].ravel()[ax1_ind], weight_window.ravel()[ax1_ind], c='purple')
    for a in ax[j]:
        a.axhline(y=0, linestyle='--', color='gray')
        a.axvline(x=0, linestyle='--', color='gray')
        a.set_xlim((-1,1))
plt.show()

In [None]:
neuron_labels

In [None]:
nodeColors = {}
nodelist = []
for neuron in nn.neurons:
    if neuron in measuredNeurons[database]:
        nodeColors[neuron] = nn.neurons[neuron].amplitude
        nn.neurons[neuron].set_property('color', nn.neurons[neuron].amplitude)
        nodelist.append(neuron)

In [None]:
nn_2 = nn.subnetwork(nodelist, as_view=False)

In [None]:
## Giving the sex specific neurons an interneuron type for positioning on graph.
sex_neurons = ['CANL', 'CANR']
for n in nn_2.neurons:
    if n in sex_neurons:
        nn_2.neurons[n].type = 'interneuron'

In [None]:
edgeColors = []
for e in nn_2.connections:
    edgeColors.append(nn_2.connections[e].weight)
cmap = plt.get_cmap('PuOr')
max_color = max(np.abs(edgeColors))
norm = matplotlib.colors.Normalize(vmin=-max_color,vmax=max_color)
m = cm.ScalarMappable(norm=norm, cmap=cmap)
edge_color_dict = {e: m.to_rgba(nn_2.connections[e].weight) for e in nn_2.connections}


max_color = max(np.abs(list(nodeColors.values())))
cmap2 = plt.get_cmap('RdYlGn')
norm2 = matplotlib.colors.Normalize(vmin=-max_color,vmax=max_color)
o = cm.ScalarMappable(norm=norm2, cmap=cmap2)

node_color_dict = {nn_2.neurons[n]: o.to_rgba(nn_2.neurons[n].color) for n in nn_2.neurons}

In [None]:
center = [n for n in nn_2.neurons if nn_2.neurons[n].type == 'sensory']
utils.plot_shell(nn_2, center=center, figsize=(10,10), edge_color_dict=edge_color_dict, node_color_dict=node_color_dict, save=False)#"weight-activity.pdf")

In [None]:
conns = [(e[0].name, e[1].name) for e in nn_2.connections]
edgeColors = [nn_2.connections[e].weight for e in nn_2.connections]

utils.plot_layered(conns, nn_2, nodeColors=nodeColors, edgeColors=edgeColors)

In [None]:
conns = [(e[0].name, e[1].name) for e in nn_2.connections]

In [None]:
nn_2.edges

In [None]:
for n in nn_2.neurons:
    print(n, nn_2.neurons[n].type, nn_2.neurons[n].category)

In [None]:
import pandas as pd


TOPDIR = '../../' ## Change this to cedne and write a function to download data from an online server for heavy data.
DATADIR = TOPDIR + 'data_sources/'
DOWNLOAD_DIR = TOPDIR + 'data_sources/downloads/'

prefix_NT = 'Wang_2019/'
prefix_CENGEN = 'CENGEN/'
prefix_NP = 'Ripoll-Sanchez_2023/'
prefix_synaptic_weights = 'Randi_2023/'
weightMatrix = DOWNLOAD_DIR + prefix_synaptic_weights + "41586_2023_6683_MOESM13_ESM.xls"
tMat = pd.read_excel(weightMatrix, index_col=0).T

In [None]:
for e in nn.neurons['AWCL'].out_connections:
    if(nn.neurons['AWCR'] in e):
        print(nn.connections[e])

In [None]:
for n in nn_2.neurons: 
    print(nn_2.neurons[n].type)
    print(nn_2.neurons[n].category)