In [1]:
# read the file (mahabharata)
with open('mahabharata.txt', 'r', encoding='utf-8') as f:
  text = f.read()

In [2]:
# Remove the BOM character if it exists
text = text.replace('\ufeff', '')

In [3]:
print("number of characters in dataset:", len(text))

number of characters in dataset: 3641165


In [4]:
# first 1000 chars
print(text[:1000])

THE MAHABHARATA

ADI PARVA

SECTION I

Om! Having bowed down to Narayana and Nara, the most exalted male being,
and also to the goddess Saraswati, must the word Jaya be uttered.

Ugrasrava, the son of Lomaharshana, surnamed Sauti, well-versed in the
Puranas, bending with humility, one day approached the great sages of
rigid vows, sitting at their ease, who had attended the twelve years’
sacrifice of Saunaka, surnamed Kulapati, in the forest of Naimisha. Those
ascetics, wishing to hear his wonderful narrations, presently began to
address him who had thus arrived at that recluse abode of the inhabitants
of the forest of Naimisha. Having been entertained with due respect by
those holy men, he saluted those Munis (sages) with joined palms, even
all of them, and inquired about the progress of their asceticism. Then
all the ascetics being again seated, the son of Lomaharshana humbly
occupied the seat that was assigned to him. Seeing that he was
comfortably seated, and recovered from fatigue,

In [5]:
# unique chars in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !&(),-.0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz‘’“”
80


In [18]:
# code to encode and decode vocabulary

token_to_int = {token:int for int,token in enumerate(chars)}
int_to_token = {int:token for int,token in enumerate(chars)}
encode = lambda s: [token_to_int[c] for c in s] # convert string to list of integers
decode = lambda l: ''.join([int_to_token[i] for i in l]) # convert list of integers to string

print(encode("what is up"))
print(decode(encode("what is up")))
print(decode([57, 58, 2]))

[72, 57, 50, 69, 1, 58, 68, 1, 70, 65]
what is up
hi!


In [19]:
# encode the entire text

import torch

data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])


torch.Size([3641165]) torch.int64
tensor([41, 29, 26,  1, 34, 22, 29, 22, 23, 29, 22, 39, 22, 41, 22,  0,  0, 22,
        25, 30,  1, 37, 22, 39, 43, 22,  0,  0, 40, 26, 24, 41, 30, 36, 35,  1,
        30,  0,  0, 36, 62,  2,  1, 29, 50, 71, 58, 63, 56,  1, 51, 64, 72, 54,
        53,  1, 53, 64, 72, 63,  1, 69, 64,  1, 35, 50, 67, 50, 74, 50, 63, 50,
         1, 50, 63, 53,  1, 35, 50, 67, 50,  6,  1, 69, 57, 54,  1, 62, 64, 68,
        69,  1, 54, 73, 50, 61, 69, 54, 53,  1])


In [20]:
# train and validation split

n = int(0.9*len(text))
train_data = data[:n]
val_data = data[n:]

In [23]:
block_size = 8
train_data[:block_size+1]

tensor([41, 29, 26,  1, 34, 22, 29, 22, 23])

In [24]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target is {target}")


when input is tensor([41]) the target is 29
when input is tensor([41, 29]) the target is 26
when input is tensor([41, 29, 26]) the target is 1
when input is tensor([41, 29, 26,  1]) the target is 34
when input is tensor([41, 29, 26,  1, 34]) the target is 22
when input is tensor([41, 29, 26,  1, 34, 22]) the target is 29
when input is tensor([41, 29, 26,  1, 34, 22, 29]) the target is 22
when input is tensor([41, 29, 26,  1, 34, 22, 29, 22]) the target is 23


In [28]:
# create batches

torch.manual_seed(42)

block_size = 8
batch_size = 4 # independent sequences to be processed in parallel

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target is {target}")

inputs:
torch.Size([4, 8])
tensor([[63, 53,  1, 25, 50, 63, 50, 71],
        [54, 68,  6,  1, 68, 50, 58, 53],
        [65, 64, 54, 62,  8,  0,  0, 78],
        [50, 69,  1, 40, 50, 51, 57, 50]])
targets:
torch.Size([4, 8])
tensor([[53,  1, 25, 50, 63, 50, 71, 50],
        [68,  6,  1, 68, 50, 58, 53,  1],
        [64, 54, 62,  8,  0,  0, 78, 46],
        [69,  1, 40, 50, 51, 57, 50,  8]])
----
when input is [63] the target is 53
when input is [63, 53] the target is 1
when input is [63, 53, 1] the target is 25
when input is [63, 53, 1, 25] the target is 50
when input is [63, 53, 1, 25, 50] the target is 63
when input is [63, 53, 1, 25, 50, 63] the target is 50
when input is [63, 53, 1, 25, 50, 63, 50] the target is 71
when input is [63, 53, 1, 25, 50, 63, 50, 71] the target is 50
when input is [54] the target is 68
when input is [54, 68] the target is 6
when input is [54, 68, 6] the target is 1
when input is [54, 68, 6, 1] the target is 68
when input is [54, 68, 6, 1, 68] the target is

In [29]:
print(xb)

tensor([[63, 53,  1, 25, 50, 63, 50, 71],
        [54, 68,  6,  1, 68, 50, 58, 53],
        [65, 64, 54, 62,  8,  0,  0, 78],
        [50, 69,  1, 40, 50, 51, 57, 50]])


In [34]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_emb_table = nn.Embedding(vocab_size,vocab_size)


    def forward(self, idx, targets=None):
        # idx and targets are of size (B, T) (batch, time) (4, 8) in this case
        logits = self.token_emb_table(idx) # this will be of size (B, T, C) where C is channels = vocab_size

        if targets is None:
            loss=None
        else:
            # cross_entropy expects C as second dim, so we modify the dim of our tensors:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # C as the second dim
            targets = targets.view(B*T) # make this one dimensional
            loss = F.cross_entropy(logits, targets) # should work now

        return logits, loss


    # copied this from karpathy's github (easy code)
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
    

    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 80])
tensor(5.0515, grad_fn=<NllLossBackward0>)

d(9[!”o[t,w62XYWQe76:62
9Xa,BavjX3’Dm
”FTDobUDXFRNy
X4ec&TDLY4,BzfRQz,W3L:4c
7kT1
r]
JMKF(cS&-SYYguB


Train the Bigram model:

In [35]:
# create an optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [42]:
# training loop

batch_size = 32 # let's up from 4 to 32
for steps in range(10000):
    xb, yb = get_batch('train') # get a batch of data

    # evaluate
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True) # to prevent gradient accumulation
    loss.backward()
    optimizer.step()

print(loss.item())

2.404163360595703


In [49]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


t. marathedysther
USulive tod ares) Pandr ts tiouind hndsorangthong fitrcrer-borithico iod botharrsa sothe, a oth. ans,
he w cof
anglin t Ser ran Bhr theg! Aun klo vasondhioAngg Asetu Prives
o o thest ld wng btin t whu a we?’
atitha wis, f ankem wfr, dinoa, h utithe. ofto fompimee, O o ing, br f aing whravesowh qund acelevey se whe
g
Visolall, ins
Brjonentheshegeeer f An mowhthinfrer hthe was, IO u s Sache Ratha dumy h benl
huceinend najowof the thir URauencor aind eraveranthen and thecof O d fo


### The mathematical trick in self-attention

Averaging out the channels of all tokens until the current token. This is kind of a bottleneck but we don't have to worry wbout it right now. Let's implement this on a toy data first:

In [64]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channel
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [65]:
# implement x[b,t] = mean_{i<=t} x[b, i]
xbow = torch.zeros((B, T, C)) # bow = bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0) 

The trick: effiency using matrix multiplication.

In [67]:
torch.manual_seed(42)

# using lower triangular matrix can help average the values easily as seen in the output
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [68]:
# version 2: using matrix multiply for a weighted aggregation
# same as what we applied above
wei = torch.tril(torch.ones(T, T)) # creates a lower triangular matrix
wei = wei / wei.sum(1, keepdim=True) # this makes the averaging happen
xbow2 = wei @ x # wei is originally (T,T) but will broadcast as (B,T,T). (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

True

In [74]:
# 3rd way: using softmax
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf')) # wherever tril=0, replace by -inf
wei = F.softmax(wei, dim=1) # softmax will do the same matrix
xbow3 = wei @ x

torch.allclose(xbow, xbow3)

True

In [80]:
# v4: self-attention

torch.manual_seed(69)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# single attention head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # B x T x 16
q = query(x) # B x T x 16
wei = q @ k.transpose(-2, -1)   # transpose along the last 2 dimensions
                                 # BxTx16 @ Bx16xT --> BxTxT

tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

#out = wei@x

out.shape

torch.Size([4, 8, 16])

In [81]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3989, 0.6011, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2886, 0.1684, 0.5431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5726, 0.2702, 0.0419, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2158, 0.3648, 0.1488, 0.0369, 0.2337, 0.0000, 0.0000, 0.0000],
        [0.1819, 0.0757, 0.3826, 0.1124, 0.0363, 0.2110, 0.0000, 0.0000],
        [0.1323, 0.1567, 0.0327, 0.2290, 0.2709, 0.0831, 0.0953, 0.0000],
        [0.2437, 0.0516, 0.1671, 0.0837, 0.0228, 0.3515, 0.0353, 0.0443]],
       grad_fn=<SelectBackward0>)