In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

In [2]:
text = open('tiny-shakespeare.txt', 'r', encoding='utf-8').read()
print('dataset len: ', len(text))
print(text[:1000])

dataset len:  1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in

In [3]:
chars = sorted(list(set(c for c in text)))
print(''.join(chars))
stoi = {s: i for i, s in enumerate(chars)}
itos = {i: s for s, i in stoi.items()}
vocab_size = len(chars)
print('vocab_size: ', vocab_size)
print(len(stoi), len(itos))
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
print(decode(encode('hello world!!!')))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab_size:  65
65 65
hello world!!!


In [4]:
data = torch.tensor(encode(text), dtype=torch.long)
train_data = data[:int(0.9*len(data))]
val_data = data[int(0.9*len(data)):]
print('train size: ', len(train_data))
print('val size: ', len(val_data))

train size:  1003854
val size:  111540


In [5]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(data)-block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

x, y = get_batch('train')
for b in range(batch_size):
    for t in range(block_size):
        context = x[b, :t+1]
        target = y[b, t]
        print('input: ', context.tolist(), ', output: ', target.item())

input:  [24] , output:  43
input:  [24, 43] , output:  58
input:  [24, 43, 58] , output:  5
input:  [24, 43, 58, 5] , output:  57
input:  [24, 43, 58, 5, 57] , output:  1
input:  [24, 43, 58, 5, 57, 1] , output:  46
input:  [24, 43, 58, 5, 57, 1, 46] , output:  43
input:  [24, 43, 58, 5, 57, 1, 46, 43] , output:  39
input:  [44] , output:  53
input:  [44, 53] , output:  56
input:  [44, 53, 56] , output:  1
input:  [44, 53, 56, 1] , output:  58
input:  [44, 53, 56, 1, 58] , output:  46
input:  [44, 53, 56, 1, 58, 46] , output:  39
input:  [44, 53, 56, 1, 58, 46, 39] , output:  58
input:  [44, 53, 56, 1, 58, 46, 39, 58] , output:  1
input:  [52] , output:  58
input:  [52, 58] , output:  1
input:  [52, 58, 1] , output:  58
input:  [52, 58, 1, 58] , output:  46
input:  [52, 58, 1, 58, 46] , output:  39
input:  [52, 58, 1, 58, 46, 39] , output:  58
input:  [52, 58, 1, 58, 46, 39, 58] , output:  1
input:  [52, 58, 1, 58, 46, 39, 58, 1] , output:  46
input:  [25] , output:  17
input:  [25, 17

### n/w layers

In [29]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size) -> None:
        super().__init__()
        self.token_embed_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx):
        # idx - B x T
        logits = self.token_embed_table(idx)
        return logits
    
    def loss(self, logits, targets):
        B, T, C = logits.shape
        loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T), reduction='mean')
        return loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits = self(idx)
            logits = logits[:,-1,:]
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, 1, replacement=True)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [13]:
from typing import Any


class SelfAttention(nn.Module):
    def __init__(self, C, headsize) -> None:
        super().__init__()
        self.q = nn.randn((C, headsize))
        self.k = nn.randn((C, headsize))
        self.v = nn.randn((C, headsize))
    
    def __call__(self, x) -> torch.tensor: # x: B x T x C
        qx = x @ self.q # B x T x C
        kx = x @ self.k # B x T x C
        vx = x @ self.v # B x T x C
        qkx = qx @ kx.transpose(-1, -2) # B x T x T
        wt = torch.tril(x.shape[1]) # T x T
        wt[wt==0] = -torch.float('inf')
        wt = nn.Softmax(wt)
        y = (qkx * wt) @ vx # B x T x C
        return y

In [30]:
x, y = get_batch('train')
print(x.shape, y.shape)
m = BigramLanguageModel(vocab_size=vocab_size)
logits = m(x)
loss = m.loss(logits, y)
print('loss: ', loss)
print('expected loss: ', -torch.log(torch.tensor(1./vocab_size)))

torch.Size([4, 8]) torch.Size([4, 8])
loss:  tensor(4.9153, grad_fn=<NllLossBackward0>)
expected loss:  tensor(4.1744)


In [34]:
out = m.generate(torch.randint(0, vocab_size, (1,1)), 100)[0]
print(decode(out.tolist()))
print(out.shape)
print(out)

'aHKuUWBCTCcxt,mpII3eMbbkcdPJAiu:giVWN jlJYYHZWE,P3djltHLtHR whjJr,gj,GP&TcDcIyaBQVysp.3cEN$$L
D?RtHq
torch.Size([101])
tensor([ 5, 39, 20, 23, 59, 33, 35, 14, 15, 32, 15, 41, 62, 58,  6, 51, 54, 21,
        21,  9, 43, 25, 40, 40, 49, 41, 42, 28, 22, 13, 47, 59, 10, 45, 47, 34,
        35, 26,  1, 48, 50, 22, 37, 37, 20, 38, 35, 17,  6, 28,  9, 42, 48, 50,
        58, 20, 24, 58, 20, 30,  1, 61, 46, 48, 22, 56,  6, 45, 48,  6, 19, 28,
         4, 32, 41, 16, 41, 21, 63, 39, 14, 29, 34, 63, 57, 54,  8,  9, 41, 17,
        26,  3,  3, 24,  0, 16, 12, 30, 58, 20, 55])


In [36]:
optim = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [45]:
batch_size = 32
for i in range(10000):
    x, y = get_batch('train')
    logits = m(x)
    loss = m.loss(logits, y)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()
print(loss.item())

2.4659721851348877


In [47]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), 500)[0].tolist()))


NELLERWino t o r. men,
Frnt tyoraksheishouging.
Misir sugrghertit nke the wo I maral, anecyorest blorke. te pern d,
Thiven y s,
Th tenofire danged wis w,
Pamyork lle'st s agavel e cer.
Hitongno me whishen me IFomse y istety R wis; thengimid gNICOFRD:

GAs tlirveecro pous momoulenan whisirve besthos w e nthe h.
KI agerisupof,
Bed ithe'd n te acet s! de jef bo an hace y, IN iserso u-we be hin ty sts, cu murmupash angshot akes, macoo ouly is t mucef hed: ivougonghanterd inch.
Y ckle tr' t he fodese
