# MiniGPT LLM

In [None]:
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch

from nimrod.text.tokenizers import CharTokenizer
from nimrod.text.datasets import SimpleCharDataset
from nimrod.models.transformer import TransformerBlock

from pathlib import Path
from dataclasses import dataclass, asdict

# Data

In [None]:
text = Path('../data/text/tiny_shakespeare.txt').read_text()
print(text[:25])
tok = CharTokenizer.from_text(text)
print(f"vocabulary size: {len(tok)}")
encoded = tok.encode(text)
print(encoded[:10])
ds = SimpleCharDataset(encoded, 10)
print(f"length of data str: {len(ds)}")
n_train = int(0.8 * len(ds))
n_eval = int(0.1 * len(ds))
n_test = len(ds) - n_train - n_eval
train_ds, eval_ds, test_ds = random_split(ds, lengths=(n_train, n_eval, n_test))
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
eval_dl = DataLoader(train_ds, batch_size=32, shuffle=False)
test_dl = DataLoader(train_ds, batch_size=32, shuffle=False)

First Citizen:
Before we 
vocabulary size: 65
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])
length of data str: 1115384


## Model

In [None]:
class MiniGPT(nn.Module):
    def __init__(
        self,
        vocab_size:int,
        embed_dim:int,
        n_head:int,
        context_length:int,
        dropout:float,
        n_blocks:int
    ):

        super().__init__()
        self.context_length = context_length
        self.tok_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(context_length, embed_dim)
        block_params = [embed_dim, n_head, context_length, dropout]
        self.blocks = [TransformerBlock(*block_params) for _ in range(n_blocks)]
        self.blocks = nn.Sequential(*self.blocks)
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(
        self,
        x:torch.LongTensor # (B,T) list of token IDS
        ):
        B,T = x.shape
        tok_embed = self.tok_embed(x) # B, T, C
        pos_embed = self.pos_embed(torch.arange(T, device=x.device)) # B, T, C
        x = tok_embed + pos_embed
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.head(x)
        return logits

    def generate(
        self,
        x:torch.LongTensor, # tok ids (B, T)
        max_tokens: int # max size of sequence gen
    ):  
        self.eval()
        for i in range(max_tokens):
            if x.size(1) >= self.context_length:
                break
            with torch.no_grad():
                logits = self.forward(x) # (B, T, n_classes)
                # look at last time step only
                logits = logits[:, -1, :]
                probs = torch.softmax(logits, dim=-1)
                pred = torch.multinomial(probs, num_samples=1)
                x = torch.cat([x, pred], dim=1) #(T+1)
        return x



### Config

In [None]:
@dataclass
class GPTConfig:
    vocab_size:int
    embed_dim:int = 16
    n_head:int = 2
    context_length:int = 10
    n_blocks:int = 2
    dropout:float = 0.1


    def __post_init__(self):
        assert self.embed_dim % self.n_head == 0, "embed_dim must be divisible by n_head"
        assert self.dropout >= 0 and self.dropout <= 1, "dropout must be between 0 and 1"

cfg = GPTConfig(vocab_size=len(tok))

In [None]:
m = MiniGPT(**asdict(cfg))
print(m)

MiniGPT(
  (tok_embed): Embedding(65, 768)
  (pos_embed): Embedding(32, 768)
  (blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-3): 4 x AttentionHead(
            (key): Linear(in_features=768, out_features=192, bias=False)
            (query): Linear(in_features=768, out_features=192, bias=False)
            (value): Linear(in_features=768, out_features=192, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): ReLU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
    

### Usage

In [None]:
device = 'cpu'
B, T, C = 5, cfg.context_length, cfg.embed_dim
x = torch.randint(0, cfg.vocab_size, (B,T), dtype=torch.long).to(device)
m = m.to(device)
y = m(x) # B,T, n_classes
print(y.shape)

torch.Size([5, 32, 65])


In [None]:
prompt = "hello!"
encoded = tok.encode(prompt).unsqueeze(dim=0)
print(encoded)
res = m.generate(encoded, max_tokens=25)
print(res)
print(tok.decode(res.squeeze()))

tensor([[46, 43, 50, 50, 53,  2]])
tensor([[46, 43, 50, 50, 53,  2, 59, 11, 41, 34,  7,  4,  8, 63, 16, 30, 18, 32,
         34, 58, 51, 54, 61, 40, 26, 45,  6, 22, 16, 64, 26]])
hello!u;cV-&.yDRFTVtmpwbNg,JDzN


## Training Loop

In [None]:
device = 'mps'
cfg = GPTConfig(vocab_size=len(tok))
cfg.embed_dim:int = 8
cfg.n_head:int = 2
cfg.context_length:int = 10
cfg.n_blocks:int = 1

m = MiniGPT(**asdict(cfg)).to(device)

n_epochs = 1
lr = 3e-4

criterion = nn.CrossEntropyLoss()
optim = torch.optim.AdamW(m.parameters(), lr=lr)

for epoch in range(n_epochs):
    m.train()
    running_loss = 0
    for step, batch in enumerate(eval_dl):
        optim.zero_grad()
        x, y = batch
        x,y = x.to(device), y.to(device)
        B, T = x.shape
        logits = m(x) # B, T, V
        B,T,V = logits.shape
        logits = logits.view(B*T, V)
        y = y.view(B*T)
        loss = criterion(logits, y)
        running_loss += loss.item()
        if not step%1000:
            print(loss.item())
        loss.backward()
        optim.step()
    print(f"epoch: {epoch} loss: {running_loss/len(eval_dl)}")
        


4.290432929992676
3.3735451698303223
3.000006675720215
2.966372013092041
2.8455679416656494
2.646514415740967
2.8075928688049316
2.6514816284179688
2.506645917892456
2.5771987438201904
2.5555624961853027
2.4803662300109863
2.600022554397583
2.507206439971924
2.5301594734191895
2.5535435676574707
2.4812862873077393
2.5459506511688232
2.4723029136657715
2.5209922790527344
2.545898914337158
2.502852439880371
2.5002646446228027
2.502060890197754
2.447946071624756
2.347792387008667
2.389284610748291
2.4702415466308594
epoch: 0 loss: 2.619110928211072


## Test

In [None]:
prompt = "hello!"
encoded = tok.encode(prompt).unsqueeze(dim=0).to(device)
print(encoded)
res = m.generate(encoded, max_tokens=10)
print(res)
print(tok.decode(res.squeeze()))

tensor([[46, 43, 50, 50, 53,  2]], device='mps:0')
tensor([[46, 43, 50, 50, 53,  2,  0, 31, 10,  0]], device='mps:0')
hello!
S:

