In [1]:
import os
import numpy as np
import pickle
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir="data/shakespeare_char"
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r').astype(int)
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r').astype(int) 
meta = pickle.load(open(os.path.join(data_dir, 'meta.pkl'), 'rb'))
stoi=meta['stoi']
itos=meta['itos']
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
print(itos)

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z'}


In [3]:
# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
    device = torch.device("cpu")
else:
    device = torch.device("mps")
device = torch.device("cpu")
print(device)

cpu


In [4]:
# Hyperparameters
batch_size = 16 # B
block_size = 128 # T
n_emb = 6*16 # E
n_head = 6
n_layer = 6
#head_size = n_emb # H
vocab_size = meta['vocab_size'] # C
num_epoch = 5000
learning_rate = 3e-4
print(f"vocab size is {vocab_size}")


vocab size is 65


In [5]:
def get_batch(split: str) -> Tuple[torch.Tensor, torch.Tensor]:
    if split == 'train':
        data = torch.from_numpy(np.array(train_data))
    elif split == 'val':
        data = torch.from_numpy(np.array(val_data))
    else:
        raise NotImplementedError
    ids = torch.randint(low=0, high=len(data)-block_size, size=(batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ids])
    y = torch.stack([data[i+1 : i+block_size+1] for i in ids])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb=get_batch('train')

In [6]:
class Head(nn.Module):
    def __init__(self, emb_size: int, head_size: int):
        super().__init__()
        self.query = nn.Linear(emb_size, head_size)
        self.key = nn.Linear(emb_size, head_size)
        self.value = nn.Linear(emb_size, head_size)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, emb: torch.Tensor) -> torch.Tensor:
        # input: (B, T, E)
        B, T, E = emb.shape
        q = self.query(emb) # q,k,v = (B, T, H)
        k = self.key(emb)
        v = self.value(emb)
        weight = q @ k.transpose(-2, -1) * (E**-0.5) # (B, T, T)
        # As a decoder block, mask out future information
        weight = weight.masked_fill(self.tril[:T, :T] == 0, -torch.inf)
        weight = torch.softmax(weight, dim=-1)
        out = weight @ v # (B, T, H)
        return out

head = Head(n_emb, 8)
out = head(torch.ones((batch_size, block_size, n_emb)))
assert out.shape == torch.Size([batch_size, block_size, 8])


In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int, n_head: int, head_size: int):
        super().__init__()
        self.heads = nn.ModuleList([Head(emb_size, head_size) for _ in range(n_head)])
        self.proj = nn.Linear(emb_size, emb_size)
    
    def forward(self, input: torch.Tensor) -> torch.Tensor: 
        x = torch.cat([h(input) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

mhead = MultiHeadAttention(n_emb, 4, n_emb//4)
out = mhead(torch.ones((batch_size, block_size, n_emb)))
out.shape

torch.Size([16, 128, 96])

In [8]:
class FeedForward(nn.Module):
    def __init__(self, input_size: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb),
            nn.ReLU(),
            nn.Linear(4 * n_emb,n_emb),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor: 
        return self.net(input)

mhead = FeedForward(n_emb)
out = mhead(torch.ones((batch_size, block_size, n_emb)))
out.shape

torch.Size([16, 128, 96])

In [9]:
class Block(nn.Module):
    def __init__(self, emb_size: int, n_head:int):
        super().__init__()
        self.sa_head = MultiHeadAttention(emb_size, n_head, n_emb//n_head)
        self.feed_forward = FeedForward(emb_size)
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = input + self.sa_head(self.ln1(input)) # residual connection, pre-norm
        x = x + self.feed_forward(self.ln2(x)) # residual connection, pre-norm
        return x

block = Block(n_emb, n_head)
out = block(torch.ones((batch_size, block_size, n_emb)))
out.shape

torch.Size([16, 128, 96])

In [10]:
class BigramModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_emb)
        self.position_embedding_table = nn.Embedding(block_size, n_emb)
        self.blocks = nn.Sequential(
            *[Block(n_emb, n_head) for _ in range(n_layer)],
            nn.LayerNorm(n_emb)
        )
        self.lm_head = nn.Linear(n_emb, vocab_size)
    
    def forward(self, idx: torch.Tensor, targets: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        # idx shape = [B, T]
        idx = idx[:, -block_size:]
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx) # [B, T, E]
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # [T, E]
        x = token_emb + pos_emb # [B, T, E]
        x = self.blocks(x) # [B, T, H=E]
        logits = self.lm_head(x)

        if targets is not None:
            logits=logits.view(batch_size * block_size, vocab_size)
            targets = targets.view(batch_size * block_size)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss
    
    def generate(self, idx: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        for _ in range(max_new_tokens):
            logits, _ = self.forward(idx[:, -block_size:]) # [B, T ,C]
            logits = logits[:, -1, :] # [B, C]
            probs = torch.softmax(logits, dim=-1) # [B, C]
            idx_next = torch.multinomial(probs, num_samples=1) # [B, 1]
            idx = torch.cat((idx, idx_next), dim=-1) # [B, C+1]
        return idx 

model = BigramModel().to(device)
output = model.generate(torch.zeros((1,1), dtype=torch.int, device=device), 100)[0].tolist()
print(decode(output))



:AwBtSy;o!AVHNMWCtBhsBT3IM$CyQvXnrRzP,t?udTHSyn,cSITSh;y;eId3&frDMcIHq-YUhtjVOXcofFFK-xjhIT!ExYU
NWe


In [11]:
@torch.no_grad()
def estimate_loss(eval_iters: int = 10):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            input, targets = get_batch(split)
            _, loss = model(input, targets)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out
        
estimate_loss()

{'train': tensor(4.2996), 'val': tensor(4.2898)}

In [12]:
def train(model: nn.Module):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for i in tqdm(range(num_epoch)):
        input, targets = get_batch('train')
        logits, loss = model(input, targets)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        if i % (num_epoch/10) == 0:
            losses = estimate_loss()
            print(losses)

train(model)



  0%|          | 2/5000 [00:01<42:19,  1.97it/s]  

{'train': tensor(4.1408), 'val': tensor(4.1466)}


 10%|█         | 502/5000 [01:09<26:05,  2.87it/s]

{'train': tensor(2.4675), 'val': tensor(2.4749)}


 20%|██        | 1002/5000 [02:19<26:30,  2.51it/s]

{'train': tensor(2.3021), 'val': tensor(2.3352)}


 30%|███       | 1501/5000 [03:30<27:53,  2.09it/s]

{'train': tensor(2.1683), 'val': tensor(2.1948)}


 40%|████      | 2002/5000 [04:43<17:46,  2.81it/s]

{'train': tensor(2.0289), 'val': tensor(2.0947)}


 50%|█████     | 2502/5000 [05:55<15:16,  2.73it/s]

{'train': tensor(1.9374), 'val': tensor(2.0111)}


 60%|██████    | 3002/5000 [07:07<11:38,  2.86it/s]

{'train': tensor(1.8349), 'val': tensor(2.0070)}


 70%|███████   | 3502/5000 [08:21<09:53,  2.52it/s]

{'train': tensor(1.7746), 'val': tensor(1.9153)}


 80%|████████  | 4002/5000 [09:35<06:45,  2.46it/s]

{'train': tensor(1.7340), 'val': tensor(1.8768)}


 90%|█████████ | 4502/5000 [10:48<03:04,  2.70it/s]

{'train': tensor(1.6653), 'val': tensor(1.8240)}


100%|██████████| 5000/5000 [12:00<00:00,  6.94it/s]


In [17]:
output = model.generate(torch.zeros((1,1), dtype=torch.long, device=device), 1000)[0].tolist()
print(decode(output))


Conteend yours benesss like own of no lince.
Why the wishicess,' any like word's cars war our of death be.

BUCKINGHAM:
May musterbless are once you, you have our mary?
When spies cuntilly accel'd,
Marke alies loss the gurliel more strups.
Hastile her pllop-leak over all underifurites
Mide words grace our shis dear ranges dabe
Admeneor and yeter a getels your their: it willows,
Madasted me happect spect woe of livess at o'ers all
I, may stirs both great's tream; you can make;
If day, death &norbhad o's thy beloud namply
cortrause a my garlown oner'd? they read cails;
Edwand the shall I charrine is my mean brow;
This to mose on life hams, mine like right
Than girth the gen death kill givods the breas,
and their of underous father, more boughters,,
Opparnoon their, are we whast this on
The commiatis,--that youb's naguy,
And barrs last bose pion him.

ISABELLOND:
Uneme are shalt O adle seems me. Awh, siry wity
These fanstence us, my in that he
The sen so crivent my that nurnes on.

GLUCE

In [14]:
T = 5
tril = torch.tril(torch.ones(T, T))
weight = torch.zeros(T,T)
weight = weight.masked_fill(tril==0, -torch.inf)
weight = torch.softmax(weight, dim=-1)
weight


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

In [15]:
t = torch.randn((2, 20))
ln = nn.LayerNorm(20)
print(t)
print(ln(t))
print(torch.mean(ln(t), dim=1))
print(torch.std(ln(t), dim=1))

tensor([[ 2.1667, -0.2986,  0.0434, -0.9022,  1.1912,  0.6353,  0.2423, -0.5491,
          1.9410,  0.7559, -1.2508,  0.0583, -0.8587, -0.4010,  0.9786, -0.5261,
          1.1254,  1.1608, -0.5628,  1.7550],
        [ 1.3411, -1.5851,  2.0341, -0.7048,  1.2602, -1.4971, -0.1538,  1.3706,
         -1.1876, -1.6367,  1.0304, -1.5667, -0.9364, -0.2546,  0.3365,  0.7078,
          0.8851,  1.6311, -1.0416, -0.5219]])
tensor([[ 1.8508, -0.6405, -0.2950, -1.2504,  0.8650,  0.3033, -0.0939, -0.8936,
          1.6226,  0.4251, -1.6028, -0.2798, -1.2065, -0.7440,  0.6502, -0.8704,
          0.7985,  0.8343, -0.9075,  1.4347],
        [ 1.1414, -1.3045,  1.7206, -0.5686,  1.0738, -1.2309, -0.1081,  1.1660,
         -0.9722, -1.3475,  0.8817, -1.2890, -0.7622, -0.1924,  0.3017,  0.6121,
          0.7603,  1.3838, -0.8502, -0.4158]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([-5.9605e-09,  1.1921e-08], grad_fn=<MeanBackward1>)
tensor([1.0260, 1.0260], grad_fn=<StdBackward0>)


In [16]:
import torch
import math
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
# this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

True
True
