In [None]:
import jax
import optax
import unicodedata
import numpy as np
import equinox as eqx
import functools as ft
import jax.numpy as jnp
import sentencepiece as spm
from jax import lax, config
from itertools import islice, chain, repeat

In [2]:
config.update("jax_debug_nans", True)

In [3]:
def load_data(path: str):
    """
    Load data from the given path of a file in utf-8 format.

    Parameters
    ----------
    path: str
        Path of the file to load.
    """
    with open(path, encoding="utf-8") as f:
        return f.readlines()

In [4]:
def train_spm(iterable, prefix, vocab=2_000, sentence_size=50_000, model_type="bpe"):
    """
    Train a sentence-piece tokenizer.
    The pad, unk, bos, eos tokens corresponds to 0, 1, 3 and 4 id respectively.

    Parameters
    ----------
    iterable: Sequence[str]
        A list of sentences to train the toknizer.
    prefix: str
        A prefix for .model, .vocab files.
    vocab: int, optional
        Size of the vocabulary (default is 2_000)
    sentence_size: int, optional
        Size of sentences to sample for training. (default is 50_000)
    model_type: str, optional
        Type of model. Either "bpe" or "unigram". (default is "bpe")
    """
    spm.SentencePieceTrainer.train(sentence_iterator=iter(iterable), model_prefix=prefix, vocab_size=vocab,
                                   model_type="bpe", normalization_rule_name="identity", 
                                   input_sentence_size=sentence_size, shuffle_input_sentence=True, 
                                   pad_id=0, unk_id=1, bos_id=2, eos_id=3)

In [5]:
def load_spm(path):
    """
    Load a pretrained tokenizer

    Parameters
    ----------
    path: str
        Path to the model file

    Returns
    -------
    SentencePieceProcessor
        A trained sentence-piece tokenizer.
    """
    local_spm = spm.SentencePieceProcessor(model_file=path)
    return local_spm

In [6]:
def normalize(arr, form):
    """
    Performs unicode noramlization on sentences.

    Paramters
    ---------
    arr: str
        A list of sentences
    form: str
        Normalization form to use. NFD/NFC/NFKD/NFKC. (default is "NFC")

    Returns
    -------
    list
        A list of normalized sentences.
    """
    return [unicodedata.normalize(form, sentence.strip("\n")) for sentence in arr]

In [7]:
def create_mask(arr):
    """
    Creates a mask array in the same shape as the given a padded array.

    Parameters
    ----------
    arr: jax.Array
        Padded Array

    Returns
    -------
    jax.Array
        A mask array.
    """
    return jnp.where(arr == 0, np.NINF, 0)

In [8]:
def data_split(arr, train_size=0.8, val_size=0.1):
    """
    Splits a given list into train, validation and test sets. The sum of train &
    validation sizes should not exceed 1. If the sum if < 1, test set has the remaining
    elements of the array.

    Parameters
    ----------
    arr: list
        A list of sentences.
    train_size: float, optional
        Percentage of train set. (dafault is 0.8)
    val_size: float, optional
        Percentage of validation set (default is 0.1)

    Returns
    -------
    (list, list, list)
        A tuple of train, validation and test set.
    """
    LEN = len(arr)
    val_size = train_size + val_size
    return arr[:int(LEN*train_size)], arr[int(LEN*train_size):int(LEN*val_size)], arr[int(LEN*val_size):]

In [9]:
def pad_or_truncate(seq, size=32, pad=0):
    """
    Pads and truncates a list tokens a specific length.

    Parameters:
    seq: list
        A list of tokens.
    size: int, optional
        Max size of the sequence. (default is 32)
    pad: int, optional
        The ID to use for padding. (defualt is 0)
    
    Returns
    -------
    list
        Truncated or padded list of tokens.
    """
    return list(islice(chain(seq, repeat(pad)), size))

In [10]:
def dataloader(X, y, seq_len, batch_size=32):
    """
    A dataloader for sequence to sequence tasks.
    Loads data in batches.
    
    Parameters
    ----------
    X: list
        Sequence of inputs.
    y: list
        Sequence of output.
    seq_len: int
        Max length of the sequence.
    batch_size: int, optional
        Size of a batch. (default is 32)
    
    Returns
    -------
    generator[tuple]
        Batches of X, y and labels
    """
    if len(X) != len(y):
        raise ValueError("Length of X and y are not equal")
    n = batch_size
    X, y = iter(X), iter(y)
    while batch := tuple(islice(zip(X, y), n)):
        X_batched = [x[0] for x in batch]
        y_batched = [x[1] for x in batch]
        
        X_batched = [pad_or_truncate(seq, seq_len) for seq in X_batched]
        y_batched = [pad_or_truncate(seq, seq_len+1) for seq in y_batched]
        ybt = [x[:-1] for x in y_batched]
        labels = [x[1:] for x in y_batched]
        
        yield X_batched, ybt, labels

In [11]:
en, fr = load_data("./en_fr/en"), load_data("./en_fr/fr")

In [12]:
en, fr = normalize(en, "NFC"), normalize(fr, "NFC")

In [13]:
Xtr, Xdev, Xte = data_split(en, train_size=0.7, val_size=0.15)
ytr, ydev, yte = data_split(fr, train_size=0.7, val_size=0.15)

In [14]:
VOCAB = 5_000

In [15]:
# train_spm(Xtr, "./tokenizer/en", VOCAB, sentence_size=1_000_000)
# train_spm(ytr, "./tokenizer/fr", VOCAB, sentence_size=1_000_000)

In [16]:
en_spm, fr_spm = load_spm("./tokenizer/en.model"), load_spm("./tokenizer/fr.model") 

In [17]:
Xtr, Xdev = [en_spm.tokenize(x, out_type=int) for x in (Xtr, Xdev)]
ytr, ydev = [fr_spm.tokenize(x, out_type=int, add_bos=True, add_eos=True) for x in (ytr, ydev)]

In [18]:
del en
del fr

In [19]:
MODEL_DIM = EMB_SIZE = 256
SEQ_LEN = 64

In [96]:
class Linear(eqx.Module):
    weights: jax.Array
    bias: jax.Array
    use_bias: bool = eqx.field(static=True)

    def __init__(self, key, nin, nout, use_bias=False):
        init = jax.nn.initializers.he_uniform()
        self.weights = init(key=key, shape=(nin, nout))
        if use_bias:
            self.bias = jnp.ones(nout)
        else: 
            self.bias = None
        self.use_bias = use_bias
        
    @eqx.filter_jit
    def __call__(self, x):
        x = x @ self.weights
        if self.use_bias:
            x = x + self.bias
        return x

In [97]:
class FFNN(eqx.Module):
    layers: list
    def __init__(self, key, nin, nout, nhidden, n_layers=2, use_bias=False):
        keys = jax.random.split(key, num=n_layers)
        layers = [
            Linear(keys[0], nin, nhidden, use_bias)
        ]
        for i in range(1, n_layers-1):
            layers.append(jax.nn.gelu)
            layers.append(Linear(keys[i], nhidden, nhidden, use_bias))
        if n_layers != 1:
            layers.append(Linear(keys[-1], nhidden, nout, use_bias))
        self.layers = layers

    @eqx.filter_jit
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [98]:
class SelfAttention(eqx.Module):
    @eqx.filter_jit
    def __call__(self, query, key, value, mask):
        scaled_dot_prod = query @ jnp.transpose(key, (0, 2, 1)) / jnp.sqrt(query.shape[-1])
        scaled_dot_prod = mask + scaled_dot_prod
        return (jax.nn.softmax(scaled_dot_prod) @ value)

In [99]:
class MultiHeadAttention(eqx.Module):
    wquery: jax.Array
    wkey: jax.Array
    wvalue: jax.Array
    weights: jax.Array
    attn: eqx.Module
    n_heads: int = eqx.field(static=True)
    dim_k: int = eqx.field(static=True)

    def __init__(self, key, n_heads, dim):
        if (dim % n_heads) != 0:
            raise ValueError("Model dimensions must be a multiple of no. of heads")
        dim_k = dim // n_heads
        init = jax.nn.initializers.he_uniform()
        wkey, qkey, kkey, vkey = jax.random.split(key, num=4)
        self.weights = init(key=wkey, shape=(n_heads * dim_k, dim))
        self.wquery = init(key=qkey, shape=(dim, dim))
        self.wkey = init(key=kkey,shape=(dim, dim))
        self.wvalue = init(key=vkey, shape=(dim, dim))
        self.attn = SelfAttention()
        self.n_heads = n_heads
        self.dim_k = dim_k

    @eqx.filter_jit
    def __call__(self, query, key, value, mask):
        query, key, value = query @ self.wquery, key @ self.wkey, value @ self.wvalue
        query, key, value = [jnp.transpose(jnp.reshape(x, (-1, self.n_heads, self.dim_k)), (1, 0, 2)) for x in (query, key, value)]
        mask = jnp.expand_dims(mask, axis=0)
        attn = self.attn(query, key, value, mask)
        return jnp.reshape(jnp.transpose(attn, (1, 0, 2)), (-1, self.n_heads * self.dim_k)) @ self.weights

In [100]:
class LayerNorm(eqx.Module):
    gamma: jax.Array
    bias: jax.Array
    eps: int = eqx.field(static=True)

    def __init__(self, size, eps=1e-6):
        self.gamma = jnp.ones(size)
        self.bias = jnp.ones(size)
        self.eps = 1e-6

    @eqx.filter_jit
    def __call__(self, x):
        mean = jnp.mean(x, -1, keepdims=True)
        std = jnp.std(x, -1, keepdims=True)
        return (self.gamma * (x - mean) / (std + self.eps)) + self.bias

In [101]:
class Encoder(eqx.Module):
    emb: jax.Array
    attn_layers: list
    ff_layers:list
    attn_norms: list
    ff_norms: list
    n_layers: int = eqx.field(static=True)
    pe: jax.Array = eqx.field(static=True)

    def __init__(self, key, n_layers, n_heads, dim, seq_len, vocab):
        keys = jax.random.split(key, num=n_layers*2+1)
        emb_key, attn_keys, ff_keys = keys[0], keys[1:n_layers+1], keys[n_layers+1:]
        self.emb = jax.random.normal(emb_key, (vocab, dim))
        # Self-Attention & Forward Layers
        self.attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in attn_keys]
        self.ff_layers = [FFNN(key, dim, dim, dim*2) for key in ff_keys]
        # Layer Norms
        self.attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.ff_norms = [LayerNorm(dim) for _ in range(n_layers)]
        # Positional Encodings
        pos = jnp.arange(seq_len)[:, jnp.newaxis]
        div_term = 10_000 ** (2 * jnp.arange(0, dim, 2) / dim)
        pe = jnp.empty((seq_len, dim))
        pe = pe.at[:, 0::2].set(jnp.sin(pos / div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(pos / div_term))
        self.pe = pe
        # Static Arguments
        self.n_layers = n_layers

    @eqx.filter_jit
    def __call__(self, x, mask):
        x = self.emb[x[...]]
        x = x + self.pe
        for i in range(self.n_layers):
            x = self.attn_norms[i](self.attn_layers[i](x, x, x, mask) + x)
            x = self.ff_norms[i](self.ff_layers[i](x) + x)
        return x

In [102]:
class Decoder(eqx.Module):
    emb: jax.Array
    mask: jax.Array = eqx.field(static=True)
    masked_attn_layers: list
    attn_layers: list
    ff_layers:list
    masked_attn_norms: list
    attn_norms: list
    ff_norms: list
    n_layers: int = eqx.field(static=True)
    pe: jax.Array = eqx.field(static=True)

    def __init__(self, key, n_layers, n_heads, dim, seq_len, vocab):
        keys = jax.random.split(key, num=n_layers*3+1)
        emb_key, attn_keys, ff_keys, masked_attn_keys = keys[0], keys[1:n_layers+1], keys[n_layers+1:n_layers*2+1], keys[n_layers*2+1:]
        self.emb = jax.random.normal(emb_key, (vocab, dim))
        self.mask = jnp.where(jnp.triu(jnp.ones((seq_len, seq_len)), 1) == 1, np.NINF, 0)
        # Masked-Attention, Self-Attention & Forward Layers
        self.masked_attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in masked_attn_keys]
        self.attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in attn_keys]
        self.ff_layers = [FFNN(key, dim, dim, dim*2) for key in ff_keys]
        # Layer Norms
        self.masked_attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.ff_norms = [LayerNorm(dim) for _ in range(n_layers)]
        # Positional Encodings
        pos = jnp.arange(seq_len)[:, jnp.newaxis]
        div_term = 10_000 ** (2 * jnp.arange(0, dim, 2) / dim)
        pe = jnp.empty((seq_len, dim))
        pe = pe.at[:, 0::2].set(jnp.sin(pos / div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(pos / div_term))
        self.pe = pe
        # Static Arguments
        self.n_layers = n_layers

    @eqx.filter_jit
    def __call__(self, x, m, x_mask, m_mask):
        x = self.emb[x[...]]
        x = x + self.pe
        x_mask = self.mask + x_mask
        for i in range(self.n_layers):
            x = self.masked_attn_norms[i](self.masked_attn_layers[i](x, x, x, x_mask) + x)
            x = self.attn_norms[i](self.attn_layers[i](x, m, m, m_mask) + x)
            x = self.ff_norms[i](self.ff_layers[i](x) + x)
        return x

In [103]:
class EncoderDecoder(eqx.Module):
    encoder: eqx.Module
    decoder: eqx.Module

    def __init__(self, key, dim, enc_heads, enc_layers, dec_heads, dec_layers, in_vocab, out_vocab, seq_len):
        enc_key, dec_key = jax.random.split(key, num=2)
        self.encoder = Encoder(enc_key, enc_layers, enc_heads, dim, seq_len, in_vocab)
        self.decoder = Decoder(dec_key, dec_layers, dec_heads, dim, seq_len, out_vocab)

    @eqx.filter_jit
    def __call__(self, X, y, X_mask, y_mask):
        m = self.encoder(X, X_mask)
        h = self.decoder(y, m, y_mask, X_mask)
        return h

In [104]:
class Transformer(eqx.Module):
    enc_dec: eqx.Module
    linear: eqx.Module

    def __init__(self, key, dim, enc_heads, enc_layers, dec_heads, dec_layers, in_vocab, out_vocab, seq_len):
        encdec_key, linear_key = jax.random.split(key)
        self.enc_dec = EncoderDecoder(encdec_key, dim,
                                      enc_heads, enc_layers, 
                                      dec_heads, dec_layers, 
                                      in_vocab, out_vocab, seq_len)
        self.linear = Linear(linear_key, dim, out_vocab)

    @eqx.filter_jit
    def __call__(self, X, y, X_mask, y_mask):
        logits = self.enc_dec(X, y, X_mask, y_mask)
        return jax.nn.softmax(self.linear(logits)) + 1e-9

In [105]:
N_ENCODER_HEADS = N_DECODER_HEADS = 8
N_ENCODER_LAYERS = N_DECODER_LAYERS = 3
INPUT_VOCAB = OUTPUT_VOCAB = VOCAB

In [106]:
key = jax.random.PRNGKey(0)
model = Transformer(key, MODEL_DIM, 
                    N_ENCODER_HEADS, N_ENCODER_LAYERS, 
                    N_DECODER_HEADS, N_DECODER_LAYERS, 
                    INPUT_VOCAB, OUTPUT_VOCAB, SEQ_LEN)

In [107]:
BATCH_SIZE = 32 
EPOCHS = 5

In [108]:
n_steps =  int(EPOCHS * len(Xtr)/BATCH_SIZE)

In [109]:
scheduler = optax.warmup_cosine_decay_schedule(0.01, 10, n_steps//100, n_steps//10, 0.5)
optimizer = optax.adam(learning_rate=scheduler)

In [110]:
def predict(model, X, y, X_mask, y_mask):
    return model(X, y, X_mask, y_mask)

In [111]:
def loss(model, X, y, X_mask, y_mask, labels):
    y_pred = jnp.log(predict(model, X, y, X_mask, y_mask))
    y_pred = jnp.where(labels==0, 0, jnp.take(y_pred, labels, axis=-1))
    count = jnp.count_nonzero(y_pred)
    return -jnp.sum(y_pred)/count

In [112]:
def optim(model, optimizer, loss_fn, vectorize=True, in_axes=None, out_axes=None):
    opt_state = optimizer.init(model)
    grad = jax.value_and_grad(loss_fn)
    if vectorize:
        gradient = jax.vmap(grad, in_axes=in_axes, out_axes=out_axes)

    def step(model, opt_state, X, y, X_mask, y_mask, labels):
        loss_value, grads = gradient(model, X, y, X_mask, y_mask, labels)
        if vectorize:
            loss_value = jnp.mean(loss_value)
            grads = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), grads)
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = optax.apply_updates(model, updates)
        return model, opt_state, loss_value

    return opt_state, step

In [113]:
opt_state, step = optim(model, optimizer, loss, in_axes=(None, 0, 0, 0, 0, 0), out_axes=0)

In [114]:
for e in range(EPOCHS):
    total_loss = 0
    num_batches = 0
    total_tokens = 0
    for i, (Xbt, ybt, labelbt) in enumerate(dataloader(Xtr, ytr, SEQ_LEN)):
        total_tokens += len([token for seq in labelbt for token in list(filter(lambda x: x!=0, seq))])
        Xbt, ybt, labelbt = [jnp.array(x) for x in (Xbt, ybt, labelbt)]
        Xmask, ymask = [create_mask(x) for x in (Xbt, ybt)]

        model, opt_state, batch_loss = step(model, opt_state, Xbt, ybt, Xmask, ymask, labelbt)
        total_loss += batch_loss
        num_batches += 1

        if num_batches % 20 == 0:
            print(f"Batches trained: {num_batches} | Avg. Batch loss: {total_loss/num_batches}")

    epoch_loss = total_loss / num_batches
    print(f"Epoch {e} | loss: {epoch_loss}")

Batches trained: 20 | Avg. Batch loss: 14.472692489624023
Batches trained: 40 | Avg. Batch loss: 14.5695161819458
Batches trained: 60 | Avg. Batch loss: 15.60781192779541
Batches trained: 80 | Avg. Batch loss: 16.869060516357422
Batches trained: 100 | Avg. Batch loss: 17.63990020751953
Batches trained: 120 | Avg. Batch loss: 18.153791427612305
Batches trained: 140 | Avg. Batch loss: 18.520845413208008
Batches trained: 160 | Avg. Batch loss: 18.796138763427734
Batches trained: 180 | Avg. Batch loss: 19.0102596282959
Batches trained: 200 | Avg. Batch loss: 19.181550979614258
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.


FloatingPointError: invalid value (nan) encountered in jit(sub)