In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch

from data._main import get_dataset
from omegaconf import OmegaConf
from utils import NEURONS_302

In [3]:
dataset_config = OmegaConf.load("../../../configs/submodule/dataset.yaml")
dataset_config.dataset.train.name = ['Kato2015', 'Nichols2017']
dataset_config.dataset.train.num_worms = 10
# Print config
print(OmegaConf.to_yaml(dataset_config))

dataset:
  train:
    name:
    - Kato2015
    - Nichols2017
    num_named_neurons: 5
    num_worms: 10
    save: false
  predict:
    name: Kato2015
    num_named_neurons: 1
    num_worms: 1
    save: false



In [4]:
combined_dataset = get_dataset(dataset_config=dataset_config.dataset.train)

In [5]:
dataset_info = {
    'dataset': [],
    'original_index': [],
    'combined_dataset_index': [],
    'neurons': [],
}

combined_dataset_neurons = []

for worm, data in combined_dataset.items():
    dataset_info['dataset'].append(data['dataset'])
    dataset_info['original_index'].append(data['original_worm'])
    dataset_info['combined_dataset_index'].append(data['worm'])
    worm_neurons = [neuron for slot, neuron in data['slot_to_named_neuron'].items()]
    dataset_info['neurons'].append(worm_neurons)
    combined_dataset_neurons = combined_dataset_neurons + worm_neurons

dataset_info = pd.DataFrame(dataset_info)

combined_dataset_neurons = np.array(combined_dataset_neurons)
# Count occurernces of each neuron
neuron_counts = np.unique(combined_dataset_neurons, return_counts=True)
# Sort by neuron count
neuron_counts = np.array(sorted(zip(*neuron_counts), key=lambda x: x[1], reverse=True))
# Create a dataframe
neuron_counts = pd.DataFrame(neuron_counts, columns=["neuron", "count"])

In [6]:
len(combined_dataset)

10

In [36]:
from train._utils import split_train_test
from data._utils import NeuralActivityDataset, create_combined_dataset
from torch.utils.data import ConcatDataset
from omegaconf import OmegaConf, DictConfig

# Split (combined) dataset
---

In [37]:
def distribute_samples(data_splits, total_nb_samples):
    # Calculate the base number of samples for each split
    base_samples_per_split = total_nb_samples // len(data_splits)
    # Calculate the remainder
    remainder = total_nb_samples % len(data_splits)

    samples_to_take = []
    
    # Distribute the samples
    for i in range(len(data_splits)):
        if i < remainder:
            samples_to_take.append(base_samples_per_split + 1)
        else:
            samples_to_take.append(base_samples_per_split)

    return samples_to_take

In [60]:
def split_combined_dataset(combined_dataset, k_splits, num_train_samples, num_val_samples, seq_len, tau, reverse, use_residual, smooth_data):

    # Choose whether to use calcium or residual data
    if use_residual:
        key_data = "residual_calcium"
    else:
        key_data = "calcium_data"

    # Choose whether to use original or smoothed data
    if smooth_data:
        key_data = "smooth_" + key_data
    else:
        key_data = key_data

    # Store the training and validation datasets
    train_dataset = []
    val_dataset = []

    # Loop through the worms in the dataset
    for wormID, single_worm_dataset in combined_dataset.items():

        # Extract relevant features from the dataset
        data = single_worm_dataset[key_data]
        neurons_mask = single_worm_dataset["named_neurons_mask"]
        time_vec = single_worm_dataset["time_in_seconds"]

        # Verifications
        assert isinstance(seq_len, int) and 0 < seq_len < len(data), "seq_len must be an integer > 0 and < len(data)"
        assert isinstance(tau, int) and tau < (len(data) - seq_len), "The desired tau is too long. Try a smaller value"
        assert seq_len < (len(data) // k_splits), "The desired seq_len is too long. Try a smaller seq_len or decreasing k_splits"

        # Split the data and the time vector into k splits
        data_splits = np.array_split(data, k_splits)
        time_vec_splits = np.array_split(time_vec, k_splits)

        # Separate the splits into training and validation sets
        train_data_splits = data_splits[::2]
        train_time_vec_splits = time_vec_splits[::2]
        val_data_splits = data_splits[1::2]
        val_time_vec_splits = time_vec_splits[1::2]

        # Number of total time steps in each split
        total_train_time_steps = np.sum([len(split) for split in train_data_splits])
        total_val_time_steps = np.sum([len(split) for split in val_data_splits])

        # Number of samples in each split
        train_samples_per_split = distribute_samples(train_data_splits, num_train_samples)
        val_samples_per_split = distribute_samples(val_data_splits, num_val_samples)

        # Create a dataset for each split
        for train_split, train_time_split, num_samples_split in zip(train_data_splits, train_time_vec_splits, train_samples_per_split):
            train_dataset.append(
                NeuralActivityDataset(
                    data = train_split.detach(),
                    time_vec = train_time_split.detach(),
                    neurons_mask = neurons_mask,
                    seq_len = seq_len,
                    num_samples = num_samples_split,
                    tau = tau,
                    use_residual = use_residual,
                    reverse = reverse,
                )
            )

        
        for val_split, val_time_split, num_samples_split in zip(val_data_splits, val_time_vec_splits, val_samples_per_split):
            val_dataset.append(
                NeuralActivityDataset(
                    data = val_split.detach(),
                    time_vec = val_time_split.detach(),
                    neurons_mask = neurons_mask,
                    seq_len = seq_len,
                    num_samples = num_samples_split,
                    tau = tau,
                    use_residual = use_residual,
                    reverse = reverse,
                )
            )

    # Concatenate the datasets
    train_dataset = torch.utils.data.ConcatDataset(train_dataset) # Nb of train examples = nb train samples * nb of worms
    val_dataset = torch.utils.data.ConcatDataset(val_dataset) # Nb of val examples = nb train samples * nb of worms

    # Save the datasets
    torch.save(train_dataset, "train_dataset.pth")
    torch.save(val_dataset, "val_dataset.pth")

    return train_dataset, val_dataset

In [61]:
dataset_config = OmegaConf.load("../../../configs/submodule/dataset_new.yaml")
print(OmegaConf.to_yaml(dataset_config.dataset.for_training))

experimental_datasets:
- Kato2015
- Nichols2017
num_named_neurons: 2
num_worms: 10
k_splits: 2
num_train_samples: 16
num_val_samples: 4
seq_len: 120
tau: 1
reverse: false
use_residual: false
smooth_data: true



In [51]:
def get_datasets(dataset_config: DictConfig, name='train'):

    experimental_datasets = dataset_config.experimental_datasets
    num_named_neurons = dataset_config.num_named_neurons
    num_worms = dataset_config.num_worms
    k_splits = dataset_config.k_splits
    num_train_samples = dataset_config.num_train_samples
    num_val_samples = dataset_config.num_val_samples
    seq_len = dataset_config.seq_len
    tau = dataset_config.tau
    reverse = dataset_config.reverse
    use_residual = dataset_config.use_residual
    smooth_data = dataset_config.smooth_data

    # Verifications
    assert isinstance(k_splits, int) and k_splits > 1, "k_splits must be an integer > 1"

    assert isinstance(num_named_neurons, int) or num_named_neurons == "all", (
        "num_named_neurons must be a positive integer or 'all'."
    )

    assert isinstance(num_worms, int) or num_worms == "all", (
        "num_worms must be a positive integer or 'all'."
    )

    combined_dataset = create_combined_dataset(experimental_datasets, num_named_neurons, num_worms, name=name)
    train_dataset, val_dataset = split_combined_dataset(combined_dataset, k_splits, num_train_samples,
                                                         num_val_samples, seq_len, tau, reverse,
                                                         use_residual, smooth_data)

    return train_dataset, val_dataset

In [52]:
train_dataset, val_dataset = get_datasets(dataset_config.dataset.for_training, name='train')

In [55]:
x1,y1,m1,m = next(iter(train_dataset))

In [58]:
(x1[1:,:] == y1[:-1,:]).all()

tensor(True)

In [59]:
len(val_dataset)

40

In [17]:
train_dataset, val_dataset = split_combined_dataset(
    combined_dataset = combined_dataset,
    k_splits = 2,
    num_train_samples = 16,
    num_val_samples = 8,
    seq_len = 100,
    tau = 1,
    reverse = False,
    use_residual = False,
    smooth_data = True,
)

In [19]:
# Create the dataloaders

trainloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = 16,
    shuffle = True,
)

In [21]:
for x, y, mask, _ in trainloader:
    print(x.shape, y.shape, mask.shape)

torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
torch.Size([16, 100, 302]) torch.Size([16, 100, 302]) torch.Size([16, 302])
