In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random
import time
import math
import tiktoken
import inspect

In [10]:
vocab_size = 50304 #50257
batch_size = 524288
mini_batches = 8
time_stamps = 512
context_len = 1024
emb_neur = 768
epochs = 50
num_blocks = 12
num_heads = 12
dropout_neur = 0.2


max_lr = 6e-4
min_lr = max_lr * 0.1
lr_steps = 10
weight_decay = 0.1
beta1, beta2 = 0.9, 0.95


enc = tiktoken.get_encoding("gpt2")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
torch.manual_seed(1337)

cuda


<torch._C.Generator at 0x7b439351a110>

In [11]:
class DataLoader():
    def __init__(self, B, T):
        self.B = B
        self.T = T
        
        with open('input.txt', 'r') as f:
            text = f.read()
        text = enc.encode(text)
        self.tokens = torch.tensor(text)
        
        self.current_step = 1

        print(f"loaded {len(text)} tokens")

    def next_batch(self):
        B, T = self.B, self.T
        
        self.current_step += 1
        tokens = self.tokens[(self.current_step-1)*B*T:self.current_step*B*T+1]
        x = (tokens[:-1]).view(B, T)
        y = (tokens[1:]).view(B, T)
        if (self.current_step+1)*B*T+1 > len(self.tokens):
            self.current_step = 1
        return x, y

In [12]:
class SelfAttention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.qkv = nn.Linear(emb_neur, 3 * emb_neur)
        self.proj = nn.Linear(emb_neur, emb_neur)
        self.proj.COMES_TO_RESIDUAL = 1
        # self.dropout = nn.Dropout(dropout_neur)

    def forward(self, idx):
        assert emb_neur % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        B, T, C = idx.shape
        qkv = self.qkv(idx)
        q, k, v = qkv.split(emb_neur, dim=2)
        q = q.view(B, T, num_heads, C//num_heads).transpose(1, 2) # B, nh, T, hs
        k = k.view(B, T, num_heads, C//num_heads).transpose(1, 2) # B, nh, T, hs
        v = v.view(B, T, num_heads, C//num_heads).transpose(1, 2) # B, nh, T, hs

        # attention = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.shape[-1]))
        # attention = torch.tril(attention[:, :, :T, :T])
        
        # attention = attention.masked_fill(attention == 0, float("-inf"))
        # attention = F.softmax(attention, dim=-1)
        # out = attention @ v # B, nh, T, hs 
        

        attention = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        out = attention.transpose(2, 1).contiguous().view(B, T, C)
        out = self.proj(out)
        # out = self.dropout(out)

        return out
        


class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        # self.net = nn.Sequential(
        #     nn.Linear(emb_neur, 4 * emb_neur),
        #     nn.GELU(),
        #     nn.Linear(4 * emb_neur, emb_neur),
        #     nn.Dropout(dropout_neur),
        # )
        self.upl = nn.Linear(emb_neur, 4 * emb_neur)
        self.gelu = nn.GELU()
        self.dwnl = nn.Linear(4 * emb_neur, emb_neur)
        self.dwnl.COMES_TO_RESIDUAL = 1

    def forward(self, idx):
        idx = self.upl(idx)
        idx = self.gelu(idx)
        idx = self.dwnl(idx)
        return idx
        # return self.net(idx)


class Block(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.attentions = SelfAttention(num_heads)
        self.ffn = FeedForward()
        self.ln1 = nn.LayerNorm(emb_neur)
        self.ln2 = nn.LayerNorm(emb_neur)

    def forward(self, idx):
        idx = idx + self.attentions(self.ln1(idx))
        idx = idx + self.ffn(self.ln2(idx))
        return idx

        
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokens_embedding = nn.Embedding(vocab_size, emb_neur)
        self.position_embedding = nn.Embedding(context_len, emb_neur)
        self.blocks = nn.Sequential( *[Block(num_heads) for _ in range(num_blocks)])
        self.ln = nn.LayerNorm(emb_neur)
        self.ll_head = nn.Linear(emb_neur, vocab_size)

        self.tokens_embedding.weight = self.ll_head.weight

        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        std = (1.0 / math.sqrt(emb_neur))
        if isinstance(module, nn.Linear):
            if hasattr(module, "COMES_TO_RESIDUAL"):
                std *= (1.0)/(math.sqrt(2*num_blocks))
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=(1.0 / math.sqrt(emb_neur)))

    # I have taken this function [configure_optimizers] from Karpathy's nanoGPT
    # https://github.com/karpathy/nanoGPT
    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        embedded_tokens = self.tokens_embedding(idx) # B, T, emb_neur
        embedded_position = self.position_embedding(torch.arange(T, device=device)) # T, emb_neur
        
        idx = embedded_tokens + embedded_position # B, T, emb_neur
        idx = self.blocks(idx)
        idx = self.ln(idx)
        logits = self.ll_head(idx)
        
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            logits, _ = self.forward(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx

In [13]:
torch.set_float32_matmul_precision('high')

m = GPT()
m = m.to(device)
m = torch.compile(m)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

data_loader = DataLoader(mini_batches, time_stamps)
# optmizer = torch.optim.Adam(m.parameters(), lr=lr, betas=(0.9, 0.95))

# I have taken this function [configure_optimizers] from Karpathy's nanoGPT
optmizer = m.configure_optimizers(weight_decay, max_lr, (beta1, beta2), device)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optmizer, T_max=lr_steps, eta_min=min_lr)

124.526208 M parameters
loaded 338025 tokens
num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 99, with 171,648 parameters
using fused AdamW: True


In [14]:
assert batch_size % (mini_batches * time_stamps) == 0, "batch_size is not devided by B and T"
mini_epochs = int(batch_size / (mini_batches * time_stamps))

def get_lr(epoch):
    if epoch < lr_steps:
        return (max_lr * (epoch+1)/lr_steps)
    if epoch > epochs:
        return min_lr
    loc = (epoch - lr_steps)/(epochs - lr_steps)
    coef = 0.5 * (1.0 + math.cos(math.pi * loc))
    return min_lr + coef * (max_lr - min_lr)


for epoch in range(50):
    t0 = time.time()
    accumulated_loss = 0.0
    optmizer.zero_grad()

    for mini_epoch in range(mini_epochs):
        x, y = data_loader.next_batch()
        x, y = x.to(device), y.to(device)
    
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = m(x, y)
        loss /= mini_epochs
        accumulated_loss += loss.detach()
        loss.backward()

    norm = torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
    lr = get_lr(epoch)
    for param_group in optmizer.param_groups:
        param_group['lr'] = lr
    optmizer.step()
    # scheduler.step()
    
    torch.cuda.synchronize()
    t1 = time.time()
    dt = t1-t0
    
    print(f"epoch: {epoch}, loss: {accumulated_loss:.5f}, norm: {norm:.5f}, time: {dt*1000:.2f}ms, tok/s: {data_loader.B*data_loader.T*mini_epochs/dt:.2f}")
    
    

epoch: 0, loss: 11.41157, norm: 27.69188, time: 22817.46ms, tok/s: 22977.49
epoch: 1, loss: 10.25615, norm: 59.41756, time: 4527.83ms, tok/s: 115792.32
epoch: 2, loss: 10.87149, norm: 11.13863, time: 4521.63ms, tok/s: 115951.12
epoch: 3, loss: 9.66047, norm: 5.64716, time: 4525.66ms, tok/s: 115847.91
epoch: 4, loss: 8.69950, norm: 2.97529, time: 4524.99ms, tok/s: 115865.06
epoch: 5, loss: 10.57604, norm: 27.93127, time: 4542.54ms, tok/s: 115417.34
epoch: 6, loss: 8.80328, norm: 16.43164, time: 4543.31ms, tok/s: 115397.87
epoch: 7, loss: 8.25292, norm: 6.36091, time: 4597.09ms, tok/s: 114047.81
epoch: 8, loss: 8.00966, norm: 3.81319, time: 4531.47ms, tok/s: 115699.39
epoch: 9, loss: 7.64843, norm: 3.09671, time: 4532.92ms, tok/s: 115662.44
epoch: 10, loss: 7.26969, norm: 2.38564, time: 4545.01ms, tok/s: 115354.52
epoch: 11, loss: 7.01213, norm: 2.86827, time: 4558.70ms, tok/s: 115008.13
epoch: 12, loss: 6.83554, norm: 3.10425, time: 4562.07ms, tok/s: 114923.32
epoch: 13, loss: 6.65110, 

In [None]:
enc.decode(m.generate(torch.tensor(enc.encode("Hello")).to(device).view(1, -1), 50)[0].tolist())

In [13]:
# tokens/sec:22406.09
# tokens/sec:45590.02 torch.set_float32_matmul_precision('high')
# tokens/sec:47236.09  with torch.autocast(device_type=device, dtype=torch.bfloat16):
# tokens/sec:63155.71 torch.compile(m)
# tokens/sec:67969.10 flash
# Nice number

# epoch: 49, loss: 6.08617, norm: 0.28814, time: 4674.63ms, tok/s: 112156.04

tensor(6.7199, device='cuda:0', grad_fn=<NllLossBackward0>)