In [15]:
# @title Imports

import math
import torch
import random
import nbimporter
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from datasets import load_dataset
from transformers import AutoTokenizer
from models._utils import load_model_from_checkpoint
from utils import NEURONS_302, DEVICE, MAX_TOKEN_LEN
from tokenizers.pre_tokenizers import WhitespaceSplit
from preprocess._utils import smooth_data_preprocess, reshape_calcium_data
from CreateSyntheticDataset import (
    save_synthetic_dataset,
    plot_neural_signals,
    plot_3d_trajectory,
)

#### Create a synthetic dataset where the neural activity is the embeddings of tokens from the tiny Shakespeare dataset.

In [2]:
def pre_tokenize_and_chunk(
    text, max_length=510
):  # slightly less than 512 to account for special tokens
    pre_tokenizer = WhitespaceSplit()
    pre_tokens = pre_tokenizer.pre_tokenize_str(text)

    chunks = []
    current_chunk = ""
    for token, (start, end) in pre_tokens:
        if len(current_chunk) + len(token) + 1 > max_length:
            chunks.append(current_chunk)
            current_chunk = token
        else:
            current_chunk += " " + token if current_chunk else token
    if current_chunk:
        chunks.append(current_chunk)

    return chunks


def tokenize_and_chunk(examples, tokenizer=None):
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    tokenized_batches = {"text": [], "input_ids": [], "attention_mask": []}

    for text in examples["text"]:
        chunks = pre_tokenize_and_chunk(text)
        for chunk in chunks:
            tokenized_output = tokenizer(
                chunk,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            tokenized_batches["text"].append(chunk)
            tokenized_batches["input_ids"].append(
                tokenized_output["input_ids"][0].tolist()
            )
            tokenized_batches["attention_mask"].append(
                tokenized_output["attention_mask"][0].tolist()
            )

    return tokenized_batches

In [3]:
def create_synthetic_dataset_shakespeare(
    max_timesteps: int = 1000,
    num_signals: int = 302,
    num_named_neurons: int = 1,
    dataset_name: str = "Shakespeare0000",
):
    # Want to access the tokenizer and the embedding table outside this function
    global tokenizer, embedding  # DEBUG
    # Load the Shakespeare dataset
    text_dataset = load_dataset("tiny_shakespeare")
    # Create a tokenizer
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    # Apply the tokenization and chunking to each split
    text_dataset = text_dataset.map(
        tokenize_and_chunk, batched=True, fn_kwargs=dict(tokenizer=tokenizer)
    )
    # Create an embedding table
    embedding_dim = num_signals
    embedding = torch.nn.Embedding(
        num_embeddings=tokenizer.vocab_size,
        embedding_dim=embedding_dim,
        dtype=torch.half,
    )
    print(embedding.weight[0, :10].detach().cpu().numpy(), end="\n\n")  # DEBUG
    # Extract all splots of the text dataset
    train_tokens = text_dataset["train"]["input_ids"]
    validation_tokens = text_dataset["validation"]["input_ids"]
    test_tokens = text_dataset["test"]["input_ids"]
    all_tokens = train_tokens + validation_tokens + test_tokens
    # Set up for creating the synthetic dataset
    num_unknown_neurons = num_signals - num_named_neurons
    smooth_method = None
    dataset = dict()
    # Create data for as many worms as possible
    worm_idx = 0
    calcium_data = []
    total_time = 0
    for chunk in all_tokens:
        if worm_idx > 200:
            break
        embd_data = embedding(torch.LongTensor(chunk)).detach().numpy()
        calcium_data.append(embd_data)
        total_time += embd_data.shape[0]
        if total_time >= max_timesteps:
            calcium_data = np.vstack(calcium_data)
            time_in_seconds = np.arange(total_time).reshape(-1, 1)
            dt = np.gradient(time_in_seconds, axis=0)
            dt[dt == 0] = np.finfo(float).eps
            resample_dt = np.round(np.median(dt).item(), 2)
            residual_calcium = np.gradient(calcium_data, axis=0) / dt
            smooth_calcium_data = smooth_data_preprocess(
                calcium_data, time_in_seconds, smooth_method
            )
            smooth_residual_calcium = smooth_data_preprocess(
                residual_calcium, time_in_seconds, smooth_method
            )
            named_neuron_indices = random.sample(range(num_signals), num_named_neurons)
            named_neurons = set(NEURONS_302[i] for i in named_neuron_indices)
            neuron_to_idx = {
                (neuron) if neuron in named_neurons else str(idx): idx
                for idx, neuron in enumerate(NEURONS_302)
            }
            idx_to_neuron = {idx: neuron for neuron, idx in neuron_to_idx.items()}

            worm_data = dict()

            worm_data["worm"] = f"worm{worm_idx}"
            worm_data["dataset"] = dataset_name
            worm_data["smooth_method"] = smooth_method
            worm_data["calcium_data"] = calcium_data
            worm_data["smooth_calcium_data"] = smooth_calcium_data
            worm_data["residual_calcium"] = residual_calcium
            worm_data["smooth_residual_calcium"] = smooth_residual_calcium
            worm_data["max_timesteps"] = total_time
            worm_data["time_in_seconds"] = time_in_seconds
            worm_data["dt"] = dt
            worm_data["resample_median_dt"] = resample_dt
            worm_data["neuron_to_idx"] = neuron_to_idx
            worm_data["idx_to_neuron"] = idx_to_neuron
            worm_data["num_neurons"] = num_signals
            worm_data["num_named_neurons"] = num_named_neurons
            worm_data["num_unknown_neurons"] = num_unknown_neurons

            worm_data = reshape_calcium_data(worm_data)
            dataset[f"worm{worm_idx}"] = worm_data

            worm_idx += 1
            calcium_data = []
            total_time = 0

    return dataset

In [4]:
# Initialize parameters
max_timesteps = 3000
num_signals = 302
num_named_neurons = num_signals  # 50 #  DEBUG
dataset_name = "Shakespeare0000"

# Creating and saving datasets
dataset = create_synthetic_dataset_shakespeare(
    max_timesteps=max_timesteps,
    num_signals=num_signals,
    num_named_neurons=num_named_neurons,
    dataset_name=dataset_name,
)

# Get the number of worms in the dataset
num_worms = len(dataset)

# Save the dataset
save_synthetic_dataset(f"processed/neural/{dataset_name}.pickle", dataset)

[ 0.5547  0.1342 -2.775   0.664   0.1761  0.4463  0.3494  1.41   -0.1451
 -1.047 ]



In [5]:
# # Selecting a worm and all the neurons to plot
# num_worms = len(dataset)
# worm_idx = random.choice([f"worm{i}" for i in range(num_worms)])
# neuron_idx = [idx for idx in dataset[worm_idx]["slot_to_neuron"].keys()][
#     :num_named_neurons
# ]

# # Plotting dataset
# plot_neural_signals(
#     data=dataset[worm_idx]["calcium_data"],
#     time_tensor=dataset[worm_idx]["time_in_seconds"],
#     neuron_idx=neuron_idx,
#     yax_limit=False,
#     suptitle=f"{dataset_name} - {worm_idx}",
# )

# # Visualize covariance matrix
# data = dataset[worm_idx]["smooth_calcium_data"]
# mask = dataset[worm_idx]["named_neurons_mask"]
# neurons = sorted(dataset[worm_idx]["named_neuron_to_slot"])

# X = data[:, mask].numpy()
# n = X.shape[0]
# X_bar = X - np.mean(X, axis=0)
# cov = 1 / (n - 1) * X_bar.T @ X_bar

# plt.figure()
# ax = sns.heatmap(cov, cmap="coolwarm", xticklabels=neurons, yticklabels=neurons)
# ax.set_title(f"Covariance matrix : {dataset_name}, {worm_idx}")
# plt.show()

# # Plotting 3D trajectory
# plot_3d_trajectory(
#     X, axis_labels=tuple(neurons), title=f"{dataset_name} neural trajectory"
# )

In [6]:
# @title HuggingFace Tokenizers Demo

output = tokenizer.encode("Welcome to the 🤗 Tokenizers library.")
encoding = tokenizer("We are very happy to show you the 🤗 Transformers library.")
print(output, end="\n\n")
print(tokenizer.decode(output), end="\n\n")
print(encoding, end="\n\n")

[101, 6160, 2000, 1996, 100, 19204, 17629, 2015, 3075, 1012, 102]

[CLS] welcome to the [UNK] tokenizers library. [SEP]

{'input_ids': [101, 2057, 2024, 2200, 3407, 2000, 2265, 2017, 1996, 100, 19081, 3075, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}



In [7]:
# @title HuggingFace Datasets Demo

text_dataset = load_dataset("tiny_shakespeare")
print(text_dataset, end="\n\n")
print(
    "train:",
    type(text_dataset["train"]["text"]),
    len(text_dataset["train"]["text"]),
    type(text_dataset["train"]["text"][0]),
    len(text_dataset["train"]["text"][0]),
    end="\n\n",
)
print(
    "validation:",
    type(text_dataset["validation"]["text"]),
    len(text_dataset["validation"]["text"]),
    type(text_dataset["validation"]["text"][0]),
    len(text_dataset["validation"]["text"][0]),
    end="\n\n",
)
print(
    "test:",
    type(text_dataset["test"]["text"]),
    len(text_dataset["test"]["text"]),
    type(text_dataset["test"]["text"][0]),
    len(text_dataset["test"]["text"][0]),
    end="\n\n",
)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})

train: <class 'list'> 1 <class 'str'> 1003854

validation: <class 'list'> 1 <class 'str'> 55770

test: <class 'list'> 1 <class 'str'> 55770



In [8]:
# @title Tokenization and Chunking Demo
# @markdown Apply the tokenization and chunking to each split.

text_dataset = text_dataset.map(
    tokenize_and_chunk, batched=True, fn_kwargs=dict(tokenizer=tokenizer)
)
print(text_dataset, end="\n\n")
print(text_dataset["train"]["text"][0], end="\n\n")
print(tokenizer.decode(text_dataset["train"]["input_ids"][0]), end="\n\n")

DatasetDict({
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 1963
    })
    validation: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 109
    })
    test: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 109
    })
})

First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people. All: We know't, we know't. First Citizen: Let us kill him, and we'll have corn at our own price. Is't a verdict? All: No more talking on't; let it be done: away, away! Second Citizen: One word, good citizens. First Citizen: We are accounted poor citizens, the

[CLS] first citizen : before we proceed any further, hear me speak. all : speak, speak. first citizen : you are all resolved rather to die than

In [None]:
# @title Initialize test input

chunk = text_dataset["train"]["input_ids"][-1]
print(tokenizer.decode(chunk), end="\n\n")
embedding = embedding.to(DEVICE)
print(embedding.weight[0, :10].detach().cpu().numpy(), end="\n\n")  # DEBUG
idx = torch.LongTensor(chunk).unsqueeze(0).to(DEVICE)
input = embedding(idx)
mask = torch.ones(input.shape[-1], dtype=torch.bool).unsqueeze(0).to(DEVICE)
print(idx.shape, input.shape, mask.shape, end="\n\n")

In [9]:
# @title TODO: Train a Transformer model on the Dataset


class TransformerModel(torch.nn.Module):
    def __init__(
        self,
        ntoken: int,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.model_type = "Transformer"
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = torch.nn.TransformerEncoderLayer(
            d_model, nhead, d_hid, dropout
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
        self.embedding = torch.nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = torch.nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
            """
            src_mask = torch.nn.Transformer.generate_square_subsequent_mask(
                len(src)
            ).to(DEVICE)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

    @torch.no_grad()
    def generate(
        self,
        idx: torch.LongTensor,
        max_new_tokens: int,
        mask: torch.Tensor,
        embedding: torch.nn.Module,
        temperature=1.0,
        top_k=None,
    ):
        """
        Special generate method for the Transformer model.
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Since we trained the model to directly predict the next token we take the index as the argmin
        over the distance between the output and the embedding table.
        """
        # Set model to evaluation mode
        self.eval()

        # Place embedding table on the same device as idx
        embedding = embedding.to(idx.device)

        # Loop through time
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= MAX_TOKEN_LEN else idx[:, -MAX_TOKEN_LEN:]
            # get appropriate input for the model based on idx
            input = embedding(idx_cond)  # shape (batch_size, seq_len, neurons)
            mask = mask.to(input.device)
            # forward the model to get the output
            outputs = self(input, mask)
            # we only care about the last time step
            outputs = outputs[:, -1, :]
            # print(
            #     f"outputs.shape: {outputs.shape}\nembedding.weight: {embedding.weight.shape}"
            # ) # DEBUG
            # Compute the Euclidean distances between output and embedding table
            distances = torch.norm(outputs - embedding.weight, dim=1)
            # print(f"distances: {distances.shape}")  # DEBUG
            # Find the index of the minimizer
            idx_next = torch.argmin(distances).view(1, 1)
            # print(f"idx: {idx.shape}\nidx_next: {idx_next.shape}") # DEBUG
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx


class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)

In [13]:
# @title Load a trained model

checkpoint_path = "/om2/user/qsimeon/worm-graph/logs/hydra/2023_12_25_16_27_08/exp0/train/checkpoints/model_best.pt"
model = load_model_from_checkpoint(checkpoint_path)

In [14]:
# @title Generate new tokens
max_new_tokens = 100
idx_gen = model.transformer_generate(idx, max_new_tokens, mask, embedding)

print(tokenizer.decode(idx.tolist()[0]), end="\n\n")
print(tokenizer.decode(idx_gen.tolist()[0][-max_new_tokens:]), end="\n\n")

[CLS] and for your love to her lead apes in hell. talk not to me : i will go sit and weep till i can find occasion of revenge. baptista : was ever gentleman thus grieved as i? but who comes here [SEP]

[CLS] prompt : prompt pigeons, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt, prompt,

