In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
print("PyTorch version: ", torch.__version__)

PyTorch version:  1.10.2


In [130]:
torch.manual_seed(1337)
with open('../karpathy-ai/m-gpt/input.txt', encoding = 'utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create mapping from charcter to integer
stoi = {ch: i for  i,ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# train and test splits
data = torch.tensor(encode(text), dtype = torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
test_data = data[n:]

In [206]:
class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout))
        
    def forward(self, x):
        return self.net(x)

class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, n_heads, n_embed, context_length):
        super().__init__()
        self.n_embed = n_embed
        self.head_size = n_embed // n_heads
        self.register_buffer('tril', torch.tril(torch.ones((context_length, context_length))).view(1,1, context_length, context_length))
        self.c_atten = nn.Linear(n_embed, 3 * n_embed, bias = False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        #print("X.shape at mhsa: ", x.shape)
        B, T, C = x.shape
        q, k, v = self.c_atten(x).split(self.n_embed, dim = 2)
        q = q.view(B, T, n_heads, self.head_size).transpose(1, 2) # B, hd, T, hs
        k = k.view(B, T, n_heads, self.head_size).transpose(1, 2) # B, hd, T, hs
        v = v.view(B, T, n_heads, self.head_size).transpose(1, 2) # B, hd, T, hs
        wei = k @ q.transpose(-2, -1) # (B, hd, T, hs) @ (B, hd, hs, T) = (B, hd, T, T)
        wei = F.softmax(wei.masked_fill(self.tril[:, :, :T, :T] == 0, float('-inf')) / math.sqrt(k.size(-1)), dim = -1)
        wei = self.dropout(wei)
        out = wei @ v # (B, hd, T, T) @ B, hd, T, hs) = (B, hd, T, hs)
        return out.transpose(1,2).contiguous().view(B, T, C)

    
class Block(nn.Module):
    def __init__(self, n_heads, n_embed, context_length):
        super().__init__()
        self.mhsa = MultiHeadedSelfAttention(n_heads, n_embed, context_length)
        self.llnorm_1 = nn.LayerNorm(n_embed)
        self.llnorm_2 = nn.LayerNorm(n_embed)
        self.ffwd = FeedForward(n_embed)
        
    def forward(self, x):
        x = x + self.mhsa(self.llnorm_1(x))
        ffwd = self.ffwd(self.llnorm_2(x))
        x = x + ffwd
        return x
    
class GPT(nn.Module):
    def __init__(self, n_heads, vocab_size , n_embed, context_length, n_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_embed)
        self.position_embedding = nn.Embedding(context_length, n_embed)
        self.transformer = nn.ModuleList([Block(n_heads, n_embed, context_length) for _ in range(n_layers)])
        self.llm_head = nn.Linear(n_embed, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T = x.shape
        val_embd = self.embedding(x)
        pos_embd = self.position_embedding(torch.arange(0, T))
        x = val_embd + pos_embd
        for block in self.transformer:
            x = block(x)
        output = self.llm_head(x)
        return output

In [207]:
n_heads = 4
n_embed = 32
context_length = 50
vocab_size = 65
batch_size = 64
n_layers = 3
token_dim = 1
eval_interval = 100
eval_iteration = 100
iterations = 5000
lr = 1e-3
n_layers = 2
dropout = 0.1

x = torch.randint(vocab_size, (batch_size, context_length, ))

gpt = GPT(n_heads = n_heads, 
                          vocab_size = vocab_size, 
                          n_embed = n_embed, 
                          context_length = context_length,
                         n_layers = n_layers)

optimizer = torch.optim.AdamW(gpt.parameters(), lr = lr)

In [208]:
print("Model paramters: ", sum(p.nelement() for p in gpt.parameters()) / 1e6, "M")

Model paramters:  0.028929 M


In [209]:
@torch.no_grad()
def evaluate():
    gpt.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iteration)
        for k in range(eval_iteration):
            xb, yb = get_batch(split)
            logits = gpt(xb)
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C), yb.view(B*T))
            losses[k] = loss.item()
        out[split] = losses.mean()
    gpt.train()
    return out

def forward_pass_with_grad(xb, yb):
    gpt.train()
    logits = gpt(xb)
    B, T, C = logits.shape
    loss = F.cross_entropy(logits.view(B*T, C), yb.view(B*T))
    return loss, logits

@torch.no_grad()
def forward_pass_without_grad(xb):
    gpt.eval()
    logits = gpt(xb)
    return logits

def get_batch(split = 'train'):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i: i+ context_length] for i in ix])
    y = torch.stack([data[i+1 : i + context_length + 1] for i in ix])
    return x, y


for i in range(iterations):
    if i % eval_interval == 0:
        loss = evaluate()
        print(f"Loss at {i+1} iterations - train loss: {loss['train']:.4f}, val loss: {loss['val']:.4f}")
    else:
        xb, yb = get_batch(split = 'train')
        loss, logits = forward_pass_with_grad(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

Loss at 1 iterations - train loss: 4.5124, val loss: 4.5246
Loss at 101 iterations - train loss: 2.8582, val loss: 2.8878
Loss at 201 iterations - train loss: 2.6429, val loss: 2.6631
Loss at 301 iterations - train loss: 2.5628, val loss: 2.5680
Loss at 401 iterations - train loss: 2.5124, val loss: 2.5184
Loss at 501 iterations - train loss: 2.4762, val loss: 2.4818
Loss at 601 iterations - train loss: 2.4387, val loss: 2.4478
Loss at 701 iterations - train loss: 2.4089, val loss: 2.4188
Loss at 801 iterations - train loss: 2.3821, val loss: 2.3956
Loss at 901 iterations - train loss: 2.3597, val loss: 2.3667
Loss at 1001 iterations - train loss: 2.3352, val loss: 2.3475
Loss at 1101 iterations - train loss: 2.3078, val loss: 2.3223
Loss at 1201 iterations - train loss: 2.2836, val loss: 2.3024
Loss at 1301 iterations - train loss: 2.2598, val loss: 2.2834
Loss at 1401 iterations - train loss: 2.2421, val loss: 2.2665
Loss at 1501 iterations - train loss: 2.2242, val loss: 2.2453
Loss

In [210]:
def generate(tokens, max_new_tokens):
    for _ in range(max_new_tokens):
        tokens_cond = tokens[:, -context_length:]
        logits = forward_pass_without_grad(tokens_cond)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim = 1)
        next_token = torch.multinomial(probs, num_samples = 1)
        tokens = torch.cat((tokens, next_token), dim = 1)
    return decode(tokens[0].tolist())

In [211]:
context = torch.zeros((1,1), dtype = torch.long)
print(generate(context, 5000))


st yeciogire head gue of muttrie.

BUKENENCESCERY:
I but, I both ping of theos: you have,
Whout wure
Seat may brof.

BARDUCHENGA:
Forrrous of thise proting'd past of colvet will rechert
Is it hang hattmy that shas pose;
That if me but betaresh'd my shen.

KINCENN:
Whet it I mid lerives ugot, is bek'd,
Abunw you like pagaare frooms; faar my it ga;
Aft Ifal is molid, withr my lome
feack land that Pedien, her moes; pheant,
How the compis, this, of thy sust non.

Pon QUEEjon,'se bit co yound Cheall, nown
Fir bothere be loss, tray knaXIIZENT:
And the Jepuly the wortanst, it.

SICHARD BEOLLO:
I he hovaratin!
In:
Lat what theise parpt
And iblie, with or my lisce: dentie;
Hid they shastaling! Vesslaing the have awitith ibeabod.

FRCHICH I RICIIO

NIIO:
The gault:
Gam whan fing will you this mave my,
A mainy nor:
Bet
Gef and have and at that be altle sen. mike;
And is, whe &hf it? Muthwen lim, I Enot!

Nord, shall theine peprowill love, witht lle
And ting I steen!
But noth? Hat staul gitteing,