In [None]:
!pip install equinox datasets sentencepiece

In [2]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import functools as ft
from jax import lax
import numpy as np

In [3]:
from datasets import load_dataset
import sentencepiece as spm

In [4]:
dataset = load_dataset("Trelis/tiny-shakespeare", split="train")

In [5]:
import re

SPLIT_PATTERN = re.compile("[\\n]{2,}")
def preprocess(seq):
  seq["Text"] = SPLIT_PATTERN.split(seq["Text"])
  return seq

In [6]:
dataset = dataset.map(preprocess)

In [7]:
VOCAB=1000
SEQ_LEN=56

In [8]:
spm.SentencePieceTrainer.train(sentence_iterator=(seq for row in dataset["Text"] for seq in row),
                               model_prefix="tiny-shakespeare",
                               model_type="bpe",
                               vocab_size=VOCAB,
                               pad_id=0,
                               unk_id=1,
                               bos_id=2,
                               eos_id=3
                               )

In [9]:
tokenizer = spm.SentencePieceProcessor("/content/tiny-shakespeare.model")

In [10]:
MODEL_DIM = EMB_SIZE = 96

In [11]:
positions = jnp.arange(0, SEQ_LEN)
positions = positions / (1000 ** ((2*positions)/MODEL_DIM))
positional_encodings = jnp.where(jnp.array([True if i%2 == 0 else False for i in range(SEQ_LEN)]),
                                 jnp.sin(positions),
                                 jnp.cos(positions))

In [12]:
class Linear(eqx.Module):
    weights: jax.Array
    bias: jax.Array

    def __init__(self, key, nin, nout):
        init = jax.nn.initializers.he_uniform()
        self.weights = init(key=key, shape=(nin, nout))
        self.bias = jnp.ones(nout)

    @eqx.filter_jit
    def __call__(self, x):
        return x @ self.weights + self.bias

In [13]:
class FFNN(eqx.Module):
    layers: list
    def __init__(self, key, nin, nout, nhidden, n_layers=2):
        keys = jax.random.split(key, num=n_layers)
        layers = [
            Linear(keys[0], nin, nhidden)
        ]
        for i in range(1, n_layers-1):
            layers.append(jax.nn.gelu)
            layers.append(Linear(keys[i], nhidden, nhidden))
        if n_layers != 1:
            layers.append(Linear(keys[-1], nhidden, nout))
        self.layers = layers

    @eqx.filter_jit
    def __call__(self, x):
        for i in range(len(self.layers)):
            x = self.layers[i](x)
        return x

In [14]:
class SelfAttention(eqx.Module):
    @eqx.filter_jit
    def __call__(self, query, key, value, mask):
        scaled_dot_prod = query @ jnp.transpose(key, (0, 2, 1)) / jnp.sqrt(query.shape[-1])
        scaled_dot_prod = mask + scaled_dot_prod
        return (jax.nn.softmax(scaled_dot_prod) @ value)

In [15]:
class MultiHeadAttention(eqx.Module):
    wquery: jax.Array
    wkey: jax.Array
    wvalue: jax.Array
    weights: jax.Array
    attn: eqx.Module
    n_heads: int = eqx.field(static=True)
    dim_k: int = eqx.field(static=True)

    def __init__(self, key, n_heads, dim):
        if (dim % n_heads) != 0:
            raise ValueError("Model dimensions must be a multiple of no. of heads")
        dim_k = dim // n_heads
        init = jax.nn.initializers.he_uniform()
        wkey, qkey, kkey, vkey = jax.random.split(key, num=4)
        self.weights = init(key=wkey, shape=(n_heads * dim_k, dim))
        self.wquery = init(key=qkey, shape=(dim, dim))
        self.wkey = init(key=kkey,shape=(dim, dim))
        self.wvalue = init(key=vkey, shape=(dim, dim))
        self.attn = SelfAttention()
        self.n_heads = n_heads
        self.dim_k = dim_k

    @eqx.filter_jit
    def __call__(self, query, key, value, mask):
        query, key, value = query @ self.wquery, key @ self.wkey, value @ self.wvalue
        query, key, value = [jnp.transpose(jnp.reshape(x, (-1, self.n_heads, self.dim_k)), (1, 0, 2)) for x in (query, key, value)]
        mask = jnp.expand_dims(mask, axis=0)
        return jnp.reshape(jnp.transpose(self.attn(query, key, value, mask), (1, 0, 2)), (-1, self.n_heads * self.dim_k)) @ self.weights

In [16]:
class LayerNorm(eqx.Module):
    gamma: jax.Array
    bias: jax.Array
    eps: int = eqx.field(static=True)

    def __init__(self, size, eps=1e-6):
        self.gamma = jnp.ones(size)
        self.bias = jnp.ones(size)
        self.eps = 1e-6

    @eqx.filter_jit
    def __call__(self, x):
        mean = jnp.mean(x, -1, keepdims=True)
        std = jnp.std(x, -1, keepdims=True)
        return (self.gamma * (x - mean) / (std + self.eps)) + self.bias

In [17]:
class Encoder(eqx.Module):
    emb: jax.Array
    attn_layers: list
    ff_layers:list
    attn_norms: list
    ff_norms: list
    n_layers: int = eqx.field(static=True)

    def __init__(self, key, n_layers, n_heads, dim):
        keys = jax.random.split(key, num=n_layers*2+1)
        emb_key, attn_keys, ff_keys = keys[0], keys[1:n_layers+1], keys[n_layers+1:]
        self.emb = jax.random.normal(emb_key, (VOCAB, EMB_SIZE))
        self.attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in attn_keys]
        self.ff_layers = [FFNN(key, dim, dim, dim*2) for key in ff_keys]
        self.attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.ff_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.n_layers = n_layers

    @eqx.filter_jit
    def __call__(self, x, mask):
        x = self.emb[x]
        for i in range(self.n_layers):
            x = self.attn_norms[i](self.attn_layers[i](x, x, x, mask) + x)
            x = self.ff_norms[i](self.ff_layers[i](x) + x)
        return x

In [18]:
class Decoder(eqx.Module):
    emb: jax.Array
    mask: jax.Array = eqx.field(static=True)
    masked_attn_layers: list
    attn_layers: list
    ff_layers:list
    masked_attn_norms: list
    attn_norms: list
    ff_norms: list
    n_layers: int = eqx.field(static=True)

    def __init__(self, key, n_layers, n_heads, dim):
        keys = jax.random.split(key, num=n_layers*3+1)
        emb_key, attn_keys, ff_keys, masked_attn_keys = keys[0], keys[1:n_layers+1], keys[n_layers+1:n_layers*2+1], keys[n_layers*2+1:]
        self.emb = jax.random.normal(emb_key, (VOCAB, EMB_SIZE))
        self.mask = jnp.where(jnp.triu(jnp.ones((SEQ_LEN, SEQ_LEN)), 1) == 1, np.NINF, 0)
        self.masked_attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in masked_attn_keys]
        self.attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in attn_keys]
        self.ff_layers = [FFNN(key, dim, dim, dim*2) for key in ff_keys]
        self.attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.ff_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.masked_attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.n_layers = n_layers

    @eqx.filter_jit
    def __call__(self, x, m, mask):
        x = self.emb[x]
        for i in range(self.n_layers):
            x = self.masked_attn_norms[i](self.masked_attn_layers[i](x, x, x, self.mask) + x)
            x = self.attn_norms[i](self.attn_layers[i](x, m, m, mask) + x)
            x = self.ff_norms[i](self.ff_layers[i](x) + x)
        return x

In [19]:
class EncoderDecoder(eqx.Module):
    encoder: eqx.Module
    decoder: eqx.Module

    def __init__(self, key, enc_heads, enc_layers, dec_heads, dec_layers, dim):
        enc_key, dec_key = jax.random.split(key, num=2)
        self.encoder = Encoder(enc_key, enc_layers, enc_heads, dim)
        self.decoder = Decoder(dec_key, dec_layers, dec_heads, dim)

    @eqx.filter_jit
    def __call__(self, X, y, X_mask, y_mask):
        m = self.encoder(X, X_mask)
        h = self.decoder(y, m, y_mask)
        return h

In [20]:
class Transformer(eqx.Module):
    enc_dec: eqx.Module
    linear: eqx.Module

    def __init__(self, key, dim, enc_heads, enc_layers, dec_heads, dec_layers, out_vocab):
        encdec_key, linear_key = jax.random.split(key)
        self.enc_dec = EncoderDecoder(encdec_key, enc_heads, enc_layers, dec_heads, dec_layers, dim)
        self.linear = Linear(linear_key, dim, out_vocab)

    @eqx.filter_jit
    def __call__(self, X, y, X_mask, y_mask):
        h = self.enc_dec(X, y, X_mask, y_mask)
        return jax.nn.softmax(self.linear(h))

In [21]:
N_ENCODER_HEADS = N_DECODER_HEADS = 8
N_ENCODER_LAYERS = N_DECODER_LAYERS = 3
INPUT_VOCAB = OUTPUT_VOCAB = VOCAB

In [22]:
key = jax.random.PRNGKey(0)
model = Transformer(key, MODEL_DIM, N_ENCODER_HEADS, N_ENCODER_LAYERS, N_DECODER_HEADS, N_DECODER_LAYERS, OUTPUT_VOCAB)

In [23]:
scheduler = optax.exponential_decay(0.05, 400, 0.5, transition_begin=50, end_value=0.0001)
optimizer = optax.adam(learning_rate=scheduler)

In [24]:
def predict(model, X, y, X_mask, y_mask):
    return model(X, y, X_mask, y_mask)

In [25]:
def loss(model, X, y, X_mask, y_mask):
    y_pred = jnp.log(predict(model, X, y, X_mask, y_mask))
    return -jnp.mean(y_pred[jnp.argmax(y, axis=-1)])

In [26]:
def optim(model, optimizer, loss_fn):
    opt_state = optimizer.init(model)
    vectorized_loss = jax.vmap(jax.value_and_grad(loss_fn), in_axes=(None, 0, 0, 0, 0), out_axes=(0))

    def step(model, opt_state, X, y, X_mask, y_mask):
        loss_value, grads = vectorized_loss(model, X, y, X_mask, y_mask)
        loss_value = jnp.mean(loss_value)
        grads = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), grads)
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = optax.apply_updates(model, updates)
        return model, opt_state, loss_value

    return opt_state, step

In [27]:
opt_state, step = optim(model, optimizer, loss)

In [28]:
from itertools import islice, chain, repeat

def pad_or_truncate(seq, size=SEQ_LEN, pad=0):
    return list(islice(chain(seq, repeat(pad)), size))

In [29]:
MINI_BATCH_SIZE = 32
EPOCHS = 10

In [30]:
def create_mask(arr: jax.Array):
    return jnp.where(arr == 0, np.NINF, 0)

In [None]:
for e in range(EPOCHS):
    total_loss = 0
    num_batches = 0

    for batch in dataset["Text"]:
        batch = tokenizer.encode(batch, add_bos=True, add_eos=True)
        X = [seq[1:-1] for seq in batch]
        y = batch
        X, y = [jnp.array(list(map(pad_or_truncate, x))) for x in (X, y)]
        X_mask, y_mask = [create_mask(x) for x in (X, y)]

        model, opt_state, batch_loss = step(model, opt_state, X, y, X_mask, y_mask)
        total_loss += batch_loss
        num_batches += 1

        if num_batches % 20 == 0:
            print(f"Batch: {num_batches} | Batch loss: {batch_loss}")

    epoch_loss = total_loss / num_batches
    print(f"Epoch {e} | loss: {epoch_loss}")