In [1]:
from __future__ import annotations
import typing

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import os
from tqdm import tqdm

torch.manual_seed(3654)

<torch._C.Generator at 0x7fa39e7f7150>

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# report the parameters | and loss with 5k training steps:

# --- small 10k dataset ---
# bigram model: 3844 ==> loss to 2.5
# with one multihead attention (5 heads): 53262 ==> loss to 0.35
# with one multihead attention (5 heads) and residual connection: 53262 ==> loss to 0.30
# with one full block: 134162 ==> loss to 0.24
# with 5 blocks: 617762 ==> loss to 0.12

# --- large dataset --- 
# bigger model: 5015318 (5M) ==> loss to 0.0116

### Get text data

In [4]:
with open("text_corpus.txt", encoding="utf-8") as f:
    text = f.read() # [:10_000]

In [5]:
len(text)

38739496

In [6]:
vocab = sorted(set(text))
print(vocab)

['\n', ' ', '!', '"', '$', '%', '&', "'", '+', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '´', 'É', 'Ö', 'Ü', 'á', 'ã', 'å', 'ç', 'è', 'é', 'ê', 'í', 'ó', 'ô', 'ö', 'ø', 'ù', 'ü', 'ć', 'ł', 'ő', 'Б', 'В', 'Г', 'Д', 'К', 'Н', 'П', 'Р', 'С', 'Т', 'У', 'Я', 'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я', 'ё', '–', '—', '…', '가', '들', '어']


In [7]:
vocab_size = len(vocab)
vocab_size

150

### Encode the characters to integers

In [8]:
# Create encoder and decoder dicts

char_int_mapping = dict()
int_char_mapping = dict()

for i, c in enumerate(sorted(set(text))):
    char_int_mapping[c] = i
    int_char_mapping[i] = c

In [9]:
def encode(string: str) -> List[int]:
    int_list = [char_int_mapping[char] for char in string]
    return int_list

def decode(int_list: List[int]) -> str:
    string = [int_char_mapping[num] for num in int_list]
    return "".join(string)

### Make it a tensor

In [10]:
data = torch.tensor(encode(text), dtype=torch.long)
data.shape

torch.Size([38739496])

### Train/Test split

In [11]:
N = int(0.9*len(data))
train_data = data[:N]
test_data = data[N:]

print(len(train_data), len(test_data))

34865546 3873950


### Create minibatches

In [12]:
block_size = 8 # block_size is the maximum context length (input textblock size)
batch_size = 4

In [13]:
def get_batch(split: torch.tensor):
    offsets = np.random.randint(0, len(split) - block_size, size=batch_size)
    x = torch.stack([split[i:i+block_size] for i in offsets]).to(device)
    y = torch.stack([split[i+1:i+block_size+1] for i in offsets]).to(device)
    return x, y

### Transformer

In [14]:
# hyperparameters
training_steps = 15000
embed_dims = 256 # is equivalent to d_model
block_size = 256
batch_size = 64
n_heads = 8
head_size = embed_dims // n_heads
n_layers = 25

In [15]:
class SelfAttentionHead(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.proj_q = nn.Linear(embed_dims, head_size, bias=False)
        self.proj_k = nn.Linear(embed_dims, head_size, bias=False)
        self.proj_v = nn.Linear(embed_dims, head_size, bias=False)
    
    def forward(self, x):
        """ 
        Applies masked scaled dot-product attention
        between vectors of queries Q, keys K and values V. 
        """
        B,T,C = x.shape
        
        Q = self.proj_q(x)
        K = self.proj_k(x)
        V = self.proj_v(x)

        W = (Q @ K.transpose(-1,-2)) # (B, T, C) @ (B, C, T) ==> (B,T,T)
        W /= torch.sqrt(torch.tensor(head_size))
        
        # mask out forbidden connections
        tril = torch.tril(torch.ones((block_size, block_size), device=device))
        W = W.masked_fill(tril[:T, :T]==0, float("-inf")) # make smaller so it fits if context < block_size
        W = F.softmax(W, dim=1)
        out = W @ V
        return out # (B,T,C=head_size)

In [16]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.heads = nn.ModuleList([SelfAttentionHead() for i in range(n_heads)])
        self.proj = nn.Linear(embed_dims, embed_dims, bias=False) # embed_dims = n_heads * head_size
    
    def forward(self, x):
        
        out = torch.cat([attn_head(x) for attn_head in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [17]:
class Block(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.attn = MultiHeadAttention()
        self.ln1 = nn.LayerNorm(embed_dims)
        self.ln2 = nn.LayerNorm(embed_dims)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dims, 4*embed_dims), # following attention-is-all-you-need paper for num hidden units
            nn.ReLU(),
            nn.Linear(4*embed_dims, embed_dims),
            nn.ReLU(),
        )
    
    def forward(self, x):
        
        # Applies layernorm before self-attention.
        # In the attention-is-all-you-need paper they apply it afterwards, 
        # but apparently pre-ln performs better. pre-ln paper: https://arxiv.org/pdf/2002.04745.pdf
        
        x = x + self.attn(self.ln1(x)) # (B,embed_dims)
        x = x + self.mlp(self.ln2(x))
        return x

In [18]:
class TransformerLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embed_dims)
        
        # positional encoding
        self.pos_embedding_table = nn.Embedding(block_size, embed_dims)
        
        # transformer layers
        # self.multihead_attn1 = MultiHeadAttention()
        # self.block1 = Block()
        self.blocks = nn.Sequential(*[Block() for i in range(n_layers)])
        
        # output layers
        self.lm_head = nn.Linear(embed_dims, vocab_size)
        
    
    def forward(self, context, targets=None):
        
        B, T = context.shape
        
        # get the embedding vectors word-to-vec style
        token_emb = self.token_embedding_table(context) # (Batch, Time, Channels) ==> [4, 8, 62]
        
        # add the positional embedding'
        pos_emb = self.pos_embedding_table(torch.arange(T, device=device)) # (T,C)
        
        x = token_emb + pos_emb

        # transformer forward pass
        x = self.blocks(x)

        # output layers
        logits = self.lm_head(x)        
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T,C) # [32,64]
            targets = targets.view(B*T) # [32]
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, context, max_new_tokens):
        B, T = context.shape
        
        # context: (Batch, Time) ==> [4, 8]
        # -> extend context in Time dimension for max_new_tokens
        
        for _ in range(max_new_tokens):
            
            # get prediction
            # logits, loss = self(xb, yb)
            logits, loss = self(context[:,-block_size:])
            
            # get logits for the last character (for the next token)
            logits = logits.view(B,-1,vocab_size) # (B,T,C)
            logits = logits[:, -1, :] # (B,C) for only the last character
            
            probs = F.softmax(logits, dim=-1) # (B,C)
            next_token = torch.multinomial(probs, num_samples=1) # (B,1)
            
            # append next token to the sequence
            context = torch.cat((context, next_token), dim=1) # (B,T+1)
        
        return context
    
    def generate_to_text(self, context, max_new_tokens):
        context = self.generate(context, max_new_tokens)
        return decode(context[0].tolist())

model = TransformerLanguageModel().to(device)

In [19]:
def train():

    losses = []
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    
    for step in tqdm(range(training_steps)):
        
        # get a batch
        xb, yb = get_batch(train_data)
        
        # predict and get loss
        logits, loss = model(xb, yb)
        
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
    
        if step % 500 == 0 or step == training_steps-1:
            if not os.path.exists("weights"):
                os.mkdir("weights")
            torch.save({
                    'epoch': step,
                    'model_state_dict': model.state_dict(),
                    'loss': loss,
                    }, f"weights/{loss.item():5f}")

In [None]:
train()

# plt.plot(losses)
# np.mean(losses[-50:])

  1%|▌                                                                                    | 103/15000 [00:43<1:41:46,  2.44it/s]

In [None]:
# inference
# zero_context = torch.zeros((batch_size,1), device=device)
# xb, yb = get_batch(train_data)
# model.generate_to_text(xb, max_new_tokens=200)

In [None]:
# model = TransformerLanguageModel().to(device)
#model.load_state_dict(torch.load("5M_1k_steps"))
#model.eval();

In [None]:
def prompt_model(model, prompt=None, max_new_tokens=200):
    if prompt is None or prompt == "":
        # give zero context
        prompt_tensor = torch.zeros((1,1), dtype=torch.long, device=device)
    else:
        # convert prompt to a batched tensor
        prompt_tensor = torch.tensor(encode(prompt), dtype=torch.long, device=device)
        prompt_tensor = torch.unsqueeze(prompt_tensor, 0)
    output = model.generate_to_text(prompt_tensor, max_new_tokens)
    return output

In [None]:
prompt = "Lex, think theLex, do you think theLex, do you think theLex, do you think theLex, do you think theLex, do you think theLex, do you think theLex, do you think the"
print(len(prompt))
prompt_model(model, prompt, max_new_tokens=100)

In [None]:
# prompt_tensor = torch.tensor(encode(prompt), dtype=torch.long, device=device)
# prompt_tensor = torch.unsqueeze(prompt_tensor, 0)
# output = model.generate_to_text(prompt_tensor, 20)
# output

In [None]:
# prompt_tensor = torch.zeros((1,1), dtype=torch.long, device=device)
# output = model.generate_to_text(prompt_tensor, 20)
# output