In [None]:
print("Hello worldo!")

In [None]:
# TODO:
# 1. done.
# 2. use multiple blocks sequentially ==> done.
# 3. rename to transformer ==> done. (TransformerLanguageModel)
# 4. add tqdm ==> done.
# 5. use full dataset ==> done.
# 6. make model bigger ==> done.
# 7. find out good batch size (for full gpu) ==> done.
# 8. find training time -> steps ==> done. (8.30min per 1k steps)
# 9. train

In [None]:
from __future__ import annotations
import typing

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm

torch.manual_seed(3654)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# report the parameters | and loss with 5k training steps:

# --- small 10k dataset ---
# bigram model: 3844 ==> loss to 2.5
# with one multihead attention (5 heads): 53262 ==> loss to 0.35
# with one multihead attention (5 heads) and residual connection: 53262 ==> loss to 0.30
# with one full block: 134162 ==> loss to 0.24
# with 5 blocks: 617762 ==> loss to 0.12

# --- large dataset --- 
# bigger model: 5015318 (5M) ==> loss to 0.0116

### Get text data

In [None]:
with open("text_corpus.txt", encoding="utf-8") as f:
    text = f.read()# [:10_000]

In [None]:
len(text)

In [None]:
vocab = sorted(set(text))
print(vocab)

In [None]:
vocab_size = len(vocab)
vocab_size

### Encode the characters to integers

In [None]:
# Create encoder and decoder dicts

char_int_mapping = dict()
int_char_mapping = dict()

for i, c in enumerate(sorted(set(text))):
    char_int_mapping[c] = i
    int_char_mapping[i] = c

In [None]:
# examples
print(char_int_mapping["g"])
print(int_char_mapping[42])

In [None]:
def encode(string: str) -> List[int]:
    int_list = [char_int_mapping[char] for char in string]
    return int_list

def decode(int_list: List[int]) -> str:
    string = [int_char_mapping[num] for num in int_list]
    return "".join(string)

In [None]:
# examples
print(encode("hellooo"))
print(decode([43, 40, 47, 47, 50, 50, 50]))

In [None]:
encode("hello world")

### Make it a tensor

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)
data.shape

### Train/Test split

In [None]:
len(data)

In [None]:
N = int(0.9*len(data))
train_data = data[:N]
test_data = data[N:]

print(len(train_data), len(test_data))

### Create minibatches

In [None]:
block_size = 8 # block_size is the maximum context length (input textblock size)
batch_size = 4

In [None]:
train_data[:block_size]

In [None]:
def get_batch(split: torch.tensor):
    offsets = np.random.randint(0, len(split) - block_size, size=batch_size)
    x = torch.stack([split[i:i+block_size] for i in offsets]).to(device)
    y = torch.stack([split[i+1:i+block_size+1] for i in offsets]).to(device)
    return x, y

In [None]:
xb, yb = get_batch(train_data)
xb

In [None]:
yb

In [None]:
decode(xb[3].tolist())

In [None]:
for b in range(batch_size):
    print(f"----- BATCH {b} -----")
    
    for t in range(block_size):
        context = xb[b][:t+1]
        target = yb[b][t]
        print(f"context: {decode(context.tolist())} -> target: {decode([int(target)])}")
    print()
        
    for t in range(block_size):
        context = xb[b][:t+1]
        target = yb[b][t]
        print(f"context: {context.tolist()} -> target: {int(target)}")
    print()

### Create a simple model

In [None]:
class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, context, targets):
        logits = self.token_embedding_table(context) # (Batch, Time, Channels) ==> [4, 8, 62]

        B, T, C = logits.shape
        logits = logits.view(B*T,C) # [32,64]
        targets = targets.view(B*T) # [32]
        loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, context, max_new_tokens):
        
        # context: (Batch, Time) ==> [4, 8]
        # -> extend context in Time dimension for max_new_tokens
        
        for _ in range(max_new_tokens):
            
            # get prediction
            logits, loss = self(xb, yb)
            
            # get logits for the last character 
            # (because we only need the last char to predict with our bigram model)
            logits = logits.view(batch_size,-1,vocab_size) # (B,T,C)
            logits = logits[:, -1, :] # (B,C) for only the last character
            
            probs = F.softmax(logits, dim=-1) # (B,C)
            next_token = torch.multinomial(probs, num_samples=1) # (B,1)
            
            # append next token to the sequence
            context = torch.cat((context, next_token), dim=1) # (B,T+1)
        
        return context
    
    def generate_to_text(self, context, max_new_tokens):
        context = self.generate(context, max_new_tokens)
        return decode(context[0].tolist())

model = BigramLanguageModel(vocab_size).to(device)
model

In [None]:
count_trainable_parameters(model)

### Train the simple model

In [None]:
batch_size = 32
losses = []

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for step in range(5_000):
    
    # get a batch
    xb, yb = get_batch(train_data)
    
    # predict and get loss
    logits, loss = model(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())

In [None]:
plt.plot(losses)

In [None]:
# inference
zero_context = torch.zeros((batch_size,1)).to(device)
model.generate_to_text(zero_context, max_new_tokens=200)

### Attention

In [None]:
Q = torch.rand(5,1)
K = torch.rand(5,1)
V = torch.rand(5,1)
Q,K,V

In [None]:
def attention(Q,K,V):
    """ 
    Applies masked scaled dot-product attention
    between vectors of queries Q, keys K and values V. 
    """
    d_k = torch.tensor(Q.shape[0])
    W = (Q @ K.T) / torch.sqrt(d_k)
    
    # mask out forbidden connections
    tril = torch.tril(torch.ones((d_k, d_k)))
    W = W.masked_fill(tril==0, float("-inf"))
    
    W = F.softmax(W, dim=1)
    
    return W @ V

In [None]:
attention(Q,K,V)

### Multi-Head Attention

In [None]:
def multi_head_attention(Q,K,V):
    d_k = torch.tensor(Q.shape[0])
    d_model = 8 # project in to this space
    N_heads = 2
    
    # linear layers
    projections = {
        x: {
            h: nn.Linear(d_k, d_model, bias=False) for h in range(N_heads)
        } for x in ["Q", "K", "V"]
    }
    
    # layer to combine the concatenated attention-block output vectors
    top_layer = nn.Linear(N_heads * d_model, d_k, bias=False)
    
    # forward pass
    result = torch.zeros(N_heads, d_model, 1)

    for h in range(N_heads):
        result[h] = attention(
            projections["Q"][h](Q.T).T,
            projections["K"][h](K.T).T,
            projections["V"][h](V.T).T
        )
    
    concat_attn_out = result.view(1, N_heads * d_model)
    return top_layer(concat_attn_out).T

In [None]:
multi_head_attention(Q,K,V)

In [None]:
d_k = torch.tensor(Q.shape[0])
d_model = 8 # project in to this space
N_heads = 2

projections = {
    x: {
        h: nn.Linear(d_k, d_model, bias=False) for h in range(N_heads)
    } for x in ["Q", "K", "V"]
}

projections

In [None]:
projections["Q"][0](Q.T).T

In [None]:
attention(Q,K,V)

In [None]:
result = torch.zeros(N_heads, d_model, 1)

for h in range(N_heads):
    result[h] = attention(
        projections["Q"][h](Q.T).T,
        projections["K"][h](K.T).T,
        projections["V"][h](V.T).T
    )

In [None]:
result

In [None]:
concat_attn_out = result.view(1, N_heads * d_model)
concat_attn_out

In [None]:
top_layer = nn.Linear(N_heads * d_model, d_k, bias=False)

In [None]:
top_layer(concat_attn_out).T.shape

### Masking

In [None]:
T = 10
tril = torch.tril(torch.ones((T,T)))
tril

In [None]:
W = torch.rand((T,T)) # there will be real data here

# mask out forbidden connections
W = W.masked_fill(tril==0, float("-inf")) # set everywhere where tril is 0 to -inf (upper right)

W = F.softmax(W, dim=-1)
plt.imshow(W)

### Positional encoding

- learned

### Transformer

In [None]:
# hyperparameters
training_steps = 5000
embed_dims = 128 # is equivalent to d_model
block_size = 256
batch_size = 128
n_heads = 8
head_size = embed_dims // n_heads
n_layers = 25

In [None]:
class SelfAttentionHead(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.proj_q = nn.Linear(embed_dims, head_size, bias=False)
        self.proj_k = nn.Linear(embed_dims, head_size, bias=False)
        self.proj_v = nn.Linear(embed_dims, head_size, bias=False)
    
    def forward(self, x):
        """ 
        Applies masked scaled dot-product attention
        between vectors of queries Q, keys K and values V. 
        """
        B,T,C = x.shape
        
        Q = self.proj_q(x)
        K = self.proj_k(x)
        V = self.proj_v(x)

        W = (Q @ K.transpose(-1,-2)) # (B, T, C) @ (B, C, T) ==> (B,T,T)
        W /= torch.sqrt(torch.tensor(head_size))
        
        # mask out forbidden connections
        tril = torch.tril(torch.ones((block_size, block_size), device=device))
        W = W.masked_fill(tril[:T, :T]==0, float("-inf")) # make smaller so it fits if context < block_size
        W = F.softmax(W, dim=1)
        out = W @ V
        return out # (B,T,C=head_size)

In [None]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.heads = nn.ModuleList([SelfAttentionHead() for i in range(n_heads)])
        self.proj = nn.Linear(embed_dims, embed_dims, bias=False) # embed_dims = n_heads * head_size
    
    def forward(self, x):
        
        out = torch.cat([attn_head(x) for attn_head in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [None]:
class Block(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.attn = MultiHeadAttention()
        self.ln1 = nn.LayerNorm(embed_dims)
        self.ln2 = nn.LayerNorm(embed_dims)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dims, 4*embed_dims), # following attention-is-all-you-need paper for num hidden units
            nn.ReLU(),
            nn.Linear(4*embed_dims, embed_dims),
            nn.ReLU(),
        )
    
    def forward(self, x):
        
        # Applies layernorm before self-attention.
        # In the attention-is-all-you-need paper they apply it afterwards, 
        # but apparently pre-ln performs better. pre-ln paper: https://arxiv.org/pdf/2002.04745.pdf
        
        x = x + self.attn(self.ln1(x)) # (B,embed_dims)
        x = x + self.mlp(self.ln2(x))
        return x

In [None]:
class TransformerLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embed_dims)
        
        # positional encoding
        self.pos_embedding_table = nn.Embedding(block_size, embed_dims)
        
        # transformer layers
        # self.multihead_attn1 = MultiHeadAttention()
        # self.block1 = Block()
        self.blocks = nn.Sequential(*[Block() for i in range(n_layers)])
        
        # output layers
        self.lm_head = nn.Linear(embed_dims, vocab_size)
        
    
    def forward(self, context, targets=None):
        
        B, T = context.shape
        
        # get the embedding vectors word-to-vec style
        token_emb = self.token_embedding_table(context) # (Batch, Time, Channels) ==> [4, 8, 62]
        
        # add the positional embedding'
        pos_emb = self.pos_embedding_table(torch.arange(T, device=device)) # (T,C)
        
        x = token_emb + pos_emb

        # transformer forward pass
        x = self.blocks(x)

        # output layers
        logits = self.lm_head(x)        
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T,C) # [32,64]
            targets = targets.view(B*T) # [32]
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, context, max_new_tokens):
        
        
        # context: (Batch, Time) ==> [4, 8]
        # -> extend context in Time dimension for max_new_tokens
        
        for _ in range(max_new_tokens):
            
            # get prediction
            # logits, loss = self(xb, yb)
            print(context.shape)
            logits = self(context)
            
            # get logits for the last character 
            # (because we only need the last char to predict with our bigram model)
            logits = logits.view(batch_size,-1,vocab_size) # (B,T,C)
            logits = logits[:, -1, :] # (B,C) for only the last character
            
            probs = F.softmax(logits, dim=-1) # (B,C)
            next_token = torch.multinomial(probs, num_samples=1) # (B,1)
            
            # append next token to the sequence
            context = torch.cat((context, next_token), dim=1) # (B,T+1)
        
        return context
    
    def generate_to_text(self, context, max_new_tokens):
        context = self.generate(context, max_new_tokens)
        return decode(context[0].tolist())

model = TransformerLanguageModel().to(device)

In [None]:
count_trainable_parameters(model)

In [None]:
losses = []

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # 3e-4

for step in tqdm(range(training_steps)):
    
    # get a batch
    xb, yb = get_batch(train_data)
    
    # predict and get loss
    logits, loss = model(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())

In [None]:
plt.plot(losses)

In [None]:
np.mean(losses[-50:])

In [None]:
# inference
zero_context = torch.zeros((batch_size,1), device=device)
xb, yb = get_batch(train_data)
model.generate_to_text(xb, max_new_tokens=200)

In [None]:
def generate_response(model, prompt=None, max_new_tokens=200):
    if prompt is None:
        # give zero context
        prompt_tensor = torch.zeros((batch_size,1), device=device)
    else:
        # convert prompt to a batched tensor
        prompt_tensor = torch.tensor(encode(prompt), dtype=torch.long, device=device)
        prompt_tensor = prompt_tensor.repeat(batch_size,1)
    output = model.generate_to_text(prompt_tensor, max_new_tokens)
    return output

In [None]:
prompt = "Lex, do you think the "
generate_response(model, prompt, max_new_tokens=200)

In [None]:
# torch.save(model.state_dict(), "5M_1k_steps")

In [None]:
model2 = TransformerLanguageModel().to(device)
model2.load_state_dict(torch.load("5M_1k_steps"))
model2.eval();

In [None]:
prompt = "Lex, do you think the "
generate_response(model2, prompt, max_new_tokens=200)