In [2]:
# @title Imports

import os
import math
import time
import torch
import random
import nbimporter
import numpy as np

from datasets import load_dataset
from transformers import AutoTokenizer
from tempfile import TemporaryDirectory
from utils import NEURONS_302, DEVICE, MAX_TOKEN_LEN
from tokenizers.pre_tokenizers import WhitespaceSplit
from preprocess._utils import smooth_data_preprocess, reshape_calcium_data
from CreateSyntheticDataset import (
    save_synthetic_dataset,
    plot_neural_signals,
    plot_3d_trajectory,
)

CUDA device found.


#### Create a synthetic dataset where the neural activity is the embeddings of tokens from the tiny Shakespeare dataset.

In [3]:
def pre_tokenize_and_chunk(
    text, max_length=510
):  # slightly less than 512 to account for special tokens
    pre_tokenizer = WhitespaceSplit()
    pre_tokens = pre_tokenizer.pre_tokenize_str(text)

    chunks = []
    current_chunk = ""
    for token, (start, end) in pre_tokens:
        if len(current_chunk) + len(token) + 1 > max_length:
            chunks.append(current_chunk)
            current_chunk = token
        else:
            current_chunk += " " + token if current_chunk else token
    if current_chunk:
        chunks.append(current_chunk)

    return chunks


def tokenize_and_chunk(examples, tokenizer=None):
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    tokenized_batches = {"text": [], "input_ids": [], "attention_mask": []}

    for text in examples["text"]:
        chunks = pre_tokenize_and_chunk(text)
        for chunk in chunks:
            tokenized_output = tokenizer(
                chunk,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            tokenized_batches["text"].append(chunk)
            tokenized_batches["input_ids"].append(
                tokenized_output["input_ids"][0].tolist()
            )
            tokenized_batches["attention_mask"].append(
                tokenized_output["attention_mask"][0].tolist()
            )

    return tokenized_batches

In [4]:
def create_synthetic_dataset_shakespeare(
    max_timesteps: int = 1000,
    num_signals: int = 302,
    num_named_neurons: int = 1,
    dataset_name: str = "Shakespeare0000",
):
    # Want to access the tokenizer and the embedding table outside this function
    global tokenizer, embedding  # DEBUG
    # Load the Shakespeare dataset
    text_dataset = load_dataset("tiny_shakespeare")
    # Create a tokenizer
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    # Apply the tokenization and chunking to each split
    text_dataset = text_dataset.map(
        tokenize_and_chunk, batched=True, fn_kwargs=dict(tokenizer=tokenizer)
    )
    # Create an embedding table
    embedding_dim = num_signals
    embedding = torch.nn.Embedding(
        num_embeddings=tokenizer.vocab_size,
        embedding_dim=embedding_dim,
        dtype=torch.half,
    )
    # Extract all splots of the text dataset
    train_tokens = text_dataset["train"]["input_ids"]
    validation_tokens = text_dataset["validation"]["input_ids"]
    test_tokens = text_dataset["test"]["input_ids"]
    all_tokens = train_tokens + validation_tokens + test_tokens
    # Set up for creating the synthetic dataset
    num_unknown_neurons = num_signals - num_named_neurons
    smooth_method = None
    dataset = dict()
    # Create data for as many worms as possible
    worm_idx = 0
    calcium_data = []
    total_time = 0
    for chunk in all_tokens:
        if worm_idx > 200:
            break
        embd_data = embedding(torch.LongTensor(chunk)).detach().numpy()
        calcium_data.append(embd_data)
        total_time += embd_data.shape[0]
        if total_time >= max_timesteps:
            calcium_data = np.vstack(calcium_data)
            time_in_seconds = np.arange(total_time).reshape(-1, 1)
            dt = np.gradient(time_in_seconds, axis=0)
            dt[dt == 0] = np.finfo(float).eps
            resample_dt = np.round(np.median(dt).item(), 2)
            residual_calcium = np.gradient(calcium_data, axis=0) / dt
            smooth_calcium_data = smooth_data_preprocess(
                calcium_data, time_in_seconds, smooth_method
            )
            smooth_residual_calcium = smooth_data_preprocess(
                residual_calcium, time_in_seconds, smooth_method
            )
            named_neuron_indices = random.sample(range(num_signals), num_named_neurons)
            named_neurons = set(NEURONS_302[i] for i in named_neuron_indices)
            neuron_to_idx = {
                (neuron) if neuron in named_neurons else str(idx): idx
                for idx, neuron in enumerate(NEURONS_302)
            }
            idx_to_neuron = {idx: neuron for neuron, idx in neuron_to_idx.items()}

            worm_data = dict()

            worm_data["worm"] = f"worm{worm_idx}"
            worm_data["dataset"] = dataset_name
            worm_data["smooth_method"] = smooth_method
            worm_data["calcium_data"] = calcium_data
            worm_data["smooth_calcium_data"] = smooth_calcium_data
            worm_data["residual_calcium"] = residual_calcium
            worm_data["smooth_residual_calcium"] = smooth_residual_calcium
            worm_data["max_timesteps"] = total_time
            worm_data["time_in_seconds"] = time_in_seconds
            worm_data["dt"] = dt
            worm_data["resample_median_dt"] = resample_dt
            worm_data["neuron_to_idx"] = neuron_to_idx
            worm_data["idx_to_neuron"] = idx_to_neuron
            worm_data["num_neurons"] = num_signals
            worm_data["num_named_neurons"] = num_named_neurons
            worm_data["num_unknown_neurons"] = num_unknown_neurons

            worm_data = reshape_calcium_data(worm_data)
            dataset[f"worm{worm_idx}"] = worm_data

            worm_idx += 1
            calcium_data = []
            total_time = 0

    return dataset

In [5]:
# # Initialize parameters
# max_timesteps = 3000
# num_signals = 302
# num_named_neurons = num_signals  # 50 #  DEBUG
# dataset_name = "Shakespeare0000"

# # Creating and saving datasets
# dataset = create_synthetic_dataset_shakespeare(
#     max_timesteps=max_timesteps,
#     num_signals=num_signals,
#     num_named_neurons=num_named_neurons,
#     dataset_name=dataset_name,
# )

# # Get the number of worms in the dataset
# num_worms = len(dataset)

# # Save the dataset
# save_synthetic_dataset(f"processed/neural/{dataset_name}.pickle", dataset)

In [6]:
# # Selecting a worm and all the neurons to plot
# num_worms = len(dataset)
# worm_idx = random.choice([f"worm{i}" for i in range(num_worms)])
# neuron_idx = [idx for idx in dataset[worm_idx]["slot_to_neuron"].keys()][
#     :num_named_neurons
# ]

# # Plotting dataset
# plot_neural_signals(
#     data=dataset[worm_idx]["calcium_data"],
#     time_tensor=dataset[worm_idx]["time_in_seconds"],
#     neuron_idx=neuron_idx,
#     yax_limit=False,
#     suptitle=f"{dataset_name} - {worm_idx}",
# )

# # Visualize covariance matrix
# data = dataset[worm_idx]["smooth_calcium_data"]
# mask = dataset[worm_idx]["named_neurons_mask"]
# neurons = sorted(dataset[worm_idx]["named_neuron_to_slot"])

# X = data[:, mask].numpy()
# n = X.shape[0]
# X_bar = X - np.mean(X, axis=0)
# cov = 1 / (n - 1) * X_bar.T @ X_bar

# plt.figure()
# ax = sns.heatmap(cov, cmap="coolwarm", xticklabels=neurons, yticklabels=neurons)
# ax.set_title(f"Covariance matrix : {dataset_name}, {worm_idx}")
# plt.show()

# # Plotting 3D trajectory
# plot_3d_trajectory(
#     X, axis_labels=tuple(neurons), title=f"{dataset_name} neural trajectory"
# )

In [7]:
# @title HuggingFace Tokenizers
# @markdown Note there are two ways to call the tokenizer's encoder.

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
expl_text = "Welcome to the 🤗 Tokenizers library."
impl_text = "We are very happy to show you the 🤗 Transformers library."
expl_encode = tokenizer.encode(expl_text)
impl_encode = tokenizer(impl_text)
print(
    f"Calling `tokenizer.encode(text)`:\n\ttext: {expl_text}\n\ttokenized: {expl_encode}\n\tdecoded: {tokenizer.decode(expl_encode)}",
    end="\n\n",
)
print(
    f"Calling `tokenizer(text)`:\n\tobject.keys(): {impl_encode.keys()}\n\ttext: {impl_text}\n\ttokenized: {impl_encode['input_ids']}\n\tdecoded: {tokenizer.decode(impl_encode['input_ids'])}",
    end="\n\n",
)

Calling `tokenizer.encode(text)`:
	text: Welcome to the 🤗 Tokenizers library.
	tokenized: [101, 6160, 2000, 1996, 100, 19204, 17629, 2015, 3075, 1012, 102]
	decoded: [CLS] welcome to the [UNK] tokenizers library. [SEP]

Calling `tokenizer(text)`:
	object.keys(): dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
	text: We are very happy to show you the 🤗 Transformers library.
	tokenized: [101, 2057, 2024, 2200, 3407, 2000, 2265, 2017, 1996, 100, 19081, 3075, 1012, 102]
	decoded: [CLS] we are very happy to show you the [UNK] transformers library. [SEP]



In [8]:
# @title HuggingFace Datasets

text_dataset = load_dataset("tiny_shakespeare")
print(text_dataset, end="\n\n")
print("~~~" * 50, end="\n\n")

print(
    "train:",
    type(text_dataset["train"]["text"]),
    len(text_dataset["train"]["text"]),
    type(text_dataset["train"]["text"][0]),
    len(text_dataset["train"]["text"][0]),
    end="\n\n",
)
print(
    "validation:",
    type(text_dataset["validation"]["text"]),
    len(text_dataset["validation"]["text"]),
    type(text_dataset["validation"]["text"][0]),
    len(text_dataset["validation"]["text"][0]),
    end="\n\n",
)
print(
    "test:",
    type(text_dataset["test"]["text"]),
    len(text_dataset["test"]["text"]),
    type(text_dataset["test"]["text"][0]),
    len(text_dataset["test"]["text"][0]),
    end="\n\n",
)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

train: <class 'list'> 1 <class 'str'> 1003854

validation: <class 'list'> 1 <class 'str'> 55770

test: <class 'list'> 1 <class 'str'> 55770



In [9]:
# @title Tokenization and Chunking
# @markdown Apply the tokenization and chunking to each split.

text_dataset = text_dataset.map(
    tokenize_and_chunk, batched=True, fn_kwargs=dict(tokenizer=tokenizer)
)
print(text_dataset, end="\n\n")
print("~~~" * 50, end="\n\n")

print(
    "text_dataset['train']['input_ids']:\n",
    "\ttype:",
    type(text_dataset["train"]["input_ids"]),
    "\n\tlength:",
    len(text_dataset["train"]["input_ids"]),
    end="\n\n",
)
print(
    "text_dataset['train']['input_ids'][0]:\n",
    "\ttype:",
    type(text_dataset["train"]["input_ids"][0]),
    "\n\tlength:",
    len(text_dataset["train"]["input_ids"][0]),
    end="\n\n",
)
print(
    "text_dataset['train']['input_ids'][0][0]:\n",
    "\ttype:",
    type(text_dataset["train"]["input_ids"][0][0]),
    "\n\tvalue:",
    text_dataset["train"]["input_ids"][0][0],
    end="\n\n",
)
print("~~~" * 50, end="\n\n")

print(f"Original sequence (text):\n\t{text_dataset['train']['text'][0]}", end="\n\n")
print(
    f"Encoded sequence (tokens):\n\t {text_dataset['train']['input_ids'][0]}",
    end="\n\n",
)
print(
    f"Decoded sequence (tokens):\n\t {tokenizer.decode(text_dataset['train']['input_ids'][0])}",
    end="\n\n",
)

DatasetDict({
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 1963
    })
    validation: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 109
    })
    test: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 109
    })
})

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

text_dataset['train']['input_ids']:
 	type: <class 'list'> 
	length: 1963



text_dataset['train']['input_ids'][0]:
 	type: <class 'list'> 
	length: 135

text_dataset['train']['input_ids'][0][0]:
 	type: <class 'int'> 
	value: 101

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Original sequence (text):
	First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people. All: We know't, we know't. First Citizen: Let us kill him, and we'll have corn at our own price. Is't a verdict? All: No more talking on't; let it be done: away, away! Second Citizen: One word, good citizens. First Citizen: We are accounted poor citizens, the

Encoded sequence (tokens):
	 [101, 2034, 6926, 1024, 2077, 2057, 10838, 2151, 2582, 1010, 2963, 2033, 3713, 1012, 2035, 1024, 3713, 1010, 3713, 1012, 2034, 692

In [10]:
# @title Define the Transformer model


class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(1, max_len, d_model)  # batch_first=True
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)


class TransformerModel(torch.nn.Module):
    def __init__(
        self,
        ntoken: int,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.model_type = "Transformer"
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = torch.nn.TransformerEncoderLayer(
            d_model, nhead, d_hid, dropout, batch_first=True
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
        self.embedding = torch.nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = torch.nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
        """
        Arguments:
            src: Tensor, shape [batch_size, seq_len]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [batch_size, seq_len, ntoken]
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            src_mask = torch.nn.Transformer.generate_square_subsequent_mask(
                src.size(1)  # Use src.size(1) to get the seq_len
            ).to(
                src.device
            )  # Use src.device to match device of src
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

    @torch.no_grad()
    def generate(
        self,
        idx: torch.LongTensor,
        max_new_tokens: int,
        temperature=1.0,
        top_k=None,
    ):
        """
        Special generate method for the Transformer model.
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Since we trained the model to directly predict the next token we take the index as the argmin
        over the distance between the output and the embedding table.
        """
        # Set model to evaluation mode
        self.eval()

        # Loop through time
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= MAX_TOKEN_LEN else idx[:, -MAX_TOKEN_LEN:]
            # forward the model to get the output
            outputs = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = outputs[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply softmax to convert logits to (normalized) probabilities
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1).view(1, 1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [11]:
# @title Initiate an instance

ntokens = tokenizer.vocab_size
emsize = 302  # embedding dimension
d_hid = 302  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 1  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention (NOTE: nhead must be a divisor of d_hid)
dropout = 0.1  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(DEVICE)

In [12]:
# @title Train the Transformer model

criterion = torch.nn.CrossEntropyLoss()
lr = 1.0  # 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)


def train(model: torch.nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 300
    start_time = time.time()
    global epoch

    num_batches = len(text_dataset["train"]["input_ids"])
    for batch in range(num_batches):
        tokens = text_dataset["train"]["input_ids"][batch]
        data = (
            torch.LongTensor(tokens[:-1]).unsqueeze(0).to(DEVICE)
        )  # ``[batch_size=1, seq_len]``
        targets = (
            torch.LongTensor(tokens[1:]).unsqueeze(0).reshape(-1).to(DEVICE)
        )  # ``[batch_size=1 * seq_len]``
        seq_len = data.size(1)
        output = model(data)  # ``[batch_size=1, seq_len, ntokens]``
        output_flat = output.view(-1, ntokens)  # ``[batch_size=1 * seq_len, ntokens]``
        loss = criterion(output_flat, targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(
                f"| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | "
                f"lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | ppl {ppl:8.2f}"
            )
            total_loss = 0
            start_time = time.time()


def evaluate(model: torch.nn.Module) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.0
    with torch.no_grad():
        num_batches = len(text_dataset["validation"]["input_ids"])
        for batch in range(num_batches):
            tokens = text_dataset["validation"]["input_ids"][batch]
            data = torch.LongTensor(tokens[:-1]).unsqueeze(0).to(DEVICE)
            targets = torch.LongTensor(tokens[1:]).unsqueeze(0).reshape(-1).to(DEVICE)
            seq_len = data.size(1)
            output = model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += criterion(output_flat, targets).item()
    return total_loss / num_batches

In [13]:
# @markdown Loop over epochs. Save the model if the validation loss is the best we’ve seen so far. Adjust the learning rate after each epoch.

best_val_loss = float("inf")
epochs = 1

final_model_params_path = os.path.join("../models/", "shakespeare_transformer_model.pt")
if os.path.exists(final_model_params_path):
    print("Loading a previously saved model checkpoint...")
    model.load_state_dict(torch.load(final_model_params_path))

with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print("-" * 89)
        print(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}"
        )
        print("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()

    print(f"Loading and saving the new best model...")
    model.load_state_dict(torch.load(best_model_params_path))  # load best model states
    torch.save(
        model.state_dict(),
        os.path.join("../models/", "shakespeare_transformer_model.pt"),
    )  # save the best model for later use

Loading a previously saved model checkpoint...


| epoch   1 |   300/ 1963 batches | lr 1.00 | ms/batch 107.07 | loss  4.12 | ppl    61.86
| epoch   1 |   600/ 1963 batches | lr 1.00 | ms/batch 104.77 | loss  4.17 | ppl    64.61
| epoch   1 |   900/ 1963 batches | lr 1.00 | ms/batch 103.82 | loss  4.26 | ppl    71.07
| epoch   1 |  1200/ 1963 batches | lr 1.00 | ms/batch 112.13 | loss  4.25 | ppl    70.26
| epoch   1 |  1500/ 1963 batches | lr 1.00 | ms/batch 121.48 | loss  4.24 | ppl    69.08
| epoch   1 |  1800/ 1963 batches | lr 1.00 | ms/batch 105.14 | loss  4.25 | ppl    70.36
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 214.47s | valid loss  4.92 | valid ppl   137.00
-----------------------------------------------------------------------------------------
Loading and saving the new best model...


In [15]:
# @title Generate new text using test input

max_new_tokens = 100
idx = torch.LongTensor(text_dataset["test"]["input_ids"][0]).unsqueeze(0).to(DEVICE)
idx_gen = model.generate(idx, max_new_tokens)

print(idx.shape, idx_gen.shape, end="\n\n")
print(tokenizer.decode(idx.tolist()[0]), end="\n\n")
print(tokenizer.decode(idx_gen.tolist()[0][-max_new_tokens:]), end="\n\n")

torch.Size([1, 132]) torch.Size([1, 232])


. katharina : i will see you do this is that for me at all hold you [SEP] : my gossips of signior baptista? if you [SEP]. gremio : signior baptista to be that ; yea minola as nay, master frowapost [SEP] heaven. gremio : this wicked trusty you [SEP] you small husband. bianca : one that is condemn thee! what you so reproductive to be long. tranio ; if you know striker with a poor

