In [1]:
import os
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
import string

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformer import Transformer, SinusoidalPositionalEncoding

## Dataset

We begin by downloading the dataset (credit to Andrej Karpathy) to the `data/shakespeare.txt` file:

In [2]:
DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
DATA_DIR = "data"
DATA_PATH = f"{DATA_DIR}/shakespeare.txt"

def load_shakespeare():
    # Download Shakespeare's works
    os.makedirs(DATA_DIR, exist_ok=True)
    if not os.path.exists(DATA_PATH):
        response = requests.get(DATA_URL)
        with open(DATA_PATH, 'wb') as f:
            f.write(response.content)
    
    # Load the text
    with open(DATA_PATH, 'r') as f:
        text = f.read()

    return text

The dataset is a single `.txt` file with text from Shakespeare. Let's now create a very basic tokenizer, that maps each character to a different token:

In [3]:
class CharTokenizer:
    def __init__(self, text: str):
        self.vocab = sorted(list(set(text)))
        self.char_to_idx = {ch: idx for idx, ch in enumerate(self.vocab)}
        self.idx_to_char = {idx: ch for idx, ch in enumerate(self.vocab)}
        self.padding_idx = len(self.vocab)

    def encode(self, text: str):
        return [self.char_to_idx[ch] for ch in text]
    
    def decode(self, indices: list[int]):
        return ''.join([self.idx_to_char[idx] for idx in indices if idx != self.padding_idx])
    
    def batch_encode(self, texts: list[str], max_len=None):
        encoded_texts = [self.encode(text) for text in texts]

        max_len_texts = max(len(text) for text in encoded_texts)
        if max_len is None:
            max_len = max_len_texts
        else:
            max_len = min(max_len, max_len_texts)

        padded_texts = [
            text + [self.padding_idx] * (max_len - len(text))
            for text in encoded_texts
        ]
        attention_mask = [
            [True] * len(text) + [False] * (max_len - len(text))
            for text in encoded_texts
        ]

        return padded_texts, attention_mask
    
    def batch_decode(self, indices: list[list[int]]):
        return [self.decode(text) for text in indices]
    
    def get_vocab_size(self):
        return len(self.vocab) + 1  # +1 for padding
    
    def get_padding_idx(self):
        return self.padding_idx

Let's explain the code step by step. This code defines a Python class `CharTokenizer` that is responsible for converting text into sequences of character indices and vice versa, as well as handling padding and batching for processing multiple texts.

---

### 1. **Constructor: `__init__`**

```python
def __init__(self, text: str):
    self.vocab = sorted(list(set(text)))
    self.char_to_idx = {ch: idx for idx, ch in enumerate(self.vocab)}
    self.idx_to_char = {idx: ch for idx, ch in enumerate(self.vocab)}
    self.padding_idx = len(self.vocab)
```

- **`__init__(self, text: str)`**: This is the constructor method that gets called when an instance of the class is created. It initializes the object with a `text` string.
  
- **`self.vocab = sorted(list(set(text)))`**:
  - `set(text)` converts the `text` string into a set of unique characters.
  - `list(set(text))` turns the set into a list, which is then sorted to create an ordered list of unique characters, forming the vocabulary.

- **`self.char_to_idx = {ch: idx for idx, ch in enumerate(self.vocab)}`**:
  - This dictionary comprehension creates a mapping of characters (`ch`) to their corresponding indices (`idx`).
  - For example, if the vocabulary is `['a', 'b', 'c']`, the dictionary would be `{'a': 0, 'b': 1, 'c': 2}`.

- **`self.idx_to_char = {idx: ch for idx, ch in enumerate(self.vocab)}`**:
  - This dictionary comprehension does the reverse mapping, creating a dictionary that maps indices to characters.
  - For example, `idx_to_char = {0: 'a', 1: 'b', 2: 'c'}`.

- **`self.padding_idx = len(self.vocab)`**:
  - The padding index is assigned the value of the length of the vocabulary.
  - This means that the padding character will have an index one greater than the highest index of any character in the vocabulary.
  - The padding token is useful when we encode texts of different lengths in a single batch. Since all of the sequences in a tensor must have the same length, the shorter ones are padded to match length of the largest ones, with this padding token.

---

### 2. **Method: `encode`**

```python
def encode(self, text: str):
    return [self.char_to_idx[ch] for ch in text]
```

- **`encode(self, text: str)`**: This method takes a string `text` and returns a list of indices corresponding to each character in the string based on the `char_to_idx` dictionary.
  
- For example, if `text = "abc"`, it will return `[0, 1, 2]` (based on the vocab `['a', 'b', 'c']`).

---

### 3. **Method: `decode`**

```python
def decode(self, indices: list[int]):
    return ''.join([self.idx_to_char[idx] for idx in indices if idx != self.padding_idx])
```

- **`decode(self, indices: list[int])`**: This method takes a list of character indices and converts them back into a string of characters.
  
- **`''.join([...])`**: Joins the list of characters into a single string.
  
- **`if idx != self.padding_idx`**: This ensures that if there are any padding indices (those with value `self.padding_idx`), they are ignored during the decoding process.

- For example, if `indices = [0, 1, 2]` and the vocab is `['a', 'b', 'c']`, it will return `"abc"`. If the list contains padding indices like `[0, 1, 2, 4]` (where 4 is the padding index), it will ignore the padding index and return `"abc"`.

---

### 4. **Method: `batch_encode`**

```python
def batch_encode(self, texts: list[str], max_len=None):
    encoded_texts = [self.encode(text) for text in texts]

    max_len_texts = max(len(text) for text in encoded_texts)
    if max_len is None:
        max_len = max_len_texts
    else:
        max_len = min(max_len, max_len_texts)

    padded_texts = [
        text + [self.padding_idx] * (max_len - len(text))
        for text in encoded_texts
    ]
    attention_mask = [
        [True] * len(text) + [False] * (max_len - len(text))
        for text in encoded_texts
    ]

    return padded_texts, attention_mask
```

- **`batch_encode(self, texts: list[str], max_len=None)`**: This method takes a list of strings (`texts`) and encodes them into a batch of padded sequences. Optionally, a `max_len` can be provided to specify the maximum length of the padded sequences.

- **`encoded_texts = [self.encode(text) for text in texts]`**: Each string in the `texts` list is encoded into a list of indices using the `encode` method.

- **`max_len_texts = max(len(text) for text in encoded_texts)`**: This computes the maximum length of all encoded texts in the batch.

- **`if max_len is None:`**:
  - If `max_len` is not provided, the method uses `max_len_texts` (the length of the longest encoded text).
  - If `max_len` is provided, it will be used, but the method will take the smaller of `max_len` and `max_len_texts` to avoid unnecessary length.

- **`padded_texts = [...]`**:
  - For each encoded text, it pads the text with the `padding_idx` (using list comprehension) until the length reaches `max_len`.

- **`attention_mask = [...]`**:
  - This creates a corresponding attention mask for each padded text.
  - `True` values correspond to actual characters, and `False` values correspond to padding positions.

- **Return**:
  - The method returns two lists:
    - `padded_texts`: A list of padded character sequences.
    - `attention_mask`: A list of masks, indicating which positions are actual characters and which are padding.

---

### 5. **Method: `batch_decode`**

```python
def batch_decode(self, indices: list[list[int]]):
    return [self.decode(text) for text in indices]
```

- **`batch_decode(self, indices: list[list[int]])`**: This method takes a batch of encoded texts (a list of lists of indices) and decodes each one using the `decode` method.

- It returns a list of decoded strings, one for each list of indices in the input batch.

---

### 6. **Method: `get_vocab_size`**

```python
def get_vocab_size(self):
    return len(self.vocab) + 1  # +1 for padding
```

- **`get_vocab_size(self)`**: This method returns the size of the vocabulary, including the padding token (which has its own index).
  - The length of `self.vocab` gives the number of unique characters in the vocabulary.
  - `+1` accounts for the padding token, which increases the vocabulary size.

---

### 7. **Method: `get_padding_idx`**

```python
def get_padding_idx(self):
    return self.padding_idx
```

- **`get_padding_idx(self)`**: This method simply returns the index reserved for padding.

---

### Summary of the Class

The `CharTokenizer` class provides the following functionality:
- **Tokenization**: Converts text into a sequence of character indices (`encode`).
- **Detokenization**: Converts sequences of indices back into text, ignoring padding (`decode`).
- **Batch Processing**: Handles multiple texts at once, encoding them, padding them to a fixed length, and creating attention masks for each sequence (`batch_encode` and `batch_decode`).
- **Padding**: Handles padding by assigning a special index to padding tokens (`padding_idx`).
- **Vocabulary Information**: Provides the size of the vocabulary and the padding index (`get_vocab_size` and `get_padding_idx`).

Let's now build the dataset, which simply takes the source text and the sequence length. The class basically breaks down the text into chunks of the size provided during initialization:

In [4]:
class TextDataset(Dataset):
    def __init__(self, text, seq_len):
        self.text = text
        self.seq_len = seq_len

    def __len__(self):
        return len(self.text) // self.seq_len

    def __getitem__(self, idx):
        start_idx, end_idx = idx * self.seq_len, (idx + 1) * self.seq_len
        return self.text[start_idx:end_idx]

We notice two special things in this toy dataset, which will not be present in a real world scenario:
- We have only one text, and all chunks will be of the same size, except perhaps the last one. This makes padding and masking less important (we could drop the last chunk and have all samples with the same length). However, in real datasets we have multiple texts, each with a different length.
- We also find that with a single text, it doesn't really make sense to add a special token `End of Sequence`, since it would only appear once in the dataset. **This means that the model is not learning how to stop generating on its own**. Thus, we cap the generation process with a maximum number of tokens.

Let's now create a `CollateFn` function to prepare samples for our training procedure:

In [5]:
class CollateFn:
    def __init__(self, tokenizer, max_seq_len):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __call__(self, batch):
        # batch: List[str]
        token_idxs, attention_mask = self.tokenizer.batch_encode(batch, self.max_seq_len)

        token_idxs = torch.tensor(token_idxs)  # (batch_size, seq_len)
        attention_mask = torch.tensor(attention_mask)  # (batch_size, seq_len)

        input_ids = token_idxs[:, :-1]  # (batch_size, seq_len - 1)
        target_ids = token_idxs[:, 1:]  # (batch_size, seq_len - 1)
        input_attention_mask = attention_mask[:, :-1]  # (batch_size, seq_len - 1)
        target_attention_mask = attention_mask[:, 1:]  # (batch_size, seq_len - 1)

        return input_ids, target_ids, input_attention_mask, target_attention_mask

1. **Class Definition: `CollateFn`**

```python
class CollateFn:
    def __init__(self, tokenizer, max_seq_len):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
```

- **`CollateFn`**: This class defines a custom collate function. A collate function is responsible for batching and transforming raw data into a form that the model can process.
- **`__init__(self, tokenizer, max_seq_len)`**: This is the constructor method that initializes the `CollateFn` class.
  - **`tokenizer`**: This is an instance of a tokenizer, which is used to convert text into token indices.
  - **`max_seq_len`**: This parameter specifies the maximum sequence length for each input text after padding/truncation. If any input text exceeds this length, it will be truncated.

---

2. **Calling the Collate Function**

```python
def __call__(self, batch):
```

- **`__call__(self, batch)`**: This is the function that gets called when the collate function is used. It takes a `batch` as input, which is a list of strings (text data).

---

3. **Batch Encoding and Padding**

```python
token_idxs, attention_mask = self.tokenizer.batch_encode(batch, self.max_seq_len)
```

- **`batch_encode`**:
  - The `batch_encode` method of the tokenizer is called to encode the list of text data (`batch`) into token indices.
  - **`batch`**: This is a list of strings, where each string is a sentence or text.
  - **`max_seq_len`**: The maximum sequence length after padding/truncation.
  
- **Output**: 
  - `token_idxs`: A list of token indices for each sentence in the batch. This represents the input tokens in numerical form.
  - `attention_mask`: A mask that indicates which tokens are actual tokens (True) and which are padding (False). It helps the model ignore padded tokens.

---

4. **Convert to Tensors**

```python
token_idxs = torch.tensor(token_idxs)  # (batch_size, seq_len)
attention_mask = torch.tensor(attention_mask)  # (batch_size, seq_len)
```

- **`torch.tensor(token_idxs)`**: Converts the list of token indices into a PyTorch tensor of shape `(batch_size, seq_len)`. `batch_size` is the number of sentences in the batch, and `seq_len` is the length of the token sequence (after padding/truncation).
  
- **`torch.tensor(attention_mask)`**: Converts the list of attention masks into a tensor of shape `(batch_size, seq_len)`. Each entry in the attention mask tensor will be `True` (1) for actual tokens and `False` (0) for padding tokens.

---

5. **Create Input and Target Sequences for Next-Token Prediction**

```python
input_ids = token_idxs[:, :-1]  # (batch_size, seq_len - 1)
target_ids = token_idxs[:, 1:]  # (batch_size, seq_len - 1)
attention_mask = attention_mask[:, :-1]  # (batch_size, seq_len - 1)
target_attention_mask = attention_mask[:, 1:]  # (batch_size, seq_len - 1)
```

- **`input_ids = token_idxs[:, :-1]`**:
  - This slices the `token_idxs` tensor to remove the last token from each sequence in the batch.
  - The `input_ids` represent the input tokens used to predict the next token. For example, given a sequence `[a, b, c]`, the `input_ids` will be `[a, b]`.
  - The shape of `input_ids` is `(batch_size, seq_len - 1)`.
- **`target_ids = token_idxs[:, 1:]`**:
  - This slices the `token_idxs` tensor to remove the first token from each sequence in the batch.
  - The `target_ids` represent the "true" next token for the model to predict. For example, given a sequence `[a, b, c]`, the `target_ids` will be `[b, c]`.
  - The shape of `target_ids` is also `(batch_size, seq_len - 1)`.
- **`input_attention_mask = attention_mask[:, :-1]`**:
  - This slices the `attention_mask` to match the length of the `input_ids` (i.e., removing the last token's attention mask).
  - This mask will be used when passing the input tokens through the model.
  - The resulting `attention_mask` tensor will have shape `(batch_size, seq_len - 1)` and will indicate which tokens are actual tokens (True) and which are padding (False).
- **`target_attention_mask = attention_mask[:, 1:]`**:
  - This slices the `attention_mask` to match the length of the `target_input_ids` (i.e., removing the first token's attention mask).
  - This mask will be used when calculating the final loss. Since padding tokens should not be predicted, we must be able to remove them from the loss!
  - The resulting `attention_mask` tensor will have shape `(batch_size, seq_len - 1)` and will indicate which tokens are actual tokens (True) and which are padding (False).

---

6. **Return the Processed Batch**

```python
return input_ids, target_ids, attention_mask
```

- **Return Values**:
  - **`input_ids`**: The input tokens (all but the last token) of shape `(batch_size, seq_len - 1)`.
  - **`target_ids`**: The target tokens (all but the first token) of shape `(batch_size, seq_len - 1)`.
  - **`attention_mask`**: A mask of shape `(batch_size, seq_len - 1)` that indicates which tokens are real and which are padding.

These values are now ready for input to a model that performs next-token prediction. The model will use `input_ids` to predict the next token, and the loss can be computed by comparing the predictions with `target_ids`.

We can print an example:

In [6]:
text = load_shakespeare()
tokenizer = CharTokenizer(text)
max_seq_len = 100

dataset = TextDataset(text, max_seq_len)
collate_fn = CollateFn(tokenizer, max_seq_len)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [7]:
input_idxs, target_idxs, input_attention_mask, target_attention_mask  = next(iter(dataloader))
input_idxs[0], target_idxs[0], input_attention_mask[0], target_attention_mask[0]

(tensor([46, 47, 41, 46,  1, 57, 46, 39, 50, 50,  1, 40, 56, 43, 39, 49,  1, 46,
         47, 57,  1, 52, 43, 41, 49,  1, 53, 56,  1, 46, 39, 64, 39, 56, 42,  1,
         51, 47, 52, 43,  6,  0, 35, 46, 43, 52, 43,  5, 43, 56,  1, 61, 43,  1,
         41, 53, 51, 43,  1, 58, 53,  1, 53, 59, 56,  1, 39, 41, 41, 53, 59, 52,
         58,  8,  0,  0, 24, 47, 43, 59, 58, 43, 52, 39, 52, 58, 10,  0, 31, 47,
         56,  6,  1, 21,  1, 40, 43, 57, 43]),
 tensor([47, 41, 46,  1, 57, 46, 39, 50, 50,  1, 40, 56, 43, 39, 49,  1, 46, 47,
         57,  1, 52, 43, 41, 49,  1, 53, 56,  1, 46, 39, 64, 39, 56, 42,  1, 51,
         47, 52, 43,  6,  0, 35, 46, 43, 52, 43,  5, 43, 56,  1, 61, 43,  1, 41,
         53, 51, 43,  1, 58, 53,  1, 53, 59, 56,  1, 39, 41, 41, 53, 59, 52, 58,
          8,  0,  0, 24, 47, 43, 59, 58, 43, 52, 39, 52, 58, 10,  0, 31, 47, 56,
          6,  1, 21,  1, 40, 43, 57, 43, 43]),
 tensor([True, True, True, True, True, True, True, True, True, True, True, True,
         True, 

Finally, let's create a Pytorch Lightning data module, and also create a split of the dataset with a train and validation set.

In [8]:
class ShakespeareDataModule(pl.LightningDataModule):
    def __init__(self, text, tokenizer, max_seq_len, batch_size):
        super().__init__()
        self.text = text
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size

    def setup(self, stage=None):
        dataset = TextDataset(self.text, self.max_seq_len)

        train_size = int(0.9 * len(dataset))
        val_size = len(dataset) - train_size
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            dataset, [train_size, val_size]
        )

        self.collate_fn = CollateFn(self.tokenizer, self.max_seq_len)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

## Model

We will be using a Transformer model to predict the next token. The purpose of this class is not to learn how to code the Transformer, so it is provided for you in the `transformer.py` file, already prepared with the **causal mask** for next-token prediction as explained in class. This model is properly explained in the Deep Learning class of the Master. The rest of the Pytorch Lightning Module is explained after the code:

In [9]:
class ShakespeareLightningModel(pl.LightningModule):
    def __init__(
        self,
        # Transformer params
        vocab_size: int,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        dropout: float,
        num_layers: int,
        # Embedding params
        padding_idx: int,
        # Positional encoding params
        max_len: int,
        # Training params
        optimizer_params: dict,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.padding_idx = padding_idx

        # Transformer model
        self.model = Transformer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            num_layers=num_layers,
        )

        # Embedding layers
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
        # self.pos_embedding = nn.Embedding(max_len, d_model)

        # Classification head
        self.fc = nn.Linear(d_model, vocab_size)

        # Training params
        self.optimizer_params = optimizer_params

    def forward(self, input_idxs, attention_mask):
        input_embed = self.embedding(input_idxs)  # (batch_size, seq_len, d_model)

        _, seq_len, d_model = input_embed.size()

        # NOTE Option 1: Learnable positional embeddings
        # positions = (
        #     torch.arange(seq_len, device=self.device)
        #     .unsqueeze(0)
        #     .to(self.device)
        # )  # (1, seq_len)
        # pos_embeddings = self.pos_embedding(positions)  # (1, seq_len, d_model)

        # NOTE Option 2: Sinusoidal positional embeddings
        pos_embeddings = (
            SinusoidalPositionalEncoding
            .get_positional_encoding(seq_len, d_model)
            .to(self.device)
            .unsqueeze(0)
        )  # (1, seq_len, d_model)

        input_embed = input_embed + pos_embeddings
        output = self.model(input_embed, attention_mask)
        return self.fc(output)
    
    def training_step(self, batch, batch_idx):
        loss = self._step(batch)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
        return loss
     
    def validation_step(self, batch, batch_idx):
        loss = self._step(batch)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
        return loss
    
    def _step(self, batch):
        input_idxs, target_idxs, input_attention_mask, target_attention_mask = batch  # (batch_size, seq_len - 1)
        output = self(input_idxs, input_attention_mask)  # (batch_size, seq_len - 1, vocab_size)

        B, L, V = output.size()
        loss = F.cross_entropy(output.view(B * L, V), target_idxs.view(-1), reduction="none")  # (B * L)

        # Mask out the padding tokens
        target_attention_mask = target_attention_mask.view(-1) / target_attention_mask.sum()  # (B * L)
        loss = (loss * target_attention_mask).sum()

        return loss
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), **self.optimizer_params)
        return optimizer
    
    def configure_callbacks(self):
        return super().configure_callbacks() + [
            pl.callbacks.ModelCheckpoint(monitor="val_loss"),
        ]


### Step-by-Step Explanation of the `ShakespeareLightningModel` Class

This class defines a PyTorch Lightning model for a transformer-based architecture, likely for sequence modeling tasks (e.g., next-token prediction) with specific configurations for text generation. Below is a detailed breakdown of the class and its methods.

---

### 1. **Class Definition: `ShakespeareLightningModel`**

```python
class ShakespeareLightningModel(pl.LightningModule):
```

- **`ShakespeareLightningModel`**: This class inherits from `pl.LightningModule`, which is a base class in PyTorch Lightning designed to simplify training, validation, and testing.
- **Purpose**: It models a sequence-to-sequence task using a Transformer architecture with added capabilities for token classification (next-token prediction, etc.).

---

### 2. **Constructor: `__init__`**

```python
def __init__(
    self,
    # Transformer params
    vocab_size: int,
    d_model: int,
    nhead: int,
    dim_feedforward: int,
    dropout: float,
    num_layers: int,
    # Embedding params
    padding_idx: int,
    # Positional encoding params
    max_len: int,
    # Training params
    optimizer_params: dict,
):
```

- **Purpose**: Initializes the model by setting hyperparameters and defining various components of the transformer architecture.
  
#### Parameters:

- **`vocab_size`**: The size of the vocabulary. This defines the number of unique tokens in the input text.
- **`d_model`**: The dimensionality of the model's hidden state (the size of the embeddings and attention layers).
- **`nhead`**: The number of attention heads in the multi-head self-attention mechanism.
- **`dim_feedforward`**: The size of the hidden feedforward layer inside the transformer blocks.
- **`dropout`**: Dropout rate used for regularization.
- **`num_layers`**: Number of transformer layers in the model.
- **`padding_idx`**: Index used for padding in the input sequences. Padding tokens are ignored in loss computation and attention.
- **`max_len`**: Maximum sequence length for positional encoding.
- **`optimizer_params`**: A dictionary of parameters that define how the optimizer will behave during training (e.g., learning rate, weight decay, etc.).

#### Inside the constructor:

```python
super().__init__()
self.save_hyperparameters()
```
- **`super().__init__()`**: Calls the constructor of the parent class (`pl.LightningModule`).
- **`self.save_hyperparameters()`**: Saves the hyperparameters defined in the constructor, which is a convenience method provided by PyTorch Lightning. It stores the arguments in the model, making it easy to save/load the model configuration.

---

### 3. **Defining Model Components**

#### Transformer Model:

```python
self.model = Transformer(
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    num_layers=num_layers,
)
```

- **`self.model`**: A decoder-only Transformer model object that consists of multiple layers of multi-head **causal** self-attention and feed-forward networks. It's initialized with the parameters provided (e.g., model dimensions, number of layers).
  
#### Embedding Layer:

```python
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
```

- **`self.embedding`**: An embedding layer that converts input tokens (represented by their indices) into dense vectors of size `d_model`.
  - **`vocab_size`**: The number of tokens in the vocabulary.
  - **`d_model`**: The dimensionality of the embeddings.
  - **`padding_idx`**: Specifies the index that corresponds to the padding token, which will be ignored during training.

#### Classification Head (Output Layer):

```python
self.fc = nn.Linear(d_model, vocab_size)
```

- **`self.fc`**: A fully connected (linear) layer that maps the final output representation from the transformer (`d_model`-dimensional) to the size of the vocabulary (`vocab_size`). This layer is used for classification (i.e., predicting the next token in a sequence).

---

### 4. **Forward Pass: `forward`**

```python
def forward(self, input_idxs, attention_mask):
```

- **`forward(self, input_idxs, attention_mask)`**: This method defines how the data flows through the model. It takes two inputs:
  - **`input_idxs`**: The token indices of the input sequence, typically shaped `(batch_size, seq_len)`.
  - **`attention_mask`**: A mask tensor indicating which tokens are padding (`0`) and which are real tokens (`1`).

#### Input Embedding:

```python
input_embed = self.embedding(input_idxs)  # (batch_size, seq_len, d_model)
```

- **`input_embed`**: The input token indices (`input_idxs`) are passed through the embedding layer, resulting in dense representations of the input tokens with shape `(batch_size, seq_len, d_model)`.

#### Positional Encoding (Sinusoidal):

```python
pos_embeddings = (
    SinusoidalPositionalEncoding
    .get_positional_encoding(seq_len, d_model)
    .to(self.device)
    .unsqueeze(0)
)  # (1, seq_len, d_model)
```

- **`pos_embeddings`**: The model uses **sinusoidal positional encodings** to provide information about the relative or absolute position of tokens in the sequence. This encoding is added to the token embeddings to preserve order information.
  - **`get_positional_encoding(seq_len, d_model)`**: This function generates sinusoidal positional encodings of size `(seq_len, d_model)`.
  - **`.unsqueeze(0)`**: Adds a batch dimension, so the shape becomes `(1, seq_len, d_model)` (since the positional encoding is shared across the batch).
  
#### Combining Input Embeddings and Positional Encodings:

```python
input_embed = input_embed + pos_embeddings
```

- **Combining**: The learned token embeddings (`input_embed`) and the positional encodings (`pos_embeddings`) are summed element-wise. This gives the model information about both the content of the tokens and their positions in the sequence.

#### Passing Through Transformer:

```python
output = self.model(input_embed, attention_mask)
```

- **`output`**: The transformed input embeddings (with positional encodings) are passed through the Transformer model. The `attention_mask` ensures that padding tokens are ignored during attention calculation.

#### Final Output Layer:

```python
return self.fc(output)
```

- **`self.fc(output)`**: The transformer output (`output`) is passed through the final linear layer (`self.fc`) to get the prediction for each token in the sequence. The output shape is `(batch_size, seq_len, vocab_size)`, where each token in the sequence has a probability distribution over the vocabulary.

---

### 5. **Training Step: `_step`**

```python
def _step(self, batch):
    input_idxs, target_idxs, input_attention_mask, target_attention_mask = batch  # (batch_size, seq_len - 1)
    output = self(input_idxs, input_attention_mask)  # (batch_size, seq_len - 1, vocab_size)

    B, L, V = output.size()
    loss = F.cross_entropy(output.view(B * L, V), target_idxs.view(-1), reduction="none")  # (B * L)

    # Mask out the padding tokens
    target_attention_mask = target_attention_mask.view(-1) / target_attention_mask.sum()  # (B * L)
    loss = (loss * target_attention_mask).sum()

    return loss
```

- **Purpose**: This method defines a single step of training: from the forward pass to loss computation.
  
#### Extract Batch Elements:

```python
input_idxs, target_idxs, attention_mask = batch
```

#### Forward Pass:

```python
output = self(input_idxs, input_attention_mask)  # (batch_size, seq_len - 1, vocab_size)
```

- **`output`**: The model's output after the forward pass is a tensor of shape `(batch_size, seq_len - 1, vocab_size)`, where each token in the sequence has a probability distribution over the vocabulary.

#### Loss Computation:

```python
B, L, V = output.size()
loss = F.cross_entropy(output.view(B * L, V), target_idxs.view(-1), reduction="none")  # (B * L)

# Mask out the padding tokens
target_attention_mask = target_attention_mask.view(-1) / target_attention_mask.sum()  # (B * L)
loss = (loss * target_attention_mask).sum()
```

- **`F.cross_entropy`**: The Cross Entropy Loss is computed between the predicted output and the target indices.
  - The output is reshaped to `(B * L, V)` for proper loss computation (flattening the batch and sequence length dimensions).
  - `target_idxs.view(-1)` flattens the target indices to match the reshaped output.
  - `reduction="none"` means that the loss functions is not reduced across elements, as we are used to. Typically, a mean over all elements in the batch (and sequence) are calculated. However, we include padding in these sequences, and we do not need to predict that. Thus, we prevent this basic reduction, and do it manually ourselves.
  - **`target_attention_mask = target_attention_mask.view(-1) / target_attention_mask.sum()`**: This specifies the coefficients we use to compute the mean across elements in the batch and the sequences. `target_attention_mask.sum()` is the total amount of non-padding elements in the target ids. Thus, each non-padding element will be multiplied by the `1 / target_attention_mask.sum()` to compute the mean.

## Training

In [10]:
text = load_shakespeare()
tokenizer = CharTokenizer(text)

vocab_size = tokenizer.get_vocab_size()
padding_idx = tokenizer.get_padding_idx()

d_model = 128
nhead = 8
dim_feedforward = 512
dropout = 0.1
num_layers = 8

max_seq_len = 100

batch_size = 64
optimizer_params = dict(lr=3e-4, weight_decay=1e-2)
max_epochs = 50

In [11]:
data_module = ShakespeareDataModule(text, tokenizer, max_seq_len, batch_size)
data_module.setup()

model = ShakespeareLightningModel(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    num_layers=num_layers,
    padding_idx=padding_idx,
    max_len=max_seq_len,
    optimizer_params=optimizer_params,
)

In [12]:
trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator="gpu",
    devices=[1],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/pablo/.micromamba/envs/mdl_gen/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [13]:
trainer.fit(model, data_module)

The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name      | Type        | Params
------------------------------------------
0 | model     | Transformer | 1.6 M 
1 | embedding | Embedding   | 8.4 K 
2 | fc        | Linear      | 8.5 K 
------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.413     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/pablo/.micromamba/envs/mdl_gen/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


                                                                           

/home/pablo/.micromamba/envs/mdl_gen/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Epoch 49: 100%|██████████| 157/157 [00:04<00:00, 37.21it/s, v_num=28, train_loss_step=1.340, val_loss=1.420, train_loss_epoch=1.350]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 157/157 [00:04<00:00, 36.57it/s, v_num=28, train_loss_step=1.340, val_loss=1.420, train_loss_epoch=1.350]


## Decoding

Let's now generate some text! We implement two of the most simple decoding strategies seen in class:
1. **Greedy**: always choosing the most likely token.
2. **Sampling**: random sampling from the probability distribution as is.

In [14]:
def greedy_decoding(logits):
    # logits: (batch_size, vocab_size)
    return logits.argmax(dim=-1)

def sampling_decoding(logits):
    # logits: (batch_size, vocab_size)
    return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)

In [15]:
def complete_text(
    decoding_fn,
    text,
    tokenizer,
    model,
    num_completions=3,
    max_len=100,
):
    device = next(model.parameters()).device

    for _ in range(num_completions):
        encoded_text = tokenizer.encode(text)[-max_len:]  # NOTE: Truncate the text! The model was
                                                          # trained with a maximum length, and will 
                                                          # break down if the input is longer than that.
        encoded_text = torch.tensor(encoded_text).unsqueeze(0).to(device)
        attention_mask = torch.tensor([[True] * len(encoded_text[0])]).to(device)

        with torch.no_grad():
            output = model(encoded_text, attention_mask)  # (1, seq_len, vocab_size)
            output = output[:, -1, :]  # (1, vocab_size)

            next_tokens = decoding_fn(output)  # (1, 1)

            predicted_char = tokenizer.idx_to_char[next_tokens.item()]

        text += predicted_char

    return text

You can load the model from disk if you have trained on a previous run of the notebook:

In [16]:
# CKPT_PATH = "lightning_logs/version_26/checkpoints/epoch=49-step=7850.ckpt"
# model = ShakespeareLightningModel.load_from_checkpoint(CKPT_PATH)

In [17]:
model.to("cuda:0")

ShakespeareLightningModel(
  (model): Transformer(
    (layers): ModuleList(
      (0-7): 8 x CausalTransformerLayer(
        (attn): MultiheadAttention(
          (q_proj): Linear(in_features=128, out_features=128, bias=True)
          (k_proj): Linear(in_features=128, out_features=128, bias=True)
          (v_proj): Linear(in_features=128, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (ff): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=512, out_features=128, bias=True)
        )
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (embedding): Embedding(66, 128, padding_idx=65)
  (fc): Linear(in_features=128, out_features=66, bias=True)
)

In [18]:
text = "To be or not to be, that is the "
completed_text = complete_text(greedy_decoding, text, tokenizer, model, num_completions=1000, max_len=max_seq_len)
print(completed_text)

To be or not to be, that is the strange of the
should be so contented the state.

POLIXENES:
What is the son?


KING EDWARD IV:
The master of the comple of his household be his honour,
That thou hast a breathed the cannot be contented
That thou canst the seal that they shall be strange
That with the strange of the contrainted of the date.


PROSPERO:
The provost of SaintAble and love to the seath.


POMPEY:
What she had beeen the stan of the strange?


POMPSPEY:
The good love another than the senators of the day.


PETRUCHIO:
Not I willl not so for the stand and long the thought of the
that thou shalt be the dead of this day wont,
That with the shall be the senator's day,
That thou didst be the strange of the comple
To the coward of the consins and like the complainty.


KING RICHARD III:
The may soul that well hath the doth of the death.


QUEEN ELIZABETH:
The great that would be the stimph of the contrate
To the fire of this wof the power of the death.


POLIXENES:
What is the sworn 

In [19]:
text = "To be or not to be, that is the "
completed_text = complete_text(sampling_decoding, text, tokenizer, model, num_completions=1000, max_len=max_seq_len)
print(completed_text)

To be or not to be, that is the friar!

ANTONIO:
Come, I kmardled at thXpeal,
That all us, and trickely more than sherm  with
scorn!
The thown some speeaking of r and bids vaultlat, no
Forth that thou are:
Calll to breothe in ownderous; that I did
Such rider I in the isguing,
And give the firor lionghs in cait, death incennt.
In soorlaly and by iss reharly 'O mine!'

Shepherd:
In a mine againstar off wons scape undertake
The appehing into againsst of trorclay of his son!
But loves their nest:
Outward boldling away, sirve he disg man;
Nothing to unwits thy brader panI thanke chat
it iss toward Lay, all traTh thou dast!
Good to those hath canon, cannot bid:
TowDo fought of Warwick, that was thou not the stat
On of not to son wit, wan with o'clock ond
Thard hath that olst thought thou thinkest an'st edColl'd
t by dost she bliftt winkels within thy angree.
Somp my cunglioc and us ous it wordon hasband!
What is mind own your freshelf deseated?
What vizes saw is limberard, sweeet with O,
Gro