### Create the files  - containing the neural activity data 
---
*Last updated: 21 October 2024*

This is meant to be a direct implementation of the main operation of the `data` submodule (`data/_main.py`).

The final output is the creation of the files `data/datasets/combined_dataset.pickle` and `data/datasets/full_dataset.pt`.

In [4]:
import os
import torch
import pickle
from utils import ROOT_DIR
from omegaconf import DictConfig, OmegaConf
from data._utils import (
    NeuralActivityDataset,
    load_dataset,
    select_desired_worms,
    select_labeled_neurons,
    rename_worm_keys,
    filter_loaded_combined_dataset,
)

In [5]:
data_config = OmegaConf.load("../configs/submodule/data.yaml").dataset
print(OmegaConf.to_yaml(data_config), end="\n\n")

use_these_datasets:
  path: null
  num_worms: null
save_datasets: true
source_datasets:
  Kato2015: all
  Nichols2017: all
  Skora2018: all
  Kaplan2020: all
  Nejatbakhsh2020: all
  Yemini2021: all
  Uzel2022: all
  Dag2023: all
  Flavell2023: all
  Leifer2023: all
  Lin2023: all
  Venkatachalam2024: all
num_labeled_neurons: null
seq_len: null
num_train_samples: 16
num_val_samples: 16
reverse: false
use_residual: false
use_smooth: false
train_split_first: false
train_split_ratio: 0.5




In [6]:
# Parse out parameters from the config
source_datasets = data_config.source_datasets
num_labeled_neurons = data_config.num_labeled_neurons
num_train_samples = data_config.num_train_samples
num_val_samples = data_config.num_val_samples
seq_len = data_config.seq_len
reverse = data_config.reverse
use_residual = data_config.use_residual
use_smooth = data_config.use_smooth
train_split_first = data_config.train_split_first
train_split_ratio = data_config.train_split_ratio
save = True  # data_config.save_datasets

In [7]:
# Make the datasets directory if it doesn't exist
os.makedirs(os.path.join(ROOT_DIR, "data", "datasets"), exist_ok=True)

# If combined_dataset.pickle already exists, load it
pickle_file = os.path.join("datasets", "combined_dataset.pickle")
if os.path.exists(pickle_file):
    combined_dataset_dict = pickle.load(pickle_file)
    combined_dataset_dict, dataset_info = filter_loaded_combined_dataset(
        combined_dataset_dict,
        data_config.use_these_datasets.num_worms,
        num_labeled_neurons,
    )

# Otherwise create it from scratch
else:
    # Convert DictConfig to dict
    if isinstance(source_datasets, DictConfig):
        source_datasets = OmegaConf.to_object(source_datasets)

    # Load the dataset(s)
    combined_dataset_dict = dict()

    for dataset_name, worms in source_datasets.items():
        # Skip if no worms requested for this dataset
        if worms is None or worms == 0:
            print(f"Skipping all worms from {dataset_name} dataset.")
            continue

        # Create a multi-worm dataset
        multi_worms_dataset = load_dataset(dataset_name)

        # Select desired worms from this dataset
        multi_worms_dataset = select_desired_worms(multi_worms_dataset, worms)

        # Select the `num_labeled_neurons` neurons and overwrite the masks
        multi_worms_dataset = select_labeled_neurons(multi_worms_dataset, num_labeled_neurons)

        # Add the worms from this dataset to the combined dataset
        for worm in multi_worms_dataset:
            if worm in combined_dataset_dict:
                worm_ = (
                    max([int(key.split("worm")[-1]) for key in combined_dataset_dict.keys()]) + 1
                )
                worm_ = "worm" + str(worm_)
                combined_dataset_dict[worm_] = multi_worms_dataset[worm]
                combined_dataset_dict[worm_]["worm"] = worm_
                combined_dataset_dict[worm_]["original_worm"] = worm
            else:
                combined_dataset_dict[worm] = multi_worms_dataset[worm]
                combined_dataset_dict[worm]["original_worm"] = worm

    print("Combined dataset has {} worms".format(len(combined_dataset_dict)))

    # Rename the worm keys so that they are ordered
    combined_dataset_dict = rename_worm_keys(combined_dataset_dict)

Combined dataset has 919 worms


In [8]:
# Print to see if everything looks right so far
print(len(combined_dataset_dict))
print(combined_dataset_dict.keys())
print(combined_dataset_dict["worm0"].keys())

919
dict_keys(['worm0', 'worm1', 'worm2', 'worm3', 'worm4', 'worm5', 'worm6', 'worm7', 'worm8', 'worm9', 'worm10', 'worm11', 'worm12', 'worm13', 'worm14', 'worm15', 'worm16', 'worm17', 'worm18', 'worm19', 'worm20', 'worm21', 'worm22', 'worm23', 'worm24', 'worm25', 'worm26', 'worm27', 'worm28', 'worm29', 'worm30', 'worm31', 'worm32', 'worm33', 'worm34', 'worm35', 'worm36', 'worm37', 'worm38', 'worm39', 'worm40', 'worm41', 'worm42', 'worm43', 'worm44', 'worm45', 'worm46', 'worm47', 'worm48', 'worm49', 'worm50', 'worm51', 'worm52', 'worm53', 'worm54', 'worm55', 'worm56', 'worm57', 'worm58', 'worm59', 'worm60', 'worm61', 'worm62', 'worm63', 'worm64', 'worm65', 'worm66', 'worm67', 'worm68', 'worm69', 'worm70', 'worm71', 'worm72', 'worm73', 'worm74', 'worm75', 'worm76', 'worm77', 'worm78', 'worm79', 'worm80', 'worm81', 'worm82', 'worm83', 'worm84', 'worm85', 'worm86', 'worm87', 'worm88', 'worm89', 'worm90', 'worm91', 'worm92', 'worm93', 'worm94', 'worm95', 'worm96', 'worm97', 'worm98', 'worm

In [9]:
# Use largest `seq_len` that produces required unique samples from shortest dataset
if seq_len is None:
    max_num_samples = max(num_train_samples, num_val_samples)
    min_timesteps = min(dataset["max_timesteps"] for _, dataset in combined_dataset_dict.items())
    seq_len = (min_timesteps // 2) - max_num_samples - 1
print(f"Chosen sequence length: {seq_len}\n")

Chosen sequence length: 329



In [10]:
# Now we want to save the relevant tensors as a Pytorch dataset
combined_datasets = []
for wormID, single_worm_dataset in combined_dataset_dict.items():
    # TODO: Encapsulate this inner part as a function `split_single_dataset`.
    # Extract relevant features from the dataset
    data = single_worm_dataset["calcium_data"]
    neurons_mask = single_worm_dataset["labeled_neurons_mask"]
    time_vec = single_worm_dataset["time_in_seconds"]
    worm_dataset = single_worm_dataset["source_dataset"]
    original_worm_id = single_worm_dataset["original_worm"]

    single_dataset = NeuralActivityDataset(
        data=data,
        time_vec=time_vec,
        neurons_mask=neurons_mask,
        wormID=original_worm_id,  # worm ID from the original experimental dataset
        worm_dataset=worm_dataset,  # name of the original experimental dataset the data is from
        seq_len=seq_len,
        num_samples=num_train_samples + num_val_samples,
        use_residual=use_residual,
        reverse=reverse,
    )

    combined_datasets.append(single_dataset)

    # ### DEBUG ###
    # print(f"\nDEBUG neurons_mask: {neurons_mask.shape, neurons_mask.sum()}")
    # print(data[:seq_len, neurons_mask])
    # X, M, info = single_dataset[0]
    # print(f"\nDEBUG M: {M.shape, M.sum()}")
    # print(X[:, M])
    # ### DEBUG ###

# Concatenate the datasets
combined_dataset_pt = (
    torch.utils.data.ConcatDataset(combined_datasets) if len(combined_datasets) else None
)  # number of train sequences = number train samples * number of worms

In [11]:
# DEBUG
# len(combined_datasets), 919*32, len(combined_dataset_dict)
# wormid = 'worm275'
# combined_dataset_dict[wormid]['source_dataset'], combined_dataset_dict[wormid]['worm'], combined_dataset_dict[wormid]['original_worm']

In [12]:
# Save the combined dataset (both pickle and pt version)
torch.save(combined_dataset_pt, os.path.join(ROOT_DIR, "data", "datasets", "full_dataset.pt"))
with open(os.path.join(ROOT_DIR, "data", "datasets", "combined_dataset.pickle"), "wb") as f:
    pickle.dump(combined_dataset_dict, f)