# Imports

In [1]:
import os

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import os

from data._utils import pick_worm, load_dataset

CUDA device found.
	 GPU: Tesla T4


In [2]:
NEURONS_302 = [ # TODO: Cite source of this list.
            "ADAL", "ADAR", "ADEL", "ADER", "ADFL", "ADFR", "ADLL", "ADLR", "AFDL", "AFDR",
            "AIAL", "AIAR", "AIBL", "AIBR", "AIML", "AIMR", "AINL", "AINR", "AIYL", "AIYR",
            "AIZL", "AIZR", "ALA", "ALML", "ALMR", "ALNL", "ALNR", "AQR", "AS1", "AS10",
            "AS11", "AS2", "AS3", "AS4", "AS5", "AS6", "AS7", "AS8", "AS9", "ASEL", "ASER",
            "ASGL", "ASGR", "ASHL", "ASHR", "ASIL", "ASIR", "ASJL", "ASJR", "ASKL", "ASKR",
            "AUAL", "AUAR", "AVAL", "AVAR", "AVBL", "AVBR", "AVDL", "AVDR", "AVEL", "AVER",
            "AVFL", "AVFR", "AVG", "AVHL", "AVHR", "AVJL", "AVJR", "AVKL", "AVKR", "AVL",
            "AVM", "AWAL", "AWAR", "AWBL", "AWBR", "AWCL", "AWCR", "BAGL", "BAGR", "BDUL",
            "BDUR", "CANL", "CANR", "CEPDL", "CEPDR", "CEPVL", "CEPVR", "DA1", "DA2", "DA3",
            "DA4", "DA5", "DA6", "DA7", "DA8", "DA9", "DB1", "DB2", "DB3", "DB4", "DB5",
            "DB6", "DB7", "DD1", "DD2", "DD3", "DD4", "DD5", "DD6", "DVA", "DVB", "DVC",
            "FLPL", "FLPR", "HSNL", "HSNR", "I1L", "I1R", "I2L", "I2R", "I3", "I4", "I5",
            "I6", "IL1DL", "IL1DR", "IL1L", "IL1R", "IL1VL", "IL1VR", "IL2DL", "IL2DR", "IL2L",
            "IL2R", "IL2VL", "IL2VR", "LUAL", "LUAR", "M1", "M2L", "M2R", "M3L", "M3R", "M4",
            "M5", "MCL", "MCR", "MI", "NSML", "NSMR", "OLLL", "OLLR", "OLQDL", "OLQDR",
            "OLQVL", "OLQVR", "PDA", "PDB", "PDEL", "PDER", "PHAL", "PHAR", "PHBL", "PHBR",
            "PHCL", "PHCR", "PLML", "PLMR", "PLNL", "PLNR", "PQR", "PVCL", "PVCR", "PVDL",
            "PVDR", "PVM", "PVNL", "PVNR", "PVPL", "PVPR", "PVQL", "PVQR", "PVR", "PVT",
            "PVWL", "PVWR", "RIAL", "RIAR", "RIBL", "RIBR", "RICL", "RICR", "RID", "RIFL",
            "RIFR", "RIGL", "RIGR", "RIH", "RIML", "RIMR", "RIPL", "RIPR", "RIR", "RIS",
            "RIVL", "RIVR", "RMDDL", "RMDDR", "RMDL", "RMDR", "RMDVL", "RMDVR", "RMED",
            "RMEL", "RMER", "RMEV", "RMFL", "RMFR", "RMGL", "RMGR", "RMHL", "RMHR", "SAADL",
            "SAADR", "SAAVL", "SAAVR", "SABD", "SABVL", "SABVR", "SDQL", "SDQR", "SIADL",
            "SIADR", "SIAVL", "SIAVR", "SIBDL", "SIBDR", "SIBVL", "SIBVR", "SMBDL", "SMBDR",
            "SMBVL", "SMBVR", "SMDDL", "SMDDR", "SMDVL", "SMDVR", "URADL", "URADR", "URAVL",
            "URAVR", "URBL", "URBR", "URXL", "URXR", "URYDL", "URYDR", "URYVL", "URYVR",
            "VA1", "VA10", "VA11", "VA12", "VA2", "VA3", "VA4", "VA5", "VA6", "VA7", "VA8",
            "VA9", "VB1", "VB10", "VB11", "VB2", "VB3", "VB4", "VB5", "VB6", "VB7", "VB8",
            "VB9", "VC1", "VC2", "VC3", "VC4", "VC5", "VC6", "VD1", "VD10", "VD11", "VD12",
            "VD13", "VD2", "VD3", "VD4", "VD5", "VD6", "VD7", "VD8", "VD9"
        ]

In [5]:
def correlate_with_specific_lag(f, g, lag):
    g = np.pad(g[lag:], (0, lag), 'constant')
    if len(f) > len(g):
        g = np.pad(g[lag:], (0, lag+(len(f)-len(g))), 'constant')
    else:
        f = np.pad(f, (0, len(g)-len(f)), 'constant')
    return np.sum(np.multiply(f, g))

In [6]:
def cross_correlate(X1, X2, lag_limit):
    vals = []
    for i in range(lag_limit):
        vals += [correlate_with_specific_lag(X1, X2, i)]

    arr = np.array(vals)
    final_arr = arr / np.max(np.abs(arr))
    
    return final_arr

In [8]:
def save_crosscorrelation(X, worm_idx=0, lag_limit=100, dataset="", mask=None):
    """
    Plot the crosscorrelation for each neuron's trajectory.

    Parameters:
    - X: A 2D numpy array of shape (max_timesteps, num_neurons) containing the neural trajectory data.
    - neurons: A list or array containing the neuron identifiers.

    Returns:
    - None: The function creates and displays a plot.
    """
    
    worm_corr_data = np.zeros((len(mask), len(mask), lag_limit))

    # Iterate over the number of neurons to create individual plots
    for i in tqdm(range(len(mask))):
        for j in range(len(mask)):
            if mask[i] == True and mask[j] == True:
                # 0 -> lag_limit
                corr = cross_correlate(X[:, i], X[:, j], lag_limit)
                
                worm_corr_data[i, j] = corr
            else:
                worm_corr_data[i, j] = np.full((lag_limit,), np.NaN)
    
    if not os.path.exists(f"./corr_data"):
        os.makedirs(f"./corr_data")
                            
    np.save('./corr_data/worm_' + str(dataset) + "_" + str(worm_idx) + '.npy', worm_corr_data)

In [9]:
dataset_name = "Kato2015"
Kato2015 = load_dataset(dataset_name)

dataset_name = "Nichols2017"
Nichols2017 = load_dataset(dataset_name)

dataset_name = "Skora2018"
Skora2018 = load_dataset(dataset_name)

dataset_name = "Kaplan2020"
Kaplan2020 = load_dataset(dataset_name)

dataset_name = "Yemini2021"
Yemini2021 = load_dataset(dataset_name)

dataset_name = "Uzel2022"
Uzel2022 = load_dataset(dataset_name)

dataset_name = "Lin2023"
Lin2023 = load_dataset(dataset_name)

dataset_name = "Leifer2023"
Leifer2023 = load_dataset(dataset_name)

dataset_name = "Flavell2023"
Flavell2023 = load_dataset(dataset_name)

datasets = [Kato2015, Nichols2017, Skora2018, Kaplan2020, Yemini2021, Uzel2022, Lin2023, Leifer2023, Flavell2023]

for i, dataset in enumerate(datasets):
        if i >= -1:
                worms = list(dataset.keys())   
                for idx in tqdm(range(len(worms))):
                        if i != 6 or (i == 6 and idx > 443):
                                worm = worms[idx]

                                single_worm_dataset = pick_worm(dataset, worm)

                                data = single_worm_dataset["calcium_data"]
                                mask = single_worm_dataset["named_neurons_mask"]
                                neurons = sorted(single_worm_dataset["named_neuron_to_slot"])

                                # X = data[:, mask].numpy()
                                X = data.numpy()

                                # plot autocorrelation and partial autocorrelation
                                save_crosscorrelation(X, idx, 100, i, mask)

100%|██████████| 302/302 [00:14<00:00, 21.20it/s]
100%|██████████| 302/302 [00:13<00:00, 22.58it/s]
100%|██████████| 302/302 [00:09<00:00, 30.94it/s]
100%|██████████| 302/302 [00:17<00:00, 17.12it/s]
100%|██████████| 302/302 [00:14<00:00, 21.44it/s]
100%|██████████| 302/302 [00:13<00:00, 22.64it/s]
100%|██████████| 302/302 [00:17<00:00, 17.04it/s]
100%|██████████| 302/302 [00:11<00:00, 25.64it/s]
100%|██████████| 302/302 [00:15<00:00, 19.78it/s]
100%|██████████| 302/302 [00:08<00:00, 35.62it/s]
100%|██████████| 302/302 [00:17<00:00, 17.12it/s]
100%|██████████| 302/302 [00:08<00:00, 35.93it/s]
100%|██████████| 12/12 [02:42<00:00, 13.54s/it]
100%|██████████| 302/302 [00:05<00:00, 59.66it/s]
100%|██████████| 302/302 [00:11<00:00, 26.97it/s]
100%|██████████| 302/302 [00:10<00:00, 28.91it/s]
100%|██████████| 302/302 [00:10<00:00, 28.73it/s]
100%|██████████| 302/302 [00:09<00:00, 32.12it/s]
100%|██████████| 302/302 [00:05<00:00, 51.44it/s]
100%|██████████| 302/302 [00:10<00:00, 30.11it/s]
10

# Produce figures from data

In [4]:
# collect all data
print("Collecting data...")

worm_files = sorted(os.listdir("./corr_data"))
print(worm_files)
seperated_files = [[]]

curr_set = 0

datasets = ["Kato2015", "Nichols2017", "Skora2018", "Kaplan2020", "Yemini2021", "Uzel2022", "Lin2023", "Leifer2023", "Flavell2023"]

if not os.path.exists(f"./corr_figs"):
    os.makedirs(f"./corr_figs")
                            
for file in worm_files:
    file_set = int(file.split("_")[1])
    if file_set != curr_set:
        curr_set = file_set
        seperated_files += [[file]]
    else:
        seperated_files[-1] += [file]

for set_idx, files in enumerate(seperated_files):
    if set_idx != 6 and set_idx > 0:
        all_data = np.empty((len(files), 302, 302, 100))
        for i, file in tqdm(enumerate(files)):
            all_data[i] = np.load(f'./corr_data/{file}')

        # filter data
        # length of filtered data list is 302 neurons*302 neurons
        print("Filtering data...")

        filtered_data = [None for i in range(302**2)]

        for i in tqdm(range(all_data.shape[1])):
            for j in range(all_data.shape[2]):
                for worm_idx in range(all_data.shape[0]):
                    if filtered_data[i*all_data.shape[1]+j] is None and not np.any(np.isnan(all_data[worm_idx, i, j])):
                        filtered_data[i*all_data.shape[1]+j] = [all_data[worm_idx, i, j]]
                    elif not np.any(np.isnan(all_data[worm_idx, i, j])):
                        filtered_data[i*all_data.shape[1]+j] += [all_data[worm_idx, i, j]]

        for i in range(len(filtered_data)):
            if filtered_data[i] is not None:
                filtered_data[i] = np.array(filtered_data[i])

        # generate plots 
        print("Generating plots...")
            
        for i, neuron1 in enumerate(NEURONS_302):
            if i > -1:
                for j, neuron2 in tqdm(enumerate(NEURONS_302)):
                    if filtered_data[i*len(NEURONS_302)+j] is not None:
                        plt.figure()
                        plt.ylim(-1, 1)
                        
                        plt.title(f"Dataset {datasets[set_idx]}: {neuron1}_{neuron2}")
                        plt.xlabel("Lag")
                        plt.ylabel("Correlation Score")
                        
                        mean = np.mean(filtered_data[i*len(NEURONS_302)+j], axis=0)
                        std = np.std(filtered_data[i*len(NEURONS_302)+j], axis=0)

                        plt.plot(np.arange(0, 100, 1), mean, linewidth=5.0, alpha=1.0, color="cornflowerblue")

                        for k in filtered_data[i*len(NEURONS_302)+j]:
                            plt.plot(np.arange(0, 100, 1), k, alpha=0.3, color="cornflowerblue")

                        z=1
                        plt.fill_between(np.arange(0, 100, 1), mean-(std*z), mean+(std*z), alpha=0.4)  
                        
                        if not os.path.exists(f"./corr_figs/{set_idx}"):
                            os.makedirs(f"./corr_figs/{set_idx}")
                            
                        if not os.path.exists(f"./corr_figs/{set_idx}/{neuron1}"):
                            os.makedirs(f"./corr_figs/{set_idx}/{neuron1}")
                            
                        plt.savefig(f"./corr_figs/{set_idx}/{neuron1}/{neuron1}_{neuron2}.png")
                        plt.clf()
                        

Collecting data...
['worm_0_0.npy', 'worm_0_1.npy', 'worm_0_10.npy', 'worm_0_11.npy', 'worm_0_2.npy', 'worm_0_3.npy', 'worm_0_4.npy', 'worm_0_5.npy', 'worm_0_6.npy', 'worm_0_7.npy', 'worm_0_8.npy', 'worm_0_9.npy', 'worm_1_0.npy', 'worm_1_1.npy', 'worm_1_10.npy', 'worm_1_11.npy', 'worm_1_12.npy', 'worm_1_13.npy', 'worm_1_14.npy', 'worm_1_15.npy', 'worm_1_16.npy', 'worm_1_17.npy', 'worm_1_18.npy', 'worm_1_19.npy', 'worm_1_2.npy', 'worm_1_20.npy', 'worm_1_21.npy', 'worm_1_22.npy', 'worm_1_23.npy', 'worm_1_24.npy', 'worm_1_25.npy', 'worm_1_26.npy', 'worm_1_27.npy', 'worm_1_28.npy', 'worm_1_29.npy', 'worm_1_3.npy', 'worm_1_30.npy', 'worm_1_31.npy', 'worm_1_32.npy', 'worm_1_33.npy', 'worm_1_34.npy', 'worm_1_35.npy', 'worm_1_36.npy', 'worm_1_37.npy', 'worm_1_38.npy', 'worm_1_39.npy', 'worm_1_4.npy', 'worm_1_40.npy', 'worm_1_41.npy', 'worm_1_42.npy', 'worm_1_43.npy', 'worm_1_5.npy', 'worm_1_6.npy', 'worm_1_7.npy', 'worm_1_8.npy', 'worm_1_9.npy', 'worm_2_0.npy', 'worm_2_1.npy', 'worm_2_10.npy',

0it [00:00, ?it/s]

2it [00:00,  2.57it/s]


KeyboardInterrupt: 

# Extract Autocorr

In [None]:
dirs = os.listdir("./corr_figs")
file_paths = []

for directory in dirs:
    subdirs = os.listdir("./corr_figs/" + directory)
    for subdir in subdirs:
        file_paths += [f"./corr_figs/{directory}/{subdir}/{subdir}_{subdir}.png"]
file_paths

['../corr_figs/3/OLQDL/OLQDL_OLQDL.png',
 '../corr_figs/3/RIFL/RIFL_RIFL.png',
 '../corr_figs/3/DB6/DB6_DB6.png',
 '../corr_figs/3/AFDL/AFDL_AFDL.png',
 '../corr_figs/3/RMER/RMER_RMER.png',
 '../corr_figs/3/VB9/VB9_VB9.png',
 '../corr_figs/3/DB1/DB1_DB1.png',
 '../corr_figs/3/AIBL/AIBL_AIBL.png',
 '../corr_figs/3/DB3/DB3_DB3.png',
 '../corr_figs/3/DVA/DVA_DVA.png',
 '../corr_figs/3/SMDVL/SMDVL_SMDVL.png',
 '../corr_figs/3/AVBL/AVBL_AVBL.png',
 '../corr_figs/3/DB5/DB5_DB5.png',
 '../corr_figs/3/SMDVR/SMDVR_SMDVR.png',
 '../corr_figs/3/AVAL/AVAL_AVAL.png',
 '../corr_figs/3/VD2/VD2_VD2.png',
 '../corr_figs/3/OLQVR/OLQVR_OLQVR.png',
 '../corr_figs/3/RIGR/RIGR_RIGR.png',
 '../corr_figs/3/RIGL/RIGL_RIGL.png',
 '../corr_figs/3/ASKR/ASKR_ASKR.png',
 '../corr_figs/3/URYVL/URYVL_URYVL.png',
 '../corr_figs/3/AIBR/AIBR_AIBR.png',
 '../corr_figs/3/ALA/ALA_ALA.png',
 '../corr_figs/3/VB3/VB3_VB3.png',
 '../corr_figs/3/VB1/VB1_VB1.png',
 '../corr_figs/3/RIVR/RIVR_RIVR.png',
 '../corr_figs/3/SIBVL/SIBV

In [None]:
import shutil

if not os.path.exists(f"./autocorr_figs"):
    os.makedirs(f"./autocorr_figs")
                            
for file in file_paths:
    split_path = file.split("/")
    dataset = split_path[2]
    if not os.path.exists(f"./autocorr_figs/{dataset}"):
        os.makedirs(f"./autocorr_figs/{dataset}")
    shutil.copyfile(file, f"{split_path[0]}/autocorr_figs/{split_path[2]}/{split_path[-1]}")