# Transformer

This notebook trains a decoder-only transformer to perform next token prediction on the Tiny Shakespeare dataset.

The parameters are chosen to make training fast (e.g., short context) and will not produce a performant model.
The interested reader should feel free to alter these to produce a larger model.

In [1]:
import micrograd_pp as mpp
import numpy.typing as npt

np = mpp.numpy  # `import numpy as np` will not work when using GPU

In [2]:
CONTEXT_WIDTH = 32
DROPOUT = 0.0
EMBEDDING_DIM = 128
EVAL_FREQ = 3_000
HIDDEN_SIZE = EMBEDDING_DIM * 4
LEARNING_RATE = 0.1
NUM_BLOCKS = 3
NUM_HEADS = 4
NUM_ITERS = 30_000
TRAIN_BATCH_SIZE = 96
TRAIN_FRAC = 0.99
VAL_BATCH_SIZE = 4_096

In [3]:
text = mpp.datasets.load_tiny_shakespeare()

In [4]:
vocab = sorted(set(text))
vocab_size = len(vocab)

char2token = {char: token for token, char in enumerate(vocab)}
all_tokens = np.array([char2token[char] for char in text], dtype=np.int32)

first_val_index = int(TRAIN_FRAC * all_tokens.size)
train_tokens = all_tokens[:first_val_index]
val_tokens = all_tokens[first_val_index:]

In [5]:
class Block:
    def __init__(self) -> None:
        self._ln1 = mpp.LayerNorm(EMBEDDING_DIM)
        self._attn = mpp.MultiheadAttention(
            embed_dim=EMBEDDING_DIM,
            num_heads=NUM_HEADS,
            batch_first=True
        )
        attn_mask_np = np.zeros((CONTEXT_WIDTH, CONTEXT_WIDTH))
        attn_mask_np[np.triu_indices_from(attn_mask_np, k=1)] = -np.inf
        self._attn_mask = mpp.Constant(attn_mask_np)
        self._dropout = mpp.Dropout(DROPOUT)
        self._ln2 = mpp.LayerNorm(EMBEDDING_DIM)
        self._ff = mpp.Sequential(
            mpp.Linear(in_features=EMBEDDING_DIM, out_features=HIDDEN_SIZE),
            mpp.ReLU(),
            mpp.Linear(in_features=HIDDEN_SIZE, out_features=EMBEDDING_DIM),
            mpp.Dropout(DROPOUT),
        )

    def __call__(
        self,
        x: mpp.Expr  # (N, L, E)
    ) -> mpp.Expr:
        x = self._ln1(x)
        x = x + self._dropout(self._attn(x, x, x, attn_mask=self._attn_mask)[0])
        x = self._ln2(x)
        x = x + self._ff(x)
        return x  # (N, L, E)

class DecoderOnlyTransformer:
    def __init__(self) -> None:
        self._tok_embedding = mpp.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBEDDING_DIM,
            label="token_embedding",
        )
        self._pos_embedding = mpp.Embedding(
            num_embeddings=CONTEXT_WIDTH,
            embedding_dim=EMBEDDING_DIM,
            label="positional_embedding",
        )
        self._blocks = mpp.Sequential(*[Block() for _ in range(NUM_BLOCKS)])
        self._ln = mpp.LayerNorm(EMBEDDING_DIM)
        self._output_proj = mpp.Linear(
            in_features=EMBEDDING_DIM,
            out_features=vocab_size,
            label="output_projection",
        )

    def __call__(
        self,
        tokens: npt.NDArray,  # (N, L)
    ) -> None:
        t = self._tok_embedding(tokens)  # (N, L, E)
        p = self._pos_embedding(np.arange(CONTEXT_WIDTH))  # (L, E)
        x = t + p  # (N, L, E)
        x = self._blocks(x)  # (N, L, E)
        x = self._ln(x)  # (N, L, E)
        return self._output_proj(x)  # (N, L, V)

def loss(model: mpp.Module, indices: npt.NDArray, user_data: npt.NDArray) -> mpp.Expr:
    """Compute loss on a batch."""
    x = np.stack([user_data[index - CONTEXT_WIDTH    :index    ] for index in indices])  # (N, L)
    y = np.stack([user_data[index - CONTEXT_WIDTH + 1:index + 1] for index in indices])  # (N, L)
    yhat = model(x).reshape((-1, vocab_size))  # (N * L, V)
    y = y.reshape(-1)  # (N * L,)
    return mpp.cross_entropy_loss(yhat, y)

def train_loss(model: mpp.Module) -> mpp.Expr:
    """Compute loss on a random batch from the training set."""
    indices = np.random.randint(low=CONTEXT_WIDTH, high=train_tokens.size, size=(TRAIN_BATCH_SIZE,))
    return loss(model=model, indices=indices, user_data=train_tokens)

def val_loss(model: mpp.Module) -> npt.NDArray:
    """Approximate loss on the validation set."""
    losses = []
    with mpp.eval(), mpp.no_grad():
        n = 0
        low = CONTEXT_WIDTH
        while low < val_tokens.size:
            high = min(low + VAL_BATCH_SIZE, val_tokens.size)
            indices = np.arange(low, high)
            item = loss(model=model, indices=indices, user_data=val_tokens).value
            losses.append(item)
            low = high
        return np.array(losses).mean()

def generate_sentence(model: mpp.Module, init: npt.NDArray | None = None, length: int = 512) -> str:
    """Use a learned decoder-only transformer to generate a sentence."""
    with mpp.eval(), mpp.no_grad():
        if init is None:
            init = np.zeros((CONTEXT_WIDTH,), dtype=np.int32)
        context = init
        tokens = []
        for _ in range(length):
            logits = model(context.reshape(1, -1))
            pvals = mpp.softmax(logits, dim=-1)[0, -1, :]
            token = np.random.multinomial(n=1, pvals=pvals.value).argmax().item()
            context[:-1] = context[1:]
            context[-1] = token
            tokens.append(token)
        return ''.join(vocab[token] for token in tokens)

In [6]:
np.random.seed(0)
model = DecoderOnlyTransformer()

print(f"""
Uninitialized Embedding
-----------------------
Loss: {val_loss(model).item()}
Random sentence: {generate_sentence(model)}
""")


Uninitialized Embedding
-----------------------
Loss: 4.616665830144252
Random sentence: wWm
!YuG.b3xgXgfuNcXiw,PGmOW,xnxg.xgufo,xZTfiwLmXlpLWnm.g
gQxBzX,
!''c,cTBwgH,YfYw,?.gw.m'w'uYT,XRgmt nXX
kIhLD
EHr?JgOXC,!mztx,,,RakcYlZuXc,G?ZozU?wgu'z,?wg
z
Zg.'PXNT?xgN'XKiwwxmR,gJD,gGs Xbfw&R,XG.q.lfz,,
o,u'wWZLLzFSXGmnUTSw
LXXLLnzBcaiwXRguuFcggYX,;'ohcLR cpcYQx,gg&!IRiR,XsxxLwwh,tgwgHmkjYu
Gc,XbXB,ifxBuG,mpfg$&xm,g?RULmJlwU'XB'XbLwwjfB
,Wug&uug
B,LcX&N&XX
ByX,l! 
k
D?Vr,l?HP,HcZzR?fu,iXYuufwx,tac,gDitN$!!m
uKtDqhm.ZR
LmTfx,bu''bf
O
ufOKGb-Ykm,m!W,Bu,L TcotB X'wccuaigl,$WgLBwBbfRXYlBJ&WmZugvXF&KXX?G
.z



In [7]:
opt = mpp.SGD(lr=LEARNING_RATE)

n = 0
while True:
    if n % EVAL_FREQ == 0:
        print(f"""
Iteration {n:8d}
------------------
Loss: {val_loss(model).item()}
Random sentence: {generate_sentence(model)}
""")

    if n >= NUM_ITERS:
        break

    train_loss(model).backward(opt=opt)
    opt.step()

    n += 1


Iteration        0
------------------
Loss: 4.616665830144252
Random sentence: wG'swRgmGRDiwGYwmalYu g?BYcX
XlixXx'zJmXc.,I-MuUmuXgYmzN,xLgqNu.wl,cXm,?mGK,uK&Xbo,fRzXckuWXBfBuK,K,3wgguGtG$mZx,uLgex
'oK',kW
Mcm
u!fwu.muabuDhxgR'fzYlNDXP
guX,xG,BQ,Rdf,YXV,cTg'B,$x,zv.
wK$wz&BuldX$fWg
,KBfBuYli
G $mXguu,NiX,uTBwLmu:B.ww$vwS,mXxuB,WRDViRiDGlVPt-XmG,'vTBR,NKgm'Z
ZbumLXKw,?D,,mGLz yhqHjwmx,eG3$bStFRZtXlRi3fdfumUP'cVFg,BK&iKiLlxgqko,OTx,ntq,RKXYufVknlKn,MVyO'O'w-mZTKiLxB Yu'UIUBfQuwwWK.oG $iczmufU'wrG',xBKivwmRBilCjFjhBPwePuWg&3,Ml$l,mxg
Nhwov,gLg$B;Duc&X'CHwwtG!,DitXDt.w-aoRHXgge,nzVDXP&uUw


Iteration     3000
------------------
Loss: 2.0463488263845897
Random sentence: CRUMINAMELELVOCENTA:
Our I hangir, shall waing I befiste it unnt: en not hem ben starome not to for elier plese,
If bese, I'n cupasely tend ald thie, patiledn.
Pear at wit Fram wh me som kinequess. I bun stou!
Sin to that morrie met the was heasol un Ve,
Gry, Ratings:
'Sels for, sickinch, baddled kith
Tild bruck bemp-manes 