In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
print("PyTorch version: ", torch.__version__)

PyTorch version:  1.10.2


In [130]:
torch.manual_seed(1337)
with open('../karpathy-ai/m-gpt/input.txt', encoding = 'utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create mapping from charcter to integer
stoi = {ch: i for  i,ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# train and test splits
data = torch.tensor(encode(text), dtype = torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
test_data = data[n:]

In [131]:
class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout))
        
    def forward(self, x):
        return self.net(x)

class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, n_heads, n_embed, context_length):
        super().__init__()
        self.n_embed = n_embed
        self.head_size = n_embed // n_heads
        self.register_buffer('tril', torch.tril(torch.ones((context_length, context_length))).view(1,1, context_length, context_length))
        self.c_atten = nn.Linear(n_embed, 3 * n_embed, bias = False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        #print("X.shape at mhsa: ", x.shape)
        B, T, C = x.shape
        q, k, v = self.c_atten(x).split(self.n_embed, dim = 2)
        q = q.view(B, T, n_heads, self.head_size).transpose(1, 2) # B, hd, T, hs
        k = k.view(B, T, n_heads, self.head_size).transpose(1, 2) # B, hd, T, hs
        v = v.view(B, T, n_heads, self.head_size).transpose(1, 2) # B, hd, T, hs
        wei = k @ q.transpose(-2, -1) # (B, hd, T, hs) @ (B, hd, hs, T) = (B, hd, T, T)
        wei = F.softmax(wei.masked_fill(self.tril[:, :, :T, :T] == 0, float('-inf')) / math.sqrt(k.size(-1)), dim = -1)
        wei = self.dropout(wei)
        out = wei @ v # (B, hd, T, T) @ B, hd, T, hs) = (B, hd, T, hs)
        return out.transpose(1,2).contiguous().view(B, T, C)
    
class Transformer(nn.Module):
    def __init__(self, n_heads, vocab_size , n_embed, context_length):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_embed)
        self.position_embedding = nn.Embedding(context_length, n_embed)
        self.mhsa = MultiHeadedSelfAttention(n_heads, n_embed, context_length)
        self.llnorm_1 = nn.LayerNorm(n_embed)
        self.llnorm_2 = nn.LayerNorm(n_embed)
        self.ffwd = FeedForward(n_embed)
        self.llm_head = nn.Linear(n_embed, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T = x.shape
        val_embd = self.embedding(x)
        pos_embd = self.position_embedding(torch.arange(0, T))
        x = val_embd + pos_embd
        x = x + self.mhsa(self.llnorm_1(x)) # add layer norm
        ffwd = self.ffwd(self.llnorm_2(x))
        x = x + ffwd
        output = self.llm_head(x)
        return output

In [178]:
n_heads = 8
n_embed = 128
context_length = 50
vocab_size = 65
batch_size = 128
token_dim = 1
eval_interval = 100
eval_iteration = 100
iterations = 1000
lr = 1e-3
dropout = 0.1

x = torch.randint(vocab_size, (batch_size, context_length, ))

transformer = Transformer(n_heads = n_heads, 
                          vocab_size = vocab_size, 
                          n_embed = n_embed, 
                          context_length = context_length)

optimizer = torch.optim.AdamW(transformer.parameters(), lr = lr)

In [179]:
print("Model paramters: ", sum(p.nelement() for p in transformer.parameters()) / 1e6, "M")

Model paramters:  0.204481 M


In [180]:
@torch.no_grad()
def evaluate():
    transformer.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iteration)
        for k in range(eval_iteration):
            xb, yb = get_batch(split)
            logits = transformer(xb)
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C), yb.view(B*T))
            losses[k] = loss.item()
        out[split] = losses.mean()
    transformer.train()
    return out

def forward_pass_with_grad(xb, yb):
    transformer.train()
    logits = transformer(xb)
    B, T, C = logits.shape
    loss = F.cross_entropy(logits.view(B*T, C), yb.view(B*T))
    return loss, logits

@torch.no_grad()
def forward_pass_without_grad(xb):
    transformer.eval()
    logits = transformer(xb)
    return logits

def get_batch(split = 'train'):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i: i+ context_length] for i in ix])
    y = torch.stack([data[i+1 : i + context_length + 1] for i in ix])
    return x, y


for i in range(iterations):
    if i % eval_interval == 0:
        loss = evaluate()
        print(f"Loss at {i+1} iterations - train loss: {loss['train']:.4f}, val loss: {loss['val']:.4f}")
    else:
        xb, yb = get_batch(split = 'train')
        loss, logits = forward_pass_with_grad(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

Loss at 1 iterations - train loss: 4.5549, val loss: 4.5455
Loss at 101 iterations - train loss: 2.4823, val loss: 2.4994
Loss at 201 iterations - train loss: 2.2720, val loss: 2.2937
Loss at 301 iterations - train loss: 2.1295, val loss: 2.1617
Loss at 401 iterations - train loss: 2.0293, val loss: 2.0889
Loss at 501 iterations - train loss: 1.9537, val loss: 2.0268
Loss at 601 iterations - train loss: 1.8919, val loss: 1.9862
Loss at 701 iterations - train loss: 1.8441, val loss: 1.9626
Loss at 801 iterations - train loss: 1.8064, val loss: 1.9326
Loss at 901 iterations - train loss: 1.7724, val loss: 1.9121


In [181]:
def generate(tokens, max_new_tokens):
    for _ in range(max_new_tokens):
        tokens_cond = tokens[:, -context_length:]
        logits = forward_pass_without_grad(tokens_cond)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim = 1)
        next_token = torch.multinomial(probs, num_samples = 1)
        tokens = torch.cat((tokens, next_token), dim = 1)
    return decode(tokens[0].tolist())

In [183]:
context = torch.zeros((1,1), dtype = torch.long)
print(generate(context, 5000))


ESTER:
True shalk this sith jeart of Geam,
Shold thavery, his merrows morned yow
A onf thummen itle Mance:
Richartuns for liver with, you hows; and makes 'way in halt
whith some supple ento so, sir; an which'd is wadgingranty.

DUKE VINCENTIO:
A flets yout you senot meet, eame, ap willshalling: by thed mine teken,
Tilove royf fear to he cannat prie.
I mile his Xeaing itfe:
Now your wellands,
No nemore:
And there and you him will
I am any you shood he kile such'd my from sich,
Or tell, away! an the binartiam:
Well my nonswers and think let is that fair;
Loothers done creeppring: to with to dispyriteed,
Hase the let consentlet hine rong old
LeA thesedring its we seail, so was hearnently am such not you the
My my speak as of dauge inscansy no,
Tows, to of longue to he I am king,
Had peam:
And was a pursens' suckeeperselved
Thou posespleet dest. Auch will reas her here
Wall wore to my pence deephelp is, as not grain decempred wind couldes: on that bury
Tom mother out with tewdy dear teafo