# Transformers

# Introduction: Why were transformers created?

Generally, For NLP tasks just few years ago model like RNNs and CNN were go to model for NLP Tasks, however, despite their success they know a present some massive limitations, thus the need to find a more suitable architecture such as transformers.

## 0.1 RNN limitations:

<div align="center">
  <img src="rnn_orig.png" width="50%"/>
</div>

* RNN have limitations when facing long sequences.
* Limitation related to vanishing/exploding gradients during training.
* The final hidden state of the encoder has a representation of the whole input sequence and it is all the decoder has access to to generate the output.
* Even after LSTM the long term memory was able to mitigate the vasnishing gradient issue but doesn't solve it completely

<div align="center">
  <img src="encode_decoder_rnn.png" width="40%"/>
</div>

# The Transformer Architecture:

## The tiny shakespear dataset.

As support data I decided to use the tiny shakesepar dataset that cntain all previous work from shakespear.</br>
Before diving into the transformer architecture, we'll quickly prepare the data to make it suitable for any implementation.

In [11]:
import os
import pickle
import inspect
import requests
import numpy as np

from typing import List, Tuple, Dict, Iterable, ContextManager


def create_directory(path: str) -> None:
    if not os.path.exists(path):
        os.makedirs(path)
        print("Directory {} created.".format(path))
    else:
        print("Directory {} already exists.".format(path))

def download_tiny_shakespeare_dataset(path: str) -> None:
    if not os.path.exists(path):
        data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
        with open(path, "w") as f:
            f.write(requests.get(data_url).text)
            print("Succesfully downloaded: {}".format(path))

def read_data(path: str) -> str:
    with open(path, "r") as f:
        data = f.read()
    print(f"Length of dataset in characters: {len(data):,}")
    return data

def get_characters(data: str) -> List:
    characters = sorted(list(set(data)))
    print("All the unique characters: {}".format("".join(characters[1:])))
    return characters

def compute_vocabulary_size(characters: List) -> int:
    size = len(characters)
    print("Vocabulary size: {}".format(size))
    return size

def build_mappings(characters: List) -> Tuple[Dict, Dict]:
    char_to_int = {char:i for i, char in enumerate(characters)}
    int_to_char = {i:char for i, char in enumerate(characters)}
    return char_to_int, int_to_char

def encode(string: str, char_to_int: Dict) -> List:
    integers = [char_to_int[char] for char in string]
    return integers

def decode(integers: str, int_to_char: Dict) -> List:
    chars = "".join([int_to_char[i] for i in integers])
    return chars

def split_data(data: str) -> Tuple[str, str]:
    n = len(data)
    train_data = data[:int(n*0.9)]
    val_data = data[int(n*0.9):]
    return train_data, val_data

def export_to_bin_files(directory: str, train_ids: List[int], val_ids: List[int]) -> None:
    train_ids = np.array(train_ids, dtype=np.uint16)
    val_ids = np.array(val_ids, dtype=np.uint16)
    train_ids.tofile(os.path.join(directory, "train.bin"))
    val_ids.tofile(os.path.join(directory, "val.bin"))

In [12]:
data_directory = "./data/shakespeare_char"
create_directory(path=data_directory)
path = os.path.join(data_directory, "input.txt")

download_tiny_shakespeare_dataset(path=path)
data = read_data(path=path)
characters = get_characters(data=data)
vocabulary_size = compute_vocabulary_size(characters=characters)
char_to_int, int_to_char = build_mappings(characters=characters)
train_data, val_data = split_data(data=data)
train_ids = encode(string=train_data, char_to_int=char_to_int)
val_ids = encode(string=val_data, char_to_int=char_to_int)
print(f"Train split has {len(train_ids):,} tokens")
print(f"Val split has {len(val_ids):,} tokens")
export_to_bin_files(directory=data_directory, train_ids=train_ids, val_ids=val_ids)

# Save the meta information as well, to help us encode/decode later
meta = {"vocab_size": vocabulary_size, "itos": int_to_char, "stoi": char_to_int}
with open(os.path.join(data_directory, "meta.pkl"), "wb") as f:
    pickle.dump(meta, f)

Directory ./data/shakespeare_char already exists.
Length of dataset in characters: 1,115,394
All the unique characters:  !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocabulary size: 65
Train split has 1,003,854 tokens
Val split has 111,540 tokens


<div align="center">
  <img src="transformer.png" width="40%"/>
</div>

The original transforùer is based on the encoder-decoder architecture that is widely uszd for tasks like machine translation, where a sequence of words is translated from one language to another . This architecture consists of two concomnents.

### Encoder:

Converts an input sequence of tokens into a sequence of embeddings vectors.

### Decoder :

Uses the encoder's hidden state to iteratively generate an output sequence of tokens, one token at a time.

## 1. Embedding:

* Each input sentence is split into tokens.
* Each token in the input sequence is represented as an integer index, These indices correspond to the position of the token in the vocabulary.
* The input indices are then passed through an embedding layer.

In [13]:
sentence = "JULIET: O Romeo, Romeo! wherefore art thou Romeo?"
sentence_ids = encode(string=sentence, char_to_int=char_to_int)
print(sentence_ids)

[22, 33, 24, 21, 17, 32, 10, 1, 27, 1, 30, 53, 51, 43, 53, 6, 1, 30, 53, 51, 43, 53, 2, 1, 61, 46, 43, 56, 43, 44, 53, 56, 43, 1, 39, 56, 58, 1, 58, 46, 53, 59, 1, 30, 53, 51, 43, 53, 12]


We would like to transform each token id (i.e. an integer) into an embedding (i.e. a vector $\in \mathbb{R}^D$, in our case $D = 384$).

In [14]:
import torch
import numpy as np

# Let's build a module to map token ids into embeddings
embedding_dime = 384
wte = torch.nn.Embedding(vocabulary_size, embedding_dime)

# As we use a torch module, we need to format our data as a torch tensor
numpy_ids = np.array(sentence_ids, dtype=np.int64)
torch_ids = torch.from_numpy(numpy_ids)
torch_ids = torch_ids.unsqueeze(0)

# Get embeddings
token_embeddings = wte(torch_ids)
token_embeddings.size()

torch.Size([1, 49, 384])

## 2. Positional embedding:

Matrix of learnable vectors that represent the respective position of each token in a sentence.

Such embeddings allow the transformer to learn how words need to be in a certain order to make sense in a sentence.


$$
\text{PE}(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)
$$

$$
\text{PE}(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)
$$




In [15]:
import torch
import numpy as np

def positional_encoding(max_seq_len, embedding_dim):
    position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-np.log(10000.0) / embedding_dim))
    pe = torch.zeros(max_seq_len, embedding_dim)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)


# Get positional encodings
max_seq_len = len(sentence_ids)
pos_encodings = positional_encoding(max_seq_len, embedding_dime)

# Add positional encodings to token embeddings
token_embeddings_with_pos = token_embeddings + pos_encodings

# Check the size of the resulting tensor
print(token_embeddings_with_pos.size())


torch.Size([1, 49, 384])


## 3. Encoder

In this section a lot of the explanation has been taken from the famous paper from [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf).



The transformer's encoder consists of many encoder layers stacked nect to each other.Each encoder layer received a sequence of embeddings and feeds them throug the following sublayers:
* A multi-head self attention layer.
* A fully connected feed forward layer that is applied to each input embedding.

<div align="center">
  <img src="encoder.png" width="10%"/>
</div>

### 3.1 The attention mechanism:

Intuitively, the attention mechanism enables models to focus on relevant parts of the input sequence when making predictions or generating outputs. Instead of processing the input sequencially, the attention mechanism assigns different weights to different parts of the sequence, thereby capturing dependencies and relationships more effectively.

* Query: The query is a feature vector that describes what we are looking for in the sequence, i.e. what would we maybe want to pay attention to.

* Keys: For each input element, we have a key which is again a feature vector. This feature vector roughly describes what the element is “offering”, or when it might be important. The keys should be designed such that we can identify the elements we want to pay attention to based on the query.

* Values: For each input element, we also have a value vector. This feature vector is the one we want to average over.

The transformer's encoder consists of many encoder layers stacked next to each other.Each encoder layer received a sequence of embeddings and feeds them through the following sublayers:
* A multi-head self attention layer.
* A fully connected feed forward layer that is applied to each input embedding.

<div align="center">
  <img src="attention.png" width="50%"/>
</div>

To represent visually what happens in the attention mechanism:

In [25]:
from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
text = "I love the Apple phone but hate the apple fruit"
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Attention Mechanism Math:


1. Query, Key, and Value Representations:
Let $X = \{x_1, x_2, ..., x_n\}$ be the input sequence. Linear transformations are applied to $X$ to generate three groups of representations: Query ($Q$), Key ($K$), and Value ($V$).
$$
\begin{align*}
Q &= \{q_1, q_2, ..., q_n\} \\
K &= \{k_1, k_2, ..., k_n\} \\
V &= \{v_1, v_2, ..., v_n\}
\end{align*}
$$

2. Dot Product Computation:
To compute attention scores, dot products are calculated for each pair of query and key representations $q_i$ and $k_j$:

$$
\hat{e}_{i,j} = q_i \cdot k_j
$$

The score $\hat{e}_{i,j}$ is then scaled by the square root of the dimension $D$ of the query representation:
$$
e_{i,j} = \frac{\hat{e}_{i,j}}{\sqrt{D}}
$$

3. Attention Scores:
A softmax function is applied to the scaled scores $\{e_{i,j}\}$ to obtain normalized attention weights:
$$
\alpha_{i,j} = \frac{\exp(e_{i,j})}{\sum_{j=1}^{n} \exp(e_{i,j})}
$$
where $\alpha_{i,j}$ represents the attention score.

4. Weighted Sum:
Finally, a weighted sum of the input token embeddings using the attention weights is computed to obtain the context vector $\hat{x}_i$:
$$
\hat{x}_i = \sum_{j=1}^{n} \alpha_{i,j} \cdot v_j
$$

In [56]:
class ModelConfig:
    def __init__(self,
                 bias: bool = False,
                 n_embd: int = 384,
                 n_head: int = 6,
                 vocab_size: int = 65) -> None:
        self.bias = bias
        self.n_embd = n_embd
        self.n_head = n_head
        self.vocab_size = vocab_size

In [57]:
import torch
import torch.nn as nn
from math import sqrt
import torch.nn.functional as F

class AttentionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        self.q = nn.Linear(self.n_embd, self.head_dim)
        self.k = nn.Linear(self.n_embd, self.head_dim)
        self.v = nn.Linear(self.n_embd, self.head_dim)

    def scaled_dot_product_attention(self,query, key, value):
        scores = torch.matmul(query,key.transpose(-2,-1))/np.sqrt(self.n_embd)
        weights = F.softmax(scores, dim=-1)
        return torch.matmul(weights, value)

    def forward(self, query,key,value):
        attn_outputs = self.scaled_dot_product_attention(
            self.q(query), self.k(key), self.v(value))
        return attn_outputs

In [91]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.n_embd
        num_heads = config.n_head
        self.heads = nn.ModuleList(
            [AttentionHead(config) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, query,key,value,mask=None):
        
        x = torch.cat([h(query,key,value,mask) for h in self.heads], dim=-1)
        x = self.output_linear(x)

        return x

In [92]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.n_embd, config.n_embd)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        return x

In [93]:
class TransformerEncoderLayer(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.layer_norm_1 = nn.LayerNorm(config.n_embd)
            self.layer_norm_2 = nn.LayerNorm(config.n_embd)
            self.attention = MultiHeadAttention(config)
            self.feed_forward = FeedForward(config)
        def forward(self, x):
            # Apply layer normalization and then copy input into query, key, value
            hidden_state = self.layer_norm_1(x)
            x = self.attention(hidden_state,hidden_state,hidden_state)
            x = self.feed_forward(self.layer_norm_2(x))
            return x

In [94]:
config = ModelConfig()
attention_layer = TransformerEncoderLayer(config=config)
attention_embeddings = attention_layer(token_embeddings)
print(token_embeddings.shape)
print(attention_embeddings.shape)

torch.Size([1, 49, 384])
torch.Size([1, 49, 384])


### Decoder

the main difference between the decoder and encoder is that the decoder has two attention sublayers:
* Masked multi-head self-attention layer.
* Encoder-decoder attention layer.

<div align="center">
  <img src="decoder.png" width="10%"/>
</div>

#### Masked multi head attention

<div align="center">
  <img src="decoder_attention.png" width="50%"/>
</div>

In [95]:
class AttentionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        self.q = nn.Linear(self.n_embd, self.head_dim)
        self.k = nn.Linear(self.n_embd, self.head_dim)
        self.v = nn.Linear(self.n_embd, self.head_dim)
  
    def scaled_dot_product_attention(self,query, key, value, mask=None):
        scores = torch.matmul(query,key.transpose(-2,-1))/np.sqrt(self.n_embd)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        weights = F.softmax(scores, dim=-1)
        return weights.bmm(value)
    
    def forward(self, query,key,value, mask=None):
        attn_outputs = self.scaled_dot_product_attention(
            self.q(query), self.k(key), self.v(value), mask=mask)
        return attn_outputs


#### Encoder-Decoder attention

<div align="center">
  <img src="encoder_decoder_attention.png" width="50%"/>
</div>

In [99]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.masked_self_attention = MultiHeadAttention(config)
        self.encoder_decoder_attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)

    def forward(self, decoder_hidden_states, encoder_hidden_states,decoder_mask=None,encoder_decoder_mask = None):
        self_attention_output = self.masked_self_attention(
            query=decoder_hidden_states,
            key=decoder_hidden_states,
            value=decoder_hidden_states,
            mask=decoder_mask
        )

        # Apply residual connection and layer normalization
        self_attention_output = self_attention_output + decoder_hidden_states
        self_attention_output = nn.LayerNorm(config.n_embd)(self_attention_output)

        # Encoder-decoder attention
        encoder_decoder_attention_output = self.encoder_decoder_attention(
            query= self_attention_output,
            key=encoder_hidden_states,
            value=encoder_hidden_states,
            mask=encoder_decoder_mask
        )

        # Apply residual connection and layer normalization
        encoder_decoder_attention_output = encoder_decoder_attention_output + self_attention_output
        encoder_decoder_attention_output = nn.LayerNorm(config.n_embd)(encoder_decoder_attention_output)

        # Apply feed-forward layer
        output = self.feed_forward(encoder_decoder_attention_output)

        return output


In [100]:
config = ModelConfig()
attention_layer = TransformerDecoderLayer(config=config)
seq_length =token_embeddings.size(-2)
attention_embeddings = attention_layer(token_embeddings,token_embeddings,torch.triu(torch.ones(seq_length, seq_length), diagonal=1))
print(token_embeddings.shape)
print(attention_embeddings.shape)

49
torch.Size([1, 49, 384])
torch.Size([1, 49, 384])


# GPT from scratch

In [None]:
import math
import torch
import torch.nn.functional as F

from typing import Tuple, Dict


class ModelConfig:
    def __init__(self,
                 bias: bool = False,
                 casual: bool = True,
                 n_embd: int = 384,
                 n_head: int = 6,
                 dropout: float = 0.2,
                 block_size: int = 256,
                 n_layer: int = 6,
                 vocab_size: int = 65) -> None:
        self.bias = bias
        self.casual = casual
        self.n_embd = n_embd
        self.n_head = n_head
        self.dropout = dropout
        self.block_size = block_size
        self.n_layer = n_layer
        self.vocab_size = vocab_size


class SelfAttention(torch.nn.Module):
    def __init__(self, config: ModelConfig) -> None:
        super().__init__()
        # Casual or bi-directional
        self.casual = config.casual

        # Check embedding dimension can be split by the number of heads
        assert config.n_embd % config.n_head == 0

        # Generate key, query, value projections for all heads (but in a batch)
        self.c_attn = torch.nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)

        # Attention output projection
        self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # Regularization
        self.attn_dropout = torch.nn.Dropout(config.dropout)
        self.resid_dropout = torch.nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        if self.casual:
            # Causal mask to ensure that attention is only applied to the left in the input sequence
            full_attention_matrix = torch.ones(config.block_size, config.block_size)
            left_attention_matrix = torch.tril(full_attention_matrix)
            shifted_left_attention_matrix = left_attention_matrix.view(1, 1, config.block_size, config.block_size)
            self.register_buffer("bias", shifted_left_attention_matrix)

    def get_query_key_value_representations(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # **1. Query, Key, and Value Representations**
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        return q, k, v

    def compute_dot_product(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
        # **2. Dot Product Computation**
        e_ij_hat = q @ k.transpose(-2, -1)
        #e_ij_hat = torch.einsum('bhij,bhjk->bhik', q, k.transpose(-2, -1))
        e_ij = e_ij_hat * (1.0 / math.sqrt(k.size(-1)))
        return e_ij

    def compute_attention_scores(self, scores: torch.Tensor) -> torch.Tensor:
        # **3. Attention Scores**
        scores = F.softmax(scores, dim=-1)
        return scores

    def compute_weighted_sum(self, scores: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        # **4. Weighted Sum**
        x_hat = scores @ v
        return x_hat

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, max_length, embbeding_dim = x.size() # Shape: (batch_size, max_length, embbeding_dim)

        # Split query, key, values from main representation for all heads
        q, k, v = self.get_query_key_value_representations(x=x)  # Shape for q, k, v: (batch_size, max_length, embbeding_dim)

        # Add the head dimension by splitting query, key, values by the number of heads
        k = k.view(batch_size, max_length, self.n_head, embbeding_dim // self.n_head).transpose(1, 2) # Shape: (batch_size, num_heads, max_length, embbeding_dim/num_heads)
        q = q.view(batch_size, max_length, self.n_head, embbeding_dim // self.n_head).transpose(1, 2) # Shape: (batch_size, num_heads, max_length, embbeding_dim/num_heads)
        v = v.view(batch_size, max_length, self.n_head, embbeding_dim // self.n_head).transpose(1, 2) # Shape: (batch_size, num_heads, max_length, embbeding_dim/num_heads)

        # Manual implementation of attention
        att = self.compute_dot_product(q=q, k=k)
        if self.casual:
          att = att.masked_fill(self.bias[:,:,:max_length,:max_length] == 0, float("-inf"))
        att = self.compute_attention_scores(scores=att)
        att = self.attn_dropout(att)
        y = self.compute_weighted_sum(scores=att, v=v) # (batch_size, num_heads, max_length, max_length) x (batch_size, num_heads, max_length, embbeding_dim/num_heads) -> (batch_size, num_heads, max_length, embbeding_dim/num_heads)

        # Reshape to remove head dimension: (batch_size, num_heads, max_length, embbeding_dim/num_heads) -> (batch_size, max_length, embbeding_dim)
        y = y.transpose(1, 2).contiguous().view(batch_size, max_length, embbeding_dim)

        # Output projection
        y = self.resid_dropout(self.c_proj(y))

        return y

# Let's get a self-attention layer and let's feed this layer with `token_embeddings`
config = ModelConfig()
attention_layer = SelfAttention(config=config)
attention_embeddings = attention_layer(token_embeddings)
print(token_embeddings.shape)
print(attention_embeddings.shape)



In [None]:
class LayerNorm(torch.nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False.
    """
    def __init__(self, ndim: int, bias: bool) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(ndim))
        self.bias = torch.nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
        return output


class MLP(torch.nn.Module):
    def __init__(self, config: ModelConfig) -> None:
        super().__init__()
        self.c_fc    = torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = torch.nn.GELU()
        self.c_proj  = torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = torch.nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(torch.nn.Module):
    def __init__(self, config: ModelConfig) -> None:
        super().__init__()
        self.ln_1 = LayerNorm(ndim=config.n_embd, bias=config.bias)
        self.attn = SelfAttention(config=config)
        self.ln_2 = LayerNorm(ndim=config.n_embd, bias=config.bias)
        self.mlp  = MLP(config=config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

# Let's get a transformer block and let's feed it with `token_embeddings`
config = ModelConfig()
transformer_block = Block(config=config)
transformer_embeddings = transformer_block(token_embeddings)

print(transformer_embeddings.shape)

In [None]:
class GPT(torch.nn.Module):
    def __init__(self, config: ModelConfig) -> None:
        """Init a GPT model

        Args:
            config (GPTConfig): See ModelConfig's documentation.
        """
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        # Let's define the transformer module
        self.transformer = torch.nn.ModuleDict(dict(wte = torch.nn.Embedding(config.vocab_size, config.n_embd),
                                                    wpe = torch.nn.Embedding(config.block_size, config.n_embd),
                                                    drop = torch.nn.Dropout(config.dropout),
                                                    h = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                                                    ln_f = LayerNorm(config.n_embd, bias=config.bias)))

        # Let's define the prediction head (to choose a token)
        self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight sharing between the weights of the embedding and softmax layers
        # Better explained here: https://paperswithcode.com/method/weight-tying
        self.transformer.wte.weight = self.lm_head.weight

        # Init all weights
        self.apply(self._init_weights)

        # Apply special scaled init to the residual projections, per GPT-2 paper
        # Investigate why, this should be done in such a way
        for name, parameter in self.named_parameters():
            if name.endswith("c_proj.weight"):
                torch.nn.init.normal_(parameter,
                                      mean=0.0,
                                      std=0.02/math.sqrt(2*config.n_layer))

        # Display number of parameters
        print("Number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding: bool = True) -> int:
        """Return the number of parameters in the model.
           For non-embedding count (default), the position embeddings get subtracted.
           The token embeddings would too, except due to the parameter sharing these
           params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module: torch.nn.Module) -> None:
        """Init weights for Linear and Embedding modules

        Args:
            module (torch.nn.Module): Model module
        """
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, torch.nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def compute_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        reshaped_logits = logits.view(-1, logits.size(-1)) # Shape: (batch_size * max_length, vocabulary_size)
        reshaped_targets = targets.view(-1) # Shape: (batch_size * max_length)
        # Compute loss, cross_entropy between the probability distribution over classes and the expected target
        # Target values == -1 are ignored
        loss = F.cross_entropy(reshaped_logits, reshaped_targets, ignore_index=-1)
        return loss

    def forward(self, idx: torch.Tensor, targets: torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Process text input and target and computes the loss.

        Args:
            idx (torch.Tensor): Text input. Input shape: batch_size x dim. Input and model should be in the same device.
            targets (torch.Tensor, optional): Text target. Target shape: Same as input. Defaults to None.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Logits and loss
        """
        # Get device
        device = idx.device

        # Check text length
        batch_size, text_length = idx.size()
        max_length_msg = f"Cannot forward sequence of length {text_length}, block size is only {self.config.block_size}"
        assert text_length <= self.config.block_size, max_length_msg

        # Create a positional vector
        pos = torch.arange(0, text_length, dtype=torch.long, device=device)

        # Transform the text input into text embeddings
        tok_emb = self.transformer.wte(idx) # Shapes: (batch_size, max_length) -> (batch_size, max_length, embedding_dim)

        # Transform the positional indices into positional embeddings
        pos_emb = self.transformer.wpe(pos) # Shapes: (max_length) -> (max_length, embedding_dim)

        # Merge text and positional information + apply dropout
        x = self.transformer.drop(tok_emb + pos_emb) # Shapes: (batch_size, max_length, embedding_dim) + (max_length, embedding_dim) -> (batch_size, max_length, embedding_dim)

        # Pass the text embedding through all the layers of the transformer
        for block in self.transformer.h:
            x = block(x) # Shape doesn't change: (batch_size, max_length, embedding_dim)
        # Normalize output
        x = self.transformer.ln_f(x) # Shape doesn't change: (batch_size, max_length, embedding_dim)

        # let's compute the loss
        if targets is not None:
            logits = self.lm_head(x) # Shapes: (batch_size, max_length, embedding_dim) -> (batch_size, max_length, vocabulary_size)
            loss = self.compute_loss(logits=logits, targets=targets)
        else:
            # Otherwise, inference-time mini-optimization: only forward the lm_head on the very last position
            last_token = x[:, [-1], :] # Shape: (batch_size, 1, embedding_dim)
            logits = self.lm_head(last_token) # Shape: (batch_size, 1, vocabulary_size)
            loss = None
        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # Start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # Filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # Create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [{'params': decay_params, 'weight_decay': weight_decay},
                        {'params': nodecay_params, 'weight_decay': 0.0}]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    @torch.no_grad()
    def generate(self,
                 idx: torch.Tensor,
                 max_new_tokens: int,
                 temperature: float=1.0,
                 top_k: int = None) -> torch.Tensor:
        """Take a conditioning sequence of indices idx (LongTensor of shape (batch_size, length)) and complete
           the sequence max_new_tokens times, feeding the predictions back into the model each time.
           Most likely you'll want to make sure to be in model.eval() mode of operation for this.

        Args:
            idx (torch.Tensor): Conditioning sequence.
            max_new_tokens (int): Max number of generated tokens.
            temperature (float, optional): Softmax temperature. Defaults to 1.0.
            top_k (int, optional): For each predition keep top_k tokens. Defaults to None.

        Returns:
            torch.Tensor: A sequence of generated tokens.
        """
        for _ in range(max_new_tokens):
            # If the sequence context is growing too long (more than block_size), we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            # Forward the model to get the logits for the index in the sequence
            logits, _ = self.forward(idx=idx_cond)

            # Scale by desired temperature
            logits = logits[:, -1, :] / temperature # Shape: (batch_size, vocabulary_size)

            # Optionally, crop the logits to only the top k options
            if top_k is not None:
                top_values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                min_top_value = top_values[:, [-1]]
                logits[logits < min_top_value] = -float("Inf")

            # Apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)

            # Append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# Let's get a transformer block and let's feed it with `token_embeddings`
config = ModelConfig()
model = GPT(config=config)
outputs = model(idx=torch_ids, targets=None)
print(outputs[0].shape)

In [None]:
# Let's create a dataloader to get batches of preprocessed text
class CharacterDataloader:
    def __init__(self,
                 dataset_name: str,
                 device: str,
                 device_type: str,
                 block_size: int,
                 batch_size: int) -> None:
        # Set main attributes
        self.directory = os.path.join("./data", dataset_name)
        self.device = device
        self.device_type = device_type
        self.block_size = block_size
        self.batch_size = batch_size

        # Get train/val data
        self.train_data = np.memmap(os.path.join(self.directory, "train.bin"), dtype=np.uint16, mode="r")
        self.val_data = np.memmap(os.path.join(self.directory, "val.bin"), dtype=np.uint16, mode="r")

    def get_batch(self, split: str) -> Tuple[torch.Tensor, torch.Tensor]:
        data = self.train_data if split == "train" else self.val_data
        ix = torch.randint(len(data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy((data[i:i+self.block_size]).astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy((data[i+1:i+1+self.block_size]).astype(np.int64)) for i in ix])
        if self.device_type == "cuda":
            # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
            x, y = x.pin_memory().to(self.device, non_blocking=True), y.pin_memory().to(self.device, non_blocking=True)
        else:
            x, y = x.to(self.device), y.to(self.device)
        return x, y

# Gradient clipping function to make training stable
def clip_gradients(value: float,
                   model: torch.nn.Module,
                   scaler: torch.cuda.amp.GradScaler,
                   optimizer: torch.optim.Optimizer) -> None:
    if value:
        scaler.unscale_(optimizer) # We assume you run scale() before clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), value) # The new norm is given by value

# Estimate train/val losses
@torch.no_grad()
def estimate_loss(context: ContextManager,
                  model: torch.nn.Module,
                  dataloader: Iterable,
                  iterations: int) -> Dict:
    outputs = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(iterations)
        for k in range(iterations):
            text_input, text_output = dataloader.get_batch(split=split)
            with context:
                _, loss = model(text_input, text_output)
            losses[k] = loss.item()
        outputs[split] = losses.mean()
    model.train()
    return outputs

In [None]:
# Get datalaoder
device = "cuda:0"
device_type = "cuda"
dataset_name = "shakespeare_char"
batch_size = 64
block_size = config.block_size
dataloader = CharacterDataloader(dataset_name=dataset_name,
                                 device=device,
                                 device_type=device_type,
                                 block_size=block_size,
                                 batch_size=batch_size)

# Get data type
dtype = "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16"

# Get model
model = GPT(config=config)
model.to(device)

# Initialize a GradScaler.
scaler = torch.cuda.amp.GradScaler(enabled=(dtype=="float16"))

# Get context
dtype_dict = {"bfloat16": torch.bfloat16, "float16": torch.float16}
context = torch.amp.autocast(device_type=device_type, dtype=dtype_dict[dtype])

# Get optimizer
weight_decay = 1e-1
learning_rate = 1e-3
beta1 = 0.9
beta2 = 0.99
optimizer = model.configure_optimizers(weight_decay=weight_decay,
                                       learning_rate=learning_rate,
                                       betas=(beta1, beta2),
                                       device_type=device_type)

# Compile the model to go faster
compile = True
if compile:
    model = torch.compile(model)

# Create results folder
results_directory = "./results/shakespeare_char"
create_directory(path=results_directory)

# Train the model!
num_iterations = 5000
best_val_loss = 1e9
for iteration in range(num_iterations):
    # Run validation
    if iteration % 100 == 0:
        losses = estimate_loss(context=context, model=model, dataloader=dataloader, iterations=200)
        print(f"Iteration {iteration}: Train loss {losses['train']:.4f}, Val loss {losses['val']:.4f}")
        if losses["val"] < best_val_loss:
            best_val_loss = losses['val']
            checkpoint = {"model": model.state_dict(),
                          "best_val_loss": best_val_loss,
                          "config": config}
            print("Saving checkpoint to {}".format(results_directory))
            torch.save(checkpoint, os.path.join(results_directory, "ckpt.pt"))

    # Run training
    X, Y = dataloader.get_batch(split="train")
    with context:
        logits, loss = model(X, Y)
    scaler.scale(loss).backward()
    clip_gradients(value=1.0,
                   model=model,
                   scaler=scaler,
                   optimizer=optimizer)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)


In [None]:
# Load the best model
device = "cuda:0"
results_directory = "./results/shakespeare_char"
checkpoint_path = os.path.join(results_directory, "ckpt.pt")
checkpoint = torch.load(checkpoint_path, map_location=device)
model = GPT(config=checkpoint["config"])
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Get context
dtype = "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16"
dtype_dict = {"bfloat16": torch.bfloat16, "float16": torch.float16}
context = torch.amp.autocast(device_type=device_type, dtype=dtype_dict[dtype])

# Define a start
start = "\n stay away"
start = encode(string=start, char_to_int=char_to_int)
torch_start = torch.tensor(start, dtype=torch.long, device=device)[None, ...]

# Define generation parameters
num_samples = 2
temperature = 0.8
top_k = 200
max_new_tokens = 500

with torch.no_grad():
    with context:
        for k in range(num_samples):
            print("=================== Start sample {} ===================". format(k))
            y = model.generate(torch_start, max_new_tokens, temperature=temperature, top_k=top_k)
            generated_text = decode(integers=y[0].tolist(), int_to_char=int_to_char)
            print(generated_teaxt)
            print("=================== End sample {} =================== \n". format(k))
