![Py4Eng](../logo.png)

# Transformers: character-level language model
## Yoav Ram

We will see here the [**Transformer** architecture](https://en.wikipedia.org/wiki/Transformer_(deep_learning_architecture)).
Transformers are the basis of large language models like OpenAI's [GPT](https://en.wikipedia.org/wiki/Generative_pre-trained_transformer)--the "T" stands for "Transformer".

Here, we apply transformers to the same problem we applied RNN and GRU: text generation by pretraining a character level model.

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import jax 
import jax.numpy as np
print('jax', jax.__version__, jax.default_backend())
import optax # pip install optax

from collections import Counter

jax 0.4.35 cpu


# Data

The data Shakespear's writing as a text.
The characters are converted to integers and then one-hot encoded.

In [2]:
filename = '../data/shakespear.txt'
with open(filename, 'rt') as f:
    text = f.read()

print("Number of characters: {}".format(len(text)))
print("Number of unique characters: {}".format(len(set(text))))
print("Number of lines: {}".format(text.count('\n')))
print("Number of words: {}".format(text.count(' ')))
print()
print("Excerpt:")
print("*" * len("Excerpt:"))
print(text[:500])

Number of characters: 99993
Number of unique characters: 62
Number of lines: 3298
Number of words: 15893

Excerpt:
********
That, poor contempt, or claim'd thou slept so faithful,
I may contrive our father; and, in their defeated queen,
Her flesh broke me and puttance of expedition house,
And in that same that ever I lament this stomach,
And he, nor Butly and my fury, knowing everything
Grew daily ever, his great strength and thought
The bright buds of mine own.

BIONDELLO:
Marry, that it may not pray their patience.'

KING LEAR:
The instant common maid, as we may less be
a brave gentleman and joiner: he that finds u


We start by creating 
- a list `chars` of the unique characters
- `data_size` the number of total characters
- `vocab_size` the number of unique characters
- `int_to_char` a dictionary from index to char
- `char_to_int` a dictionary from char to index
and then we convert `data` from a string to a NumPy array of integers representing the chars.

In [3]:
chars = list(set(text))
data_size, vocab_size = len(text), len(chars)

# char to int and vice versa
int_to_char = dict(enumerate(chars)) #  == { i: ch for i,ch in enumerate(chars) }
char_to_int = dict(zip(int_to_char.values(), int_to_char.keys())) # { ch: i for i,ch in enumerate(chars) }

def onehot_encode(text):
    ints = [char_to_int[c] for c in text]
    ints = np.array(ints, dtype=int)
    return jax.nn.one_hot(ints, vocab_size)

def onehot_decode(data):
    ints = data.argmax(axis=1).tolist()
    chars = (int_to_char[k] for k in ints)
    return str.join('', chars)

X = onehot_encode(text)

# Transformer model
## Self-attention

Previous recurrent and convolution models that we used had limited access to the input. 
Convolutions could only use nearby elements in the input, and recurrent models had to keep a "memory" or previous elements in the input.

In self-attention, this is solved by computing attention weights for each position in the sequence.
First, we compute the **query** $q$ and **key** $k$.
$$
q_i = W^q x_i $$$$
k_j = W^k x_j $$

Then, the **attention weight** of element $x_i$ to element $x_j$ is determined by 
$$
w_{ij} = softmax\left(q_i k_j\right)
$$

We then compute the **value** of position $j$,
$$
v_j = W^v x_j
$$
and so the **context** of position $i$, $z_i$, is given by
$$
z_i = \sum_j{w_{ij} v_j}
$$
The context vector $z$ is then used to compute the **self-attention output**,
$$
\textit{sa} = W^o z
$$
The learnable parameters of self-attention are the matrices $W^q$, $W^k$, $W^v$, and $W^o$.

Another element is the *mask*: we want position $i$ to attend to position $j \le i$ to preserve _causaility_ (just as in RNN, character $i$ depends on $j<i$ but not $j>i$). We implement it using a lower triangular matrix with `np.tril`.

In [4]:
def self_attention(x, W_q, W_k, W_v, W_o):
    Q = x @ W_q # query q
    K = x @ W_k # key k
    V = x @ W_v # value v
    d_k = Q.shape[-1] 
    w_logits = Q @ np.swapaxes(K, -1, -2) / np.sqrt(d_k) # logits of attention w
    # causal mask
    mask = np.tril(np.ones_like(w_logits)) 
    w_logits = w_logits - 1e10 * (1 - mask)
    w = jax.nn.softmax(w_logits, axis=-1) # attention weights
    z = w @ V # context vector - no parameters here!
    sa = z @ W_o # self-attention output
    return sa

## Positional encoding
We also implement positional encoding, which provides the model with information about the order of tokens without learning position embeddings.

Positional encoding here is a fixed, sinusoidal function that assigns each token a unique vector based on its position in the sequence. For a given sequence length and model dimension, it:
- Creates a grid where each row corresponds to a sequence position and each column to a model dimension.
- Computes an "angle" for each position-dimension pair using a frequency term.
- Multiplies the position index by these rates, then applies sin to even-indexed dimensions and cos to odd-indexed ones.
- Concatenates the sine and cosine results to form a `(seq_len, d_model)` encoding matrix.

In [5]:
def positional_encoding(seq_len, d_model):
    pos = np.arange(seq_len)[:, None]
    i = np.arange(d_model)[None, :]
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    angle_rads = pos * angle_rates
    # apply sin to even indices in the array; cos to odd indices
    sines = np.sin(angle_rads[:, 0::2])
    cosines = np.cos(angle_rads[:, 1::2])
    pos_encoding = np.concatenate([sines, cosines], axis=-1)
    return pos_encoding

## Feed-forward

Self-attention works on an embedding of the characters to a Euclidean space, rather than on one-hot encoded vectors, so we start by converting one-hot encoded input to integers.
The embedding allows the self-attention mechanism to operate on continuous representations rather than sparse one-hot vectors. The embedding is implemented using a learned matrix. 

We then:
- add the embedding to the positional encoding
- compute self-attention 
- add a residual connection
- pass the result through a shallow feed forward network
- add a residual connection
- give the result to a softmax model to predict the next character

Note that the self-attention layer and the feed-forward layer take their input after passing it through **layer normalization**, stabilizes training by normalizing activations across the feature dimension, ensuring they have consistent mean and variance. This reduces internal covariate shift, facilitates gradient flow, and helps in training deeper networks. It’s particularly useful in transformers to improve convergence and overall performance.

We also return the logits rather than the probabilities (so don't apply softmax) to improve numerical stability (avoid exponentiating since the loss function will take the log again anyway).

In [12]:
def transformer_model(params, x, pos_enc):
    # convert from one-hot to integers
    x = x.argmax(axis=1)
    # embedding lookup and add positional encoding
    x = params['embedding'][x, :]
    x = x + pos_enc 
    # self-attention
    sa = self_attention(layer_norm(x), params['W_q'], params['W_k'], params['W_v'], params['W_o'])
    # residual connection
    x = x + sa  
    # feed forward network     
    hidden = jax.nn.relu(x @ params['W1'] + params['b1'])
    ff = hidden @ params['W2'] + params['b2']
    # residual connection
    x = x + ff      
    # output prediction - return logits rather than probabilities
    logits = x @ params['W_out'] + params['b_out']
    return logits

def layer_norm(x, eps=1e-6):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.mean((x - mean) ** 2, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

We initialize the parameters by drawing from normal distributions, mostly.

In [13]:
# Parameter initialization
def init_params(key, seq_len, vocab_size, d_model, d_ff):
    keys = jax.random.split(key, 8)
    params = {
        'embedding': jax.random.normal(keys[0], (vocab_size, d_model)) * 0.01,
        'W_q': jax.random.normal(keys[1], (d_model, d_model)) * 0.01,
        'W_k': jax.random.normal(keys[2], (d_model, d_model)) * 0.01,
        'W_v': jax.random.normal(keys[3], (d_model, d_model)) * 0.01,
        'W_o': jax.random.normal(keys[4], (d_model, d_model)) * 0.01,
        'W1': jax.random.normal(keys[5], (d_model, d_ff)) * 0.01,
        'b1': np.zeros((d_ff,)),
        'W2': jax.random.normal(keys[6], (d_ff, d_model)) * 0.01,
        'b2': np.zeros((d_model,)),
        'W_out': jax.random.normal(keys[7], (d_model, vocab_size)) * 0.01,
        'b_out': np.zeros((vocab_size,)),
    }
    return params

key = jax.random.key(0)
seq_len = 25        # attendt to 25 characters
d_model = 128
d_ff = 512

params = init_params(key, seq_len, vocab_size, d_model, d_ff)
pos_enc = positional_encoding(seq_len, d_model)
x = X[:seq_len, :]
logits = transformer_model(params, x, pos_enc)
assert logits.shape == (seq_len, vocab_size)

## Loss function

The loss function is a straightforward categorical cross-entropy.
The only trick is we use `log_softmax` rather than `softmax` so we don't need to take the `log` again.

In [14]:
def NLL(params, x, y, pos_enc):
    logits = transformer_model(params, x, pos_enc)
    log_probs = jax.nn.log_softmax(logits)
    return -np.sum(y * log_probs)

x = X[:seq_len, :]
y = X[1:seq_len+1, :] 
loss = NLL(params, x, y, pos_enc)
print(loss)

103.21503


## Automatic differentiation
Now instead of manually deriving the gradient and implementing it as a Python program, we use JAX's automatic differentiation. [`jax.grad`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#jax-first-transformation-grad) takes a function `f(a, b, c)` and returns a function `dfda(a, b, c)` that returns the gradient of `f` with respect to `a` at the values of `a`, `b`, and `c`. It does so by automating the procedure we did manually using the chain rule.

In our case, `f` is `NLL`, `a` is `params`, and `b` and `c` are `x` and `y`, that is, we use `grad` on `NLL(params, x, y)` to get `backprop(params, x, y)`.

The function [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#value-and-grad) is used to return both `f(a,b,c)` (the "value") and the `dfda` (the "grad"). 
Finally, `has_aux` means that `f` return two values - the value that needs to be differentiated, and an auxillary value. In our case, the value to differentiate is `loss` and the auxillary is `h`. This is important because we need to keep track of `h` and `loss`.

In [15]:
backprop = jax.value_and_grad(NLL)

loss, grads = backprop(params, x, y, pos_enc)
for k in params:
    assert params[k].shape == grads[k].shape
    assert not (grads[k] == 0).all()

# Adam optimizer

We use a JAX implementation of the Adam optimizer from the [Optax](https://optax.readthedocs.io/) library.
We first create the optimizer and initialize its state.

In [16]:
optimizer = optax.adam(learning_rate=0.001) # 0.001 is the default from Kingma et al 2014
opt_state = optimizer.init(params)

We then use the optimizer to compute the updates, and apply them.

In [17]:
loss, grads = backprop(params, x, y, pos_enc)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates) 

# JITing the training step

We write a function that does all this, and pass it to `jax.jit`, which [just-in-time compiles the function](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) so it can be executed efficiently in XLA.

In [18]:
@jax.jit # decreases runtime from 380 ms to <1 ms!
def update_params(params, opt_state, x, y, pos_enc):
    loss, grads = backprop(params, x, y, pos_enc)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [19]:
%timeit update_params(params, opt_state, x, y, pos_enc)

730 μs ± 1.11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [20]:
params, opt_state, loss = update_params(params, opt_state, x, y, pos_enc)
print(loss)
params, opt_state, loss = update_params(params, opt_state, x, y, pos_enc)
print(loss)

100.29582
97.1355


# Sampling from the network

Finally, instead of a `predict` function, we have a `sample` function, which, given the parameters and the number of samples we want, produces a sample of text from the network.

It does so by drawing a random seed for $x_0$ and drawing $x_t$ for $t>0$ from the distribution given by $\widehat y_t$, which is the softmax of the transformer output.

Note that this function is computationally heavy as it iterates over each character position.

In [26]:
def sample(params, num_samples, key):    
    x = np.zeros((num_samples, vocab_size), dtype=float)
    pos_enc = positional_encoding(num_samples, d_model)
    keys = jax.random.split(key, num_samples)
    seed_char = jax.random.choice(keys[0], vocab_size)
    x = x.at[0, seed_char].set(1)
    for t in range(1, num_samples):
        logits = transformer_model(params, x[:t], pos_enc[:t])[t-1]
        yhat = jax.nn.softmax(logits)
        # draw from output distribution        
        i = jax.random.choice(keys[t], vocab_size, p=yhat)
        x = x.at[t, i].set(1)
    return onehot_decode(x)

print(sample(params, 100, jax.random.key(1)))

h 
   iu er nnttfan h h    e niLtei pQerfe
uyvq neqmr xeasfan f fnd ahfempe nlocmsYitghPib   -s?ooeH


# Training the network

We setup the training.

In [30]:
max_batches = 10000000
pos = 0
batch = 0 
losses = []
key = jax.random.key(86)

seq_len = 25 # unrolling for 25 characters
d_model = 128
d_ff = 512

params = init_params(key, seq_len, vocab_size, d_model, d_ff)
pos_enc = positional_encoding(seq_len, d_model)

optimizer = optax.adam(learning_rate=0.001) # you can try with 0.01
opt_state = optimizer.init(params)

Now we can train the transformer model.

In [31]:
%%time
while batch <= max_batches:
    if pos + seq_len + 1 >= data_size:
        # reset data position and hidden state
        pos = 0
        
    x = X[pos : pos + seq_len]
    y = X[pos + 1 : pos + seq_len + 1]
    pos += seq_len
    
    params, opt_state, loss = update_params(params, opt_state, x, y, pos_enc)
    losses.append(loss)
    
    if batch % (max_batches // 10) == 0:
        print('batch {:d}, loss {:.6f}, pos {}'.format(batch, loss, pos))
        print()
        
        key, subkey = jax.random.split(key)        
        sample_text = sample(params, 50, subkey)
        print(sample_text)
        print('-'*80)
    batch += 1

batch 0, loss 102.919846, pos 25

lfbGIPBjXMkmYXJbLZZexlTj-KFdsI; xu'D,H,R;GAp,Nd,wg
--------------------------------------------------------------------------------
batch 1000000, loss 52.248795, pos 6275

n I mot, begh swommu;
And kid usoweranond merustow
--------------------------------------------------------------------------------
batch 2000000, loss 54.052490, pos 12525

k menses speed, paremet he rutheasoukiswevoukerali
--------------------------------------------------------------------------------
batch 3000000, loss 44.485901, pos 18775

lood wilt:
I weept, ing inde ba gh, in is wr tleso
--------------------------------------------------------------------------------
batch 4000000, loss 44.551559, pos 25025

QUEETH:
Madd,
Thavencearror helay owin.
T,
Th rrif
--------------------------------------------------------------------------------
batch 5000000, loss 50.127365, pos 31275

His thip
Welf did
Rome my t meaknoat, mystrourasee
----------------------------------------------

# References

- [Vaswani et al. 2017](http://arxiv.org/abs/1706.03762): _Attention Is All You Need_, the fundamental paper on transformers.


# Colophon
This notebook was written by [Yoav Ram](http://python.yoavram.com).

This work is licensed under a [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) International License.

![Python logo](https://www.python.org/static/community_logos/python-logo.png)