In [None]:
# @title Imports

import torch
import random
import nbimporter
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from utils import NEURONS_302, DEVICE
from datasets import load_dataset
from transformers import AutoTokenizer
from models._utils import load_model_from_checkpoint
from tokenizers.pre_tokenizers import WhitespaceSplit
from preprocess._utils import smooth_data_preprocess, reshape_calcium_data
from CreateSyntheticDataset import (
    save_synthetic_dataset,
    plot_neural_signals,
    plot_3d_trajectory,
)

In [None]:
# @title HuggingFace Tokenizers

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
output = tokenizer.encode("Welcome to the 🤗 Tokenizers library.")
encoding = tokenizer("We are very happy to show you the 🤗 Transformers library.")
print(output, end="\n\n")
print(tokenizer.decode(output), end="\n\n")
print(encoding, end="\n\n")

In [None]:
# @title Helper functions


def pre_tokenize_and_chunk(
    text, max_length=510
):  # slightly less than 512 to account for special tokens
    pre_tokenizer = WhitespaceSplit()
    pre_tokens = pre_tokenizer.pre_tokenize_str(text)

    chunks = []
    current_chunk = ""
    for token, (start, end) in pre_tokens:
        if len(current_chunk) + len(token) + 1 > max_length:
            chunks.append(current_chunk)
            current_chunk = token
        else:
            current_chunk += " " + token if current_chunk else token
    if current_chunk:
        chunks.append(current_chunk)

    return chunks


def tokenize_and_chunk(examples, tokenizer=None):
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    tokenized_batches = {"text": [], "input_ids": [], "attention_mask": []}

    for text in examples["text"]:
        chunks = pre_tokenize_and_chunk(text)
        for chunk in chunks:
            tokenized_output = tokenizer(
                chunk,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            tokenized_batches["text"].append(chunk)
            tokenized_batches["input_ids"].append(
                tokenized_output["input_ids"][0].tolist()
            )
            tokenized_batches["attention_mask"].append(
                tokenized_output["attention_mask"][0].tolist()
            )

    return tokenized_batches

In [None]:
# @title Load the dataset

text_dataset = load_dataset("tiny_shakespeare")
print(text_dataset, end="\n\n")
print(
    "train:",
    type(text_dataset["train"]["text"]),
    len(text_dataset["train"]["text"]),
    type(text_dataset["train"]["text"][0]),
    len(text_dataset["train"]["text"][0]),
    end="\n\n",
)
print(
    "validation:",
    type(text_dataset["validation"]["text"]),
    len(text_dataset["validation"]["text"]),
    type(text_dataset["validation"]["text"][0]),
    len(text_dataset["validation"]["text"][0]),
    end="\n\n",
)
print(
    "test:",
    type(text_dataset["test"]["text"]),
    len(text_dataset["test"]["text"]),
    type(text_dataset["test"]["text"][0]),
    len(text_dataset["test"]["text"][0]),
    end="\n\n",
)

In [None]:
# @title Tokenization and Chunking
# @markdown Apply the tokenization and chunking to each split.

text_dataset = text_dataset.map(tokenize_and_chunk, batched=True)
print(text_dataset, end="\n\n")
print(text_dataset["train"]["text"][0], end="\n\n")
print(tokenizer.decode(text_dataset["train"]["input_ids"][0]), end="\n\n")

In [None]:
# @title Embedding Table

#  Create an embedding layer
embedding_dim = 302
embedding = torch.nn.Embedding(
    num_embeddings=tokenizer.vocab_size, embedding_dim=embedding_dim, dtype=torch.half
)

In [None]:
# @title Synthetic dataset
# @markdown Neural activity is sequence of token embeddings.

train_tokens = text_dataset["train"]["input_ids"]
validation_tokens = text_dataset["validation"]["input_ids"]
test_tokens = text_dataset["test"]["input_ids"]
all_tokens = train_tokens + validation_tokens + test_tokens
worm_dataset = dict()
max_timesteps = 3000
num_signals = embedding_dim
num_named_neurons = 302
num_unknown_neurons = num_signals - num_named_neurons
smooth_method = None
dataset_name = "Shakespeare0000"

worm_idx = 0
calcium_data = []
total_time = 0
for chunk in all_tokens:
    if worm_idx > 200:
        break
    embd_data = embedding(torch.LongTensor(chunk)).detach().numpy()
    calcium_data.append(embd_data)
    total_time += embd_data.shape[0]
    if total_time >= max_timesteps:
        calcium_data = np.vstack(calcium_data)
        time_in_seconds = np.arange(total_time).reshape(-1, 1)
        dt = np.gradient(time_in_seconds, axis=0)
        dt[dt == 0] = np.finfo(float).eps
        resample_dt = np.round(np.median(dt).item(), 2)
        residual_calcium = np.gradient(calcium_data, axis=0) / dt
        smooth_calcium_data = smooth_data_preprocess(
            calcium_data, time_in_seconds, smooth_method
        )
        smooth_residual_calcium = smooth_data_preprocess(
            residual_calcium, time_in_seconds, smooth_method
        )
        named_neuron_indices = random.sample(range(num_signals), num_named_neurons)
        named_neurons = set(NEURONS_302[i] for i in named_neuron_indices)
        neuron_to_idx = {
            (neuron) if neuron in named_neurons else str(idx): idx
            for idx, neuron in enumerate(NEURONS_302)
        }
        idx_to_neuron = {idx: neuron for neuron, idx in neuron_to_idx.items()}

        worm_data = dict()

        worm_data["worm"] = f"worm{worm_idx}"
        worm_data["dataset"] = dataset_name
        worm_data["smooth_method"] = smooth_method
        worm_data["calcium_data"] = calcium_data
        worm_data["smooth_calcium_data"] = smooth_calcium_data
        worm_data["residual_calcium"] = residual_calcium
        worm_data["smooth_residual_calcium"] = smooth_residual_calcium
        worm_data["max_timesteps"] = total_time
        worm_data["time_in_seconds"] = time_in_seconds
        worm_data["dt"] = dt
        worm_data["resample_median_dt"] = resample_dt
        worm_data["neuron_to_idx"] = neuron_to_idx
        worm_data["idx_to_neuron"] = idx_to_neuron
        worm_data["num_neurons"] = num_signals
        worm_data["num_named_neurons"] = num_named_neurons
        worm_data["num_unknown_neurons"] = num_unknown_neurons

        worm_data = reshape_calcium_data(worm_data)
        worm_dataset[f"worm{worm_idx}"] = worm_data

        worm_idx += 1
        calcium_data = []
        total_time = 0

# Save the dataset
save_synthetic_dataset(f"processed/neural/{dataset_name}.pickle", worm_dataset)

In [None]:
# @title Plotting the dataset

# Selecting a worm and all the neurons to plot
num_worms = len(worm_dataset)
worm_idx = random.choice([f"worm{i}" for i in range(num_worms)])
neuron_idx = [idx for idx in worm_dataset[worm_idx]["slot_to_neuron"].keys()][
    :num_named_neurons
]

# Plotting dataset
plot_neural_signals(
    data=worm_dataset[worm_idx]["calcium_data"],
    time_tensor=worm_dataset[worm_idx]["time_in_seconds"],
    neuron_idx=neuron_idx,
    yax_limit=False,
    suptitle=f"{dataset_name} - {worm_idx}",
)

# Visualize covariance matrix
data = worm_dataset[worm_idx]["smooth_calcium_data"]
mask = worm_dataset[worm_idx]["named_neurons_mask"]
neurons = sorted(worm_dataset[worm_idx]["named_neuron_to_slot"])

X = data[:, mask].numpy()
n = X.shape[0]
X_bar = X - np.mean(X, axis=0)
cov = 1 / (n - 1) * X_bar.T @ X_bar

plt.figure()
ax = sns.heatmap(cov, cmap="coolwarm", xticklabels=neurons, yticklabels=neurons)
ax.set_title(f"Covariance matrix : {dataset_name}, {worm_idx}")
plt.show()

# Plotting 3D trajectory
plot_3d_trajectory(
    X, axis_labels=tuple(neurons), title=f"{dataset_name} neural trajectory"
)

In [None]:
# @title Load a trained model

checkpoint_path = "/om2/user/qsimeon/worm-graph/logs/hydra/2023_12_25_06_45_54/exp10/train/checkpoints/model_best.pt"
model = load_model_from_checkpoint(checkpoint_path)

In [None]:
print(tokenizer.decode(chunk), end="\n\n")
embedding = embedding.to(DEVICE)
idx = torch.LongTensor(chunk).unsqueeze(0).to(DEVICE)
input = embedding(idx)
mask = torch.ones(input.shape[-1], dtype=torch.bool).unsqueeze(0).to(DEVICE)
print(idx.shape, input.shape, mask.shape, end="\n\n")

In [None]:
max_new_tokens = 100
idx_gen = model.transformer_generate(idx, max_new_tokens, mask, embedding)

print(tokenizer.decode(idx.tolist()[0]), end="\n\n")
print(tokenizer.decode(idx_gen.tolist()[0][-max_new_tokens:]), end="\n\n")