# n-gram language model

An [n-gram language model](https://en.wikipedia.org/wiki/Word_n-gram_language_model) is a statistical model of language.
It assumes that the probability of the next token in a sequence depends on a fixed size window (a.k.a. context) of previous tokens.

For example, in a trigram ($n = 3$) model, the [likelihood](https://en.wikipedia.org/wiki/Likelihood_function) of observing the sentence $X_1 \, X_2 \cdots X_T$ is
$$
\mathbb{P}(X_1 \, X_2 \cdots X_T)
= \prod_{t = 1}^T \mathbb{P}(X_t \mid X_{t - 2} \, X_{t - 1})
$$
where, typically, $X_t$ is assigned a placeholder value (e.g., a null or start of sentence token) whenever $t \leq 0$.
The logarithm of the likelihood is
$$
\log \mathbb{P}(X_1 \, X_2 \cdots X_T)
\propto \frac{1}{T} \sum_{t = 1}^T \log \mathbb{P}(X_t \mid X_{t - 2} \, X_{t - 1})
$$
which we recognize as the [cross-entropy](https://parsiad.ca/blog/2023/motivating_the_cross_entropy_loss).
We allow each probability $p_{x,x^\prime,x^{\prime\prime}} \equiv \mathbb{P}(x \mid x^\prime \, x^{\prime\prime})$ to be a distinct parameter of the model.
In this case, letting $V$ denote the set of tokens (a.k.a. the vocabulary), the trigram model has $|V|^3$ parameters.
An efficient way to compute a [maxium likelihood estimator](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) for $p$ is by [counting n-grams](https://en.wikipedia.org/wiki/Word_n-gram_language_model#Approximation_method).

However, one could also approximate this estimator by performing gradient ascent on the cross entropy.
This notebook uses Micrograd++ to do just that.
While **this approach is not efficient**, it is useful in that it demonstrates how a more complicated language model *without* a closed form solution (e.g., [recurrent neural networks](https://en.wikipedia.org/wiki/Recurrent_neural_network) and [large language models](https://en.wikipedia.org/wiki/Large_language_model)) can be learned by maximizing cross entropy iteratively.

In [1]:
import micrograd_pp as mpp
import numpy as np
import numpy.typing as npt
import scipy.special

In [2]:
BATCH_SIZE = 32
CONTEXT_WIDTH = 2  # Trigram
NUM_ITERS = 1_000_000
TRAIN_FRAC = 0.9

In [3]:
text = mpp.datasets.load_tiny_shakespeare()

In [4]:
vocab = sorted(set(text))

char2token = {char: token for token, char in enumerate(vocab)}
all_tokens = np.array([char2token[char] for char in text], dtype=np.int32)

first_val_index = int(TRAIN_FRAC * all_tokens.size)
train_tokens = all_tokens[:first_val_index]
val_tokens = all_tokens[first_val_index:]

def base_expand(context: npt.NDArray) -> npt.NDArray:
    """Convert a context window of tokens into a single token."""
    c = len(vocab)**np.arange(CONTEXT_WIDTH)
    return (c * context).sum(axis=-1)

def loss(embedding: mpp.Embedding, val: bool = True) -> mpp.Expr:
    """Compute loss on a random training batch or the validation set."""
    if val:
        user_data = val_tokens
        indices = np.arange(start=CONTEXT_WIDTH, stop=val_tokens.size)
    else:
        user_data = train_tokens
        indices = np.random.randint(low=CONTEXT_WIDTH, high=train_tokens.size, size=(BATCH_SIZE,))
    x = np.stack([user_data[index - CONTEXT_WIDTH:index] for index in indices])  # (B, C)
    y = user_data[indices]  # (B,)
    logits = embedding(base_expand(x))
    return mpp.cross_entropy_loss(logits, y)

def generate_sentence(embedding: mpp.Embedding, init: npt.NDArray | None = None, length: int = 64) -> str:
    """Use a learned embedding to generate a sentence."""
    if init is None:
        init = np.zeros((CONTEXT_WIDTH,), dtype=np.int32)
    context = init
    tokens = context.tolist()
    for _ in range(length):
        logits = embedding(base_expand(context[np.newaxis, ...]))
        pvals = scipy.special.softmax(logits.value.squeeze())
        token = np.random.multinomial(n=1, pvals=pvals).argmax().item()
        context[:-1] = context[1:]
        context[-1] = token
        tokens.append(token)
    return ''.join(vocab[token] for token in tokens)

In [5]:
np.random.seed(0)
embedding = mpp.Embedding(num_embeddings=len(vocab)**CONTEXT_WIDTH, embedding_dim=len(vocab))

with mpp.eval(), mpp.no_grad():
    print(f"""
Uninitialized Embedding
-----------------------
Loss: {loss(embedding).value.item()}
Random sentence: {generate_sentence(embedding)}
""")


Uninitialized Embedding
-----------------------
Loss: 4.618406097403782
Random sentence: 

? lNtAp.'ZcS-Um
US:I!X.DC&VTej:XdX'QVMw3IK Fkv?rvkLnVqFZC
lB&TC$



In [6]:
opt = mpp.SGD(lr=1.0)

n = 0
while True:
    if n % (NUM_ITERS // 10) == 0:
        with mpp.eval(), mpp.no_grad():
            print(f"""
Iteration {n:8d}
------------------
Loss: {loss(embedding).value.item()}
Random sentence: {generate_sentence(embedding)}
""")

    if n >= NUM_ITERS:
        break

    loss(embedding=embedding, val=False).backward(opt=opt)
    opt.step()

    n += 1


Iteration        0
------------------
Loss: 4.618406097403782
Random sentence: 

H,WIJa-Wk&'XXFCIq?lbJCBT?'XtDf-kW-Grq&FLIRpPz'tjY3Tpc.jLUXk loOn


Iteration   100000
------------------
Loss: 2.171994386892817
Random sentence: 

And myumbed VI:
WhourtR;Az'xMpU?gzXs
God, yoused
iNA:
HER:
O ban


Iteration   200000
------------------
Loss: 2.116169136143967
Random sentence: 

LEONTER:
POld come muchat Pome
ROKE Ennow
Wheithey ord's fie ast


Iteration   300000
------------------
Loss: 2.0921984053229843
Random sentence: 

so peaver, bod
What fore,
Nown,
That rall daun he berce wor the 


Iteration   400000
------------------
Loss: 2.085089184457391
Random sentence: 

Ex&
$MKFRIANNENTES:
Shat shose,
That,
My se her ass of smadeectu


Iteration   500000
------------------
Loss: 2.0779833496975297
Random sentence: 

Bol, wou wour eato clovirseend.

My oad to my bet your beam not 


Iteration   600000
------------------
Loss: 2.0715481547062122
Random sentence: 

CLAUNTISsHlp were enap?yl's