# Combine Head and MultiHead layer in a single block.

We should have a single layer. The idea is to use `num_heads` as a batch dimension.

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython.display import display, Markdown
%matplotlib inline

In [6]:
with open("../data/tiny-shakespeare/input.txt") as file:
    data = file.read()

len(data)

1115394

In [7]:
display(Markdown(data[:1000]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [8]:
chars = sorted(list(set(data)))
"".join(char for char in chars)

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [9]:
len(chars)

65

In [74]:
stoi, itos = {}, {}

for i, char in enumerate(chars):
    stoi[char] = i
    itos[i] = char

encode = lambda text: [stoi[c] for c in text]
decode = lambda idx_list: "".join(itos[i] for i in idx_list)

In [90]:
# prepare train and test data
split_idx = int(0.80 * len(data))
tokens = encode(data)
train_tokens = tokens[:split_idx]
val_tokens = tokens[split_idx:]
len(train_tokens), len(val_tokens)

(892315, 223079)

In [102]:
# create the model

class MHA(nn.Module):
    def __init__(self, emb_dim, block_size, n_heads, head_dim, dropout):
        super().__init__()

        self.n_heads = n_heads
        self.head_dim = head_dim

        # 1st LayerNorm
        self.ln1 = nn.LayerNorm(emb_dim)

        # first Linear to get from emb_dim --> 3 * n_heads*head_dim, to get k,q,v, then proj back to emb_dim
        self.c_proj = nn.Linear(emb_dim, 3 * n_heads * head_dim, bias=False)
        self.proj = nn.Linear(n_heads * head_dim, emb_dim)

        # 2nd LayerNorm
        self.ln2 = nn.LayerNorm(emb_dim)

        # finally thinking layer
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(4 * emb_dim, emb_dim)
        )

        self.dropout1 = nn.Dropout(dropout)

        # finally register the tril matrix
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        # get the shape
        B, T, C = x.shape

        # Layer norm
        ln_x = self.ln1(x)

        # Project and extract k,q,v
        c = self.c_proj(ln_x) # (B,T,C)  --> (B,T,3*nh*H)
        c = c.view(B, T, self.n_heads, 3 * self.head_dim) # (B,T,nh,3*H)
        k, q, v = torch.split(c, self.head_dim, dim=-1) # each of shape B,T,nh,H
        k, q, v = k.transpose(-3, -2), q.transpose(-3, -2), v.transpose(-3, -2) # B, nh, T, H

        # Get the attention weights
        wei = q @ k.transpose(-2, -1) * (self.head_dim**-0.50) # (B,nh,T,H) @ (B,nh,H,T) -> (B,nh,T,T)
        wei = wei.masked_fill(self.mask[:T, :T] == 0, -float("inf"))
        wei = torch.softmax(wei, dim=-1)
        wei = self.dropout1(wei)

        # Apply to v
        act = wei @ v # (B,nh,T,T) @ (B,nh,T,H) -> (B,nh,T,H)
        act = act.transpose(-3, -2) # B,T,nh,H
        act = act.reshape(B, T, self.n_heads * self.head_dim)

        # Transform to emb_dim and skip connection
        act = self.proj(act) # (B, T,C)
        act = x + act

        # Think and skip connections
        ln_act = self.ln2(act)
        out = self.ffn(ln_act) # (B,T,C)
        out = x + out # x shape (B,T,C)

        return out


class NanoGPT(nn.Module):
    def __init__(self, vocab_size, block_size, emb_dim, n_layers, n_heads, head_dim, dropout, device):
        super().__init__()

        # helper variables
        self.block_size = block_size
        self.device = device

        # Embedding lookup table
        self.token_embbeding_table = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding_table = nn.Embedding(block_size, emb_dim)

        # MHA head
        self.MHA = nn.Sequential(*[MHA(emb_dim, block_size, n_heads, head_dim, dropout) for _ in range(n_layers)])

        # Layernorm
        self.ln = nn.LayerNorm(emb_dim)

        # final linear layer
        self.lm_layer = nn.Linear(emb_dim, vocab_size)

        # init weights
        self.apply(self._init_weights)

        print(f"Number of parameters: {sum([p.numel() for p in self.parameters()])}")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, x, targets=None):
        # x shape (B, T)
        B, T = x.shape
        
        token_emb = self.token_embbeding_table(x)
        pos_emb = self.position_embedding_table(torch.arange(0, T).to(self.device))
        emb = token_emb + pos_emb

        emb = self.MHA(emb)
        emb = self.ln(emb)
        logits = self.lm_layer(emb) # (B, T, V)

        loss = None

        if targets is not None:
            B, T, V = logits.shape
            loss = F.cross_entropy(logits.view(B*T, V), targets.view(B*T))

        return logits, loss
    
    def generate(self, max_tokens=1000):
        with torch.no_grad():
            cur_window, idx_list = torch.LongTensor([[0]]).to(self.device), [0] # (1, 1)

            for i in range(max_tokens):
                cur_window = cur_window[:, -self.block_size:] # (1, B)
                logits, _ = self.forward(cur_window) # (1,B,V)
                probs = torch.softmax(logits, dim=-1).squeeze(dim=0) # (B,V)
                idx = torch.multinomial(probs, num_samples=1, replacement=True)[-1].item()
                cur_window = torch.concat([cur_window, torch.LongTensor([[idx]]).view(1, 1).to(self.device)], dim=-1)
                idx_list.append(idx)

            generated_text = decode(idx_list)

            return generated_text

In [93]:
def get_batch(tokens, block_size, batch_size):
    batch = torch.randint(0, len(tokens)-block_size, (batch_size,)) # B dimension array of random indices
    Xb = torch.stack([torch.LongTensor(tokens[i:i+block_size]) for i in batch], dim=0) # Create (B, T) dimension array
    yb = torch.stack([torch.LongTensor(tokens[i+1:i+block_size+1]) for i in batch], dim=0) # Create (B, T) dimension array
    return Xb, yb

In [94]:
@torch.no_grad()
def compute_loss(tokens, block_size, batch_size, model, device):
    loss_values = []
    for _ in range(100):
        Xb, yb = get_batch(tokens, block_size, batch_size)
        Xb, yb = Xb.to(device), yb.to(device)

        _, loss = model(Xb, yb)
        loss_values.append(loss.item())

    mean_loss = torch.FloatTensor(loss_values).mean().item()
    return mean_loss

In [110]:
def train(train_tokens, val_tokens, model, optimizer, device, block_size, batch_size, n_iters, eval_interval):
    train_lossi, val_lossi = [], []

    for i in range(n_iters):
        model.train()
        Xb, yb = get_batch(train_tokens, block_size, batch_size)
        Xb, yb = Xb.to(device), yb.to(device)

        # forward
        _, loss = model(Xb, yb)

        # set grads to zero
        optimizer.zero_grad(set_to_none=True)

        # do backward
        loss.backward()

        # optimizer step
        optimizer.step()

        if (i % eval_interval == 0) or (i == n_iters - 1):
            model.eval()
            train_loss = compute_loss(train_tokens, block_size, batch_size, model, device)
            val_loss = compute_loss(val_tokens, block_size, batch_size, model, device)

            train_lossi.append(train_loss)
            val_lossi.append(val_loss)

            print(f"Step {i}/{n_iters} --> Train: {train_loss:.4f} | Val: {val_loss:.4f}")

        # break

    return train_lossi, val_lossi

In [111]:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
n_iters = 5000
eval_interval = n_iters//10
lr = 3e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
emb_dim = 32
n_heads = 4
head_dim = emb_dim // n_heads
n_layers = 1
dropout = 0.2
vocab_size = len(stoi)

In [112]:
model = NanoGPT(emb_dim=emb_dim, vocab_size=vocab_size, block_size=block_size, n_heads=n_heads,\
                 n_layers=n_layers, head_dim=head_dim, device=device)
model = model.to(device)

Number of parameters: 17153


In [113]:
optimizer = optim.AdamW(model.parameters(), lr=lr)

In [114]:
train_lossi, val_lossi = train(train_tokens=train_tokens, val_tokens=val_tokens, model=model, optimizer=optimizer,\
      device=device, block_size=block_size, batch_size=batch_size, n_iters=n_iters, eval_interval=eval_interval)

Step 0/5000 --> Train: 4.1585 | Val: 4.1585
Step 500/5000 --> Train: 2.5426 | Val: 2.5566
Step 1000/5000 --> Train: 2.3062 | Val: 2.3372
Step 1500/5000 --> Train: 2.2060 | Val: 2.2385
Step 2000/5000 --> Train: 2.1434 | Val: 2.2008
Step 2500/5000 --> Train: 2.0995 | Val: 2.1563
Step 3000/5000 --> Train: 2.0762 | Val: 2.1338
Step 3500/5000 --> Train: 2.0497 | Val: 2.1164
Step 4000/5000 --> Train: 2.0393 | Val: 2.1098
Step 4500/5000 --> Train: 2.0223 | Val: 2.0954
Step 4999/5000 --> Train: 1.9987 | Val: 2.0755


In [116]:
model.eval()

NanoGPT(
  (token_embbeding_table): Embedding(65, 32)
  (position_embedding_table): Embedding(8, 32)
  (MHA): Sequential(
    (0): MHA(
      (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (c_proj): Linear(in_features=32, out_features=96, bias=False)
      (proj): Linear(in_features=32, out_features=32, bias=True)
      (ln2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (linear): Sequential(
        (0): Linear(in_features=32, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=32, bias=True)
      )
    )
  )
  (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (lm_layer): Linear(in_features=32, out_features=65, bias=True)
)

In [117]:
display(Markdown(model.generate()))


MEO:
Theer and
Thichfere that the one modd on,
But rear twill sity of roman have thampter, hild
Men of bureds to notine,
And wouldeordingsir my rel p, comuseds of'er shalion have.

ICK:
Nobloy lar me sonce,
And loblover my lod it a well sty shie it shad knis to
Your , thins, of or seent: shink, wherech over
Frow may.

LETERWIOP:
He our
banknom a to the marie!
On stall oldy hosed thy to nower:
To the dow! in in cespeeath fire it fady affounple thaslo shourder'd blodds, wentle's alle your there?

CIUSSARn ferven his Qso loves humplef it trove gaarde Whater dony.
Rame't,
Tubrod, this with a will'd stram andcom: what Murght me nant the a do thorse brothou to mart
Wath and hathnire vere the; do by sigued to thoum oun of'end.

When lagiveVer beove vaw
Yenter
Shat awam sirtime scrows
Sirth
Bitspliding this hold nided.
Achorloofer one
Thing thore hath dey
But with sone in thatrung be a rean's sharignasule refortoods, and flatip inst, junding hereatined whould me cond, ight be.
But my you vour 