In [1]:
# @title Imports

import os
import math
import time
import torch
import nbimporter

from transformers import AutoTokenizer
from tempfile import TemporaryDirectory
from utils import DEVICE, BLOCK_SIZE
from models._utils import NeuralTransformer
from datasets import load_dataset as load_hf_dataset
from CreateSyntheticDataset import tokenize_and_chunk

CUDA device found.


## Language modeling with tiny Shakespeare

In [2]:
# @title HuggingFace Tokenizers
# @markdown Note there are two ways to call the tokenizer's encoder.

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
expl_text = "Welcome to the 🤗 Tokenizers library."
impl_text = "We are very happy to show you the 🤗 Transformers library."
expl_encode = tokenizer.encode(expl_text)
impl_encode = tokenizer(impl_text)
print(
    f"Calling `tokenizer.encode(text)`:\n\ttext: {expl_text}\n\ttokenized: {expl_encode}\n\tdecoded: {tokenizer.decode(expl_encode)}",
    end="\n\n",
)
print(
    f"Calling `tokenizer(text)`:\n\tobject.keys(): {impl_encode.keys()}\n\ttext: {impl_text}\n\ttokenized: {impl_encode['input_ids']}\n\tdecoded: {tokenizer.decode(impl_encode['input_ids'])}",
    end="\n\n",
)

Calling `tokenizer.encode(text)`:
	text: Welcome to the 🤗 Tokenizers library.
	tokenized: [101, 6160, 2000, 1996, 100, 19204, 17629, 2015, 3075, 1012, 102]
	decoded: [CLS] welcome to the [UNK] tokenizers library. [SEP]

Calling `tokenizer(text)`:
	object.keys(): dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
	text: We are very happy to show you the 🤗 Transformers library.
	tokenized: [101, 2057, 2024, 2200, 3407, 2000, 2265, 2017, 1996, 100, 19081, 3075, 1012, 102]
	decoded: [CLS] we are very happy to show you the [UNK] transformers library. [SEP]



In [3]:
# @title HuggingFace Datasets

text_dataset = load_hf_dataset("tiny_shakespeare")
print(text_dataset, end="\n\n")
print("~~~" * 50, end="\n\n")

print(
    "train:",
    type(text_dataset["train"]["text"]),
    len(text_dataset["train"]["text"]),
    type(text_dataset["train"]["text"][0]),
    len(text_dataset["train"]["text"][0]),
    end="\n\n",
)
print(
    "validation:",
    type(text_dataset["validation"]["text"]),
    len(text_dataset["validation"]["text"]),
    type(text_dataset["validation"]["text"][0]),
    len(text_dataset["validation"]["text"][0]),
    end="\n\n",
)
print(
    "test:",
    type(text_dataset["test"]["text"]),
    len(text_dataset["test"]["text"]),
    type(text_dataset["test"]["text"][0]),
    len(text_dataset["test"]["text"][0]),
    end="\n\n",
)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

train: <class 'list'> 1 <class 'str'> 1003854

validation: <class 'list'> 1 <class 'str'> 55770

test: <class 'list'> 1 <class 'str'> 55770



In [4]:
# @title Tokenization and Chunking
# @markdown Apply the tokenization and chunking to each split.

text_dataset = text_dataset.map(
    tokenize_and_chunk, batched=True, fn_kwargs=dict(tokenizer=tokenizer)
)
print(text_dataset, end="\n\n")
print("~~~" * 50, end="\n\n")

print(
    "text_dataset['train']['input_ids']:\n",
    "\ttype:",
    type(text_dataset["train"]["input_ids"]),
    "\n\tlength:",
    len(text_dataset["train"]["input_ids"]),
    end="\n\n",
)
print(
    "text_dataset['train']['input_ids'][0]:\n",
    "\ttype:",
    type(text_dataset["train"]["input_ids"][0]),
    "\n\tlength:",
    len(text_dataset["train"]["input_ids"][0]),
    end="\n\n",
)
print(
    "text_dataset['train']['input_ids'][0][0]:\n",
    "\ttype:",
    type(text_dataset["train"]["input_ids"][0][0]),
    "\n\tvalue:",
    text_dataset["train"]["input_ids"][0][0],
    end="\n\n",
)
print("~~~" * 50, end="\n\n")

print(f"Original sequence (text):\n\t{text_dataset['train']['text'][0]}", end="\n\n")
print(
    f"Encoded sequence (tokens):\n\t {text_dataset['train']['input_ids'][0]}",
    end="\n\n",
)
print(
    f"Decoded sequence (tokens):\n\t {tokenizer.decode(text_dataset['train']['input_ids'][0])}",
    end="\n\n",
)

DatasetDict({
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 1963
    })
    validation: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 109
    })
    test: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 109
    })
})

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

text_dataset['train']['input_ids']:
 	type: <class 'list'> 
	length: 1963



text_dataset['train']['input_ids'][0]:
 	type: <class 'list'> 
	length: 135

text_dataset['train']['input_ids'][0][0]:
 	type: <class 'int'> 
	value: 101

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Original sequence (text):
	First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people. All: We know't, we know't. First Citizen: Let us kill him, and we'll have corn at our own price. Is't a verdict? All: No more talking on't; let it be done: away, away! Second Citizen: One word, good citizens. First Citizen: We are accounted poor citizens, the

Encoded sequence (tokens):
	 [101, 2034, 6926, 1024, 2077, 2057, 10838, 2151, 2582, 1010, 2963, 2033, 3713, 1012, 2035, 1024, 3713, 1010, 3713, 1012, 2034, 692

In [5]:
# @title Define the Transformer model


class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(1, max_len, d_model)  # batch_first=True
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)


class TransformerModel(torch.nn.Module):
    def __init__(
        self,
        ntoken: int,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.model_type = "Transformer"
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = torch.nn.TransformerEncoderLayer(
            d_model, nhead, d_hid, dropout, batch_first=True
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
        self.embedding = torch.nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = torch.nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
        """
        Arguments:
            src: Tensor, shape [batch_size, seq_len]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [batch_size, seq_len, ntoken]
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            src_mask = torch.nn.Transformer.generate_square_subsequent_mask(
                src.size(1)  # Use src.size(1) to get the seq_len
            ).to(
                src.device
            )  # Use src.device to match device of src
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

    @torch.no_grad()
    def generate(
        self,
        idx: torch.LongTensor,
        max_new_tokens: int,
        temperature=1.0,
        top_k=None,
    ):
        """
        Special generate method for the Transformer model.
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Since we trained the model to directly predict the next token we take the index as the argmin
        over the distance between the output and the embedding table.
        """
        # Set model to evaluation mode
        self.eval()

        # Loop through time
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= BLOCK_SIZE else idx[:, -BLOCK_SIZE:]
            # forward the model to get the output
            outputs = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = outputs[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply softmax to convert logits to (normalized) probabilities
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1).view(1, 1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [None]:
# @title Initiate an instance

ntokens = tokenizer.vocab_size
emsize = 302  # embedding dimension
d_hid = 512  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 1  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention (NOTE: nhead must be a divisor of d_hid)
dropout = 0.1  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(DEVICE)

In [None]:
# @title Train the Transformer model

criterion = torch.nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.99)


def train(model: torch.nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 300
    start_time = time.time()
    global epoch

    num_batches = len(text_dataset["train"]["input_ids"])
    for batch in range(num_batches):
        tokens = text_dataset["train"]["input_ids"][batch]
        # parse into input and target
        input = (
            torch.LongTensor(tokens[:-1]).unsqueeze(0).to(DEVICE)
        )  # ``[batch_size=1, seq_len]``
        target = (
            torch.LongTensor(tokens[1:]).unsqueeze(0).reshape(-1).to(DEVICE)
        )  # ``[batch_size=1 * seq_len]``
        # forward pass
        output = model(input)  # ``[batch_size=1, seq_len, ntokens]``
        output_flat = output.view(-1, ntokens)  # ``[batch_size=1 * seq_len, ntokens]``
        # backpropagation step
        optimizer.zero_grad()
        loss = criterion(output_flat, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(
                f"| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | "
                f"lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | ppl {ppl:8.2f}"
            )
            total_loss = 0
            start_time = time.time()


def evaluate(model: torch.nn.Module) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.0
    with torch.no_grad():
        num_batches = len(text_dataset["validation"]["input_ids"])
        for batch in range(num_batches):
            tokens = text_dataset["validation"]["input_ids"][batch]
            input = torch.LongTensor(tokens[:-1]).unsqueeze(0).to(DEVICE)
            target = torch.LongTensor(tokens[1:]).unsqueeze(0).reshape(-1).to(DEVICE)
            seq_len = input.size(1)
            output = model(input)
            output_flat = output.view(-1, ntokens)
            total_loss += criterion(output_flat, target).item()
    return total_loss / num_batches

In [None]:
# @markdown Loop over epochs. Save the model if the validation loss is the best we’ve seen so far. Adjust the learning rate after each epoch.

best_val_loss = float("inf")
epochs = 1

final_model_params_path = os.path.join("../models/", "shakespeare_transformer_model.pt")
if os.path.exists(final_model_params_path):
    print("Loading a previously saved model checkpoint...")
    model.load_state_dict(torch.load(final_model_params_path))

with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print("-" * 89)
        print(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}"
        )
        print("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()

    print(f"Loading and saving the new best model...")
    model.load_state_dict(torch.load(best_model_params_path))  # load best model states
    torch.save(
        model.state_dict(),
        os.path.join("../models/", "shakespeare_transformer_model.pt"),
    )  # save the best model for later use

In [None]:
# @title Generate new text using test input

max_new_tokens = 100
idx = torch.LongTensor(text_dataset["test"]["input_ids"][0]).unsqueeze(0).to(DEVICE)
idx_gen = model.generate(idx, max_new_tokens, top_k=5)

print(idx.shape, idx_gen.shape, end="\n\n")
print(tokenizer.decode(idx.tolist()[0]), end="\n\n")
print(tokenizer.decode(idx_gen.tolist()[0][-max_new_tokens:]), end="\n\n")

## Neural data modeling

In [6]:
# @title Create neural datasets
# @markdown A synthetic dataset where the neural activity is the embeddings of tokens from the tiny Shakespeare.

ntokens = tokenizer.vocab_size
emsize = 302
d_hid = 512
embedding = torch.nn.Embedding(ntokens, emsize, _freeze=True)

train_dataset = [
    torch.vstack([embedding(token) for token in torch.LongTensor(sequence)])
    for sequence in text_dataset["train"]["input_ids"]
]
validation_dataset = [
    torch.vstack([embedding(token) for token in torch.LongTensor(sequence)])
    for sequence in text_dataset["validation"]["input_ids"]
]
test_dataset = [
    torch.vstack([embedding(token) for token in torch.LongTensor(sequence)])
    for sequence in text_dataset["test"]["input_ids"]
]

# get a test sample for an example
data = test_dataset[0].unsqueeze(0)
print("data: ", data.shape, data.dtype, data.requires_grad, data.device, end="\n\n")

mask = torch.ones(emsize, dtype=torch.bool).unsqueeze(0)
print("mask:", mask.shape, mask.dtype, mask.requires_grad, mask.device, end="\n\n")

data:  torch.Size([1, 132, 302]) torch.float32 False cpu

mask: torch.Size([1, 302]) torch.bool False cpu



In [7]:
# @title Create a NeuralTransformer model

model = NeuralTransformer(input_size=emsize, hidden_size=d_hid).to(DEVICE)

# test input-output functionality
mask = mask.to(DEVICE)
input = data[:, :-1, :].to(DEVICE)
target = data[:, 1:, :].to(DEVICE)
output = model(input, mask)

print(
    "input:",
    input.shape,
    input.dtype,
    input.requires_grad,
    input.device,
    end="\n\n",
)
print(
    "target:",
    target.shape,
    target.dtype,
    target.requires_grad,
    target.device,
    end="\n\n",
)
print(
    "output:",
    output.shape,
    output.dtype,
    output.requires_grad,
    output.device,
    end="\n\n",
)

input: torch.Size([1, 131, 302]) torch.float32 False cuda:0

target: torch.Size([1, 131, 302]) torch.float32 False cuda:0

output: torch.Size([1, 131, 2048]) torch.float16 True cuda:0



In [8]:
# @title Train the NeuralTransformer model

lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.99)


def train(model: torch.nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 300
    start_time = time.time()
    global epoch

    num_batches = len(train_dataset)
    for batch in range(num_batches):
        data = train_dataset[batch].unsqueeze(0)
        mask = mask = torch.ones(emsize, dtype=torch.bool).unsqueeze(0)
        # parse into input and target
        mask = mask.to(DEVICE)
        input = data[:, :-1, :].to(DEVICE)
        target = data[:, 1:, :].to(DEVICE)
        # forward pass
        output = model(input, mask)  # ``[batch_size=1, seq_len, ntokens]``
        # backpropopagation step
        optimizer.zero_grad()
        loss = model.loss_fn()(
            output, target, mask
        )  # flatens output to ``[batch_size=1 * seq_len, ntokens]``
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(
                f"| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | "
                f"lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | ppl {ppl:8.2f}"
            )
            total_loss = 0
            start_time = time.time()


def evaluate(model: torch.nn.Module) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.0
    with torch.no_grad():
        num_batches = len(validation_dataset)
        for batch in range(num_batches):
            data = validation_dataset[batch].unsqueeze(0)
            mask = mask = torch.ones(emsize, dtype=torch.bool).unsqueeze(0)
            mask = mask.to(DEVICE)
            input = data[:, :-1, :].to(DEVICE)
            target = data[:, 1:, :].to(DEVICE)
            output = model(input, mask)
            loss = model.loss_fn()(output, target, mask)
            total_loss += loss.item()
    return total_loss / num_batches

In [21]:
# @markdown Loop over epochs. Save the model if the validation loss is the best we’ve seen so far. Adjust the learning rate after each epoch.

best_val_loss = float("inf")
epochs = 10

final_model_params_path = os.path.join("../models/", "neural_transformer_model.pt")
if os.path.exists(final_model_params_path):
    print("Loading a previously saved model checkpoint...")
    model.load_state_dict(torch.load(final_model_params_path))

with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print("-" * 89)
        print(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}"
        )
        print("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"Saving the new best model...")
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()

    print(f"Loading the best overall model...")
    model.load_state_dict(torch.load(best_model_params_path))  # load best model states
    print(f"Saving the best overall model...")
    torch.save(
        model.state_dict(),
        os.path.join("../models/", "neural_transformer_model.pt"),
    )  # save the best model for later use

Loading a previously saved model checkpoint...
| epoch   1 |   300/ 1963 batches | lr 4.48 | ms/batch  8.17 | loss 17.00 | ppl 24236865.91
| epoch   1 |   600/ 1963 batches | lr 4.48 | ms/batch 10.06 | loss 17.44 | ppl 37579872.78
| epoch   1 |   900/ 1963 batches | lr 4.48 | ms/batch  6.93 | loss 17.83 | ppl 55663976.27
| epoch   1 |  1200/ 1963 batches | lr 4.48 | ms/batch  6.77 | loss 16.92 | ppl 22223629.37
| epoch   1 |  1500/ 1963 batches | lr 4.48 | ms/batch  5.88 | loss 16.92 | ppl 22265917.60
| epoch   1 |  1800/ 1963 batches | lr 4.48 | ms/batch  6.01 | loss 16.61 | ppl 16392225.22
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 14.62s | valid loss 14.00 | valid ppl 1197186.19
-----------------------------------------------------------------------------------------
Saving the new best model...
| epoch   2 |   300/ 1963 batches | lr 4.43 | ms/batch  6.29 | loss 16.06 | ppl 9427737.19
| epoch   2 |   600/ 1963

In [22]:
# @title Generate new data using test input

data = test_dataset[0].unsqueeze(0).to(DEVICE)
mask = torch.ones(emsize, dtype=torch.bool).unsqueeze(0).to(DEVICE)
embedding = embedding.to(DEVICE)

max_new_tokens = 100
data_gen = model.transformer_generate(data, mask, max_new_tokens)

print(mask.shape, data.shape, data_gen.shape, embedding.weight.shape, end="\n\n")

# @markdown We want to find the nearest token to each generated embedding.
# We can do this by finding the nearest embedding to each generated embedding
# and then finding the index corresponding to that embedding.

torch.Size([1, 302]) torch.Size([1, 132, 302]) torch.Size([1, 232, 302]) torch.Size([30522, 302])



In [23]:
# First run a test on data we know what the true token output should be
with torch.no_grad():
    sequence = data
    tokens = model.tokenize_neural_data(
        neural_sequence=data, feature_mask=mask, token_matrix=embedding.weight
    )
    print(tokens.squeeze().tolist(), end="\n\n")
    # decode into text
    print(tokenizer.decode(tokens.squeeze().tolist()), end="\n\n")

# If correct this should match
print(text_dataset["test"]["input_ids"][0], end="\n\n")

# Now do the same thing on the newly generated data
with torch.no_grad():
    sequence = data_gen[:, -max_new_tokens:, :]
    tokens = model.tokenize_neural_data(
        neural_sequence=sequence,
        feature_mask=mask,
        token_matrix=embedding.weight,
    )
    print(tokens.squeeze().tolist(), end="\n\n")
    # decode into text
    print(tokenizer.decode(tokens.squeeze().tolist()), end="\n\n")

[101, 2743, 3401, 11937, 1005, 4372, 2004, 4618, 2007, 2593, 2112, 1005, 1055, 3820, 3233, 1029, 7550, 2050, 1024, 2025, 1999, 2026, 2160, 1010, 19913, 16778, 2080, 1025, 2005, 1010, 2017, 2113, 1010, 23232, 2031, 5551, 1010, 1998, 1045, 2031, 2116, 8858, 1024, 4661, 1010, 2214, 24665, 23238, 2080, 2003, 2963, 7520, 2075, 2145, 1025, 1998, 11361, 2057, 2453, 2022, 7153, 1012, 25283, 3695, 1024, 2059, 2012, 2026, 26859, 1010, 2019, 2009, 2066, 2017, 1024, 2045, 11089, 2232, 2026, 2269, 4682, 1025, 1998, 2045, 1010, 2023, 2305, 1010, 2057, 1005, 2222, 3413, 1996, 2449, 9139, 1998, 2092, 1012, 4604, 2005, 2115, 2684, 2011, 2115, 7947, 2182, 1024, 2026, 2879, 4618, 18584, 1996, 8040, 3089, 8159, 2121, 12825, 1012, 1996, 5409, 2003, 2023, 1010, 2008, 1010, 2012, 2061, 10944, 5432, 1010, 2017, 102]


[101, 2743, 3401, 11937, 1005, 4372, 2004, 4618, 2007, 2593, 2112, 1005, 1055, 3820, 3233, 1029, 7550, 2050, 1024, 2025, 1999, 2026, 2160, 1010, 19913, 16778, 2080, 1025, 2005, 1010, 2017, 2113,