In [58]:
import torch 

In [59]:
with open("./data/input.txt", "r", encoding='utf-8') as f:
  text = f.read()

len(text)

1115394

In [60]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [61]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [62]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])

print(encode("Hello There!"))
print(decode(encode("Hello There!")))

[20, 43, 50, 50, 53, 1, 32, 46, 43, 56, 43, 2]
Hello There!


In [63]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.size(), data.dtype)

torch.Size([1115394]) torch.int64


In [64]:
t = int(0.9 * data.shape[0])
train_data = data[:t]
val_data = data[t:]

print(train_data.shape, val_data.shape)

torch.Size([1003854]) torch.Size([111540])


In [65]:
block_size = 8 
X = train_data[:block_size]
y = train_data[1:block_size+1]
for i in range(block_size):
  context = X[:i+1]
  target = y[i]
  print(f"When input is {decode(context.tolist())}, the target is {decode([target.item()])}")

When input is F, the target is i
When input is Fi, the target is r
When input is Fir, the target is s
When input is Firs, the target is t
When input is First, the target is  
When input is First , the target is C
When input is First C, the target is i
When input is First Ci, the target is t


In [66]:
torch.manual_seed(1337)
batch_size = 4 
block_size = 8 

def get_batch(split):
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data)-block_size, size=(batch_size, ))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x, y

xb, yb = get_batch("train")

for b in range(batch_size):
  for bl in range(block_size):
    print(f"when input is {decode(xb[b][:bl+1].tolist())}, target is {decode([yb[b][bl].item()])}")
  print('-'*40)


when input is L, target is e
when input is Le, target is t
when input is Let, target is '
when input is Let', target is s
when input is Let's, target is  
when input is Let's , target is h
when input is Let's h, target is e
when input is Let's he, target is a
----------------------------------------
when input is f, target is o
when input is fo, target is r
when input is for, target is  
when input is for , target is t
when input is for t, target is h
when input is for th, target is a
when input is for tha, target is t
when input is for that, target is  
----------------------------------------
when input is n, target is t
when input is nt, target is  
when input is nt , target is t
when input is nt t, target is h
when input is nt th, target is a
when input is nt tha, target is t
when input is nt that, target is  
when input is nt that , target is h
----------------------------------------
when input is M, target is E
when input is ME, target is O
when input is MEO, target is :
when in

In [67]:
from torch import nn, optim
from torch.nn import functional as F 
from einops import rearrange
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
  def __init__(self, vocab_size) -> None:
    super().__init__()

    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
  
  def forward(self, idx, target=None):

    logits = self.token_embedding_table(idx)

    if target is not None:
      logits = rearrange(logits, 'b t c -> (b t) c')
      target = rearrange(target, 'b t -> (b t)')
      loss = F.cross_entropy(logits, target)
    else:
      loss = None
    return logits, loss
  
  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      logits, _ = self(idx)
      logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat([idx, idx_next], dim=1)
    return idx
      

model = BigramLanguageModel(vocab_size)
logits, loss = model(xb, yb)
loss.item()



4.878634929656982

In [68]:
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

In [69]:
batch_size= 32
n_epochs = 10000
for epoch in range(n_epochs):
  xb, yb = get_batch("train")

  logits, loss = model(xb, yb)

  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

print(loss.item())




2.394822597503662


In [71]:
idx = torch.zeros((1, 1), dtype=torch.long)
ypred = ''.join(decode(model.generate(idx, max_new_tokens=300)[0].tolist()))
print(ypred)


Iyoteng h hasbe pave pirance
Rie hicomyonthar's
Plinseard ith henoure wounonthioneir thondy, y heltieiengerofo'dsssit ey
KIN d pe wither vouprrouthercc.
hathe; d!
My hind tt hinig t ouchos tes; st yo hind wotte grotonear 'so it t jod weancotha:
h hay.JUCle n prids, r loncave w hollular s O:
HIs; ht 


## **Self Attention**

In [81]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [82]:
xbow = torch.zeros(B, T, C)
for b in range(B):
  for t in range(T):
    xprev = x[b, :t+1]
    xbow[b, t] = xprev.mean(0)

# print(x[0])
# print(xbow[0])

In [89]:
wei = torch.tril(torch.ones(T, T))
wei = wei/wei.sum(1, keepdim=True)

xbow2 = wei @ x  # (1, T, T) @ (B, T, C) -> (B, T, C)

torch.allclose(xbow, xbow2, rtol=1e-4)
# print((xbow - xbow2).abs().max().item)

True

In [94]:
# NOTE: Introduce attention where the normalized weights are more learnable rather than being equal

tril = torch.tril(torch.ones((T, T)))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

xbow3 = wei @ x
xbow3[0]
torch.allclose(xbow3, xbow2, rtol=1e-4)

True

: 

In [10]:
import torch
class LayerNorm1d: # (used to be BatchNorm1d)

  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)

  def __call__(self, x):
    # calculate the forward pass
    xmean = x.mean(1, keepdim=True) # batch mean
    xvar = x.var(1, keepdim=True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out

  def parameters(self):
    return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [11]:
x[:,0].mean(), x[:,0].std() 

(tensor(0.1469), tensor(0.8803))

In [12]:
x[0,:].mean(), x[0,:].std() 

(tensor(-9.5367e-09), tensor(1.0000))

: 