In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import functools as ft

In [None]:
MODEL_DIM = EMB_SIZE = 128
SEQ_LEN = 16

In [None]:
positions = jnp.arange(0, SEQ_LEN)
positions = position / (1000 ** ((2*positions)/MODEL_DIM))
positional_encodings = jnp.where(jnp.array([True if i%2 == 0 else False for i in range(16)]), 
                                 jnp.sin(positions),
                                 jnp.cos(positions))

In [None]:
class Linear(eqx.Module):
    def __init__(self, key, nin, nout):
        init = jax.nn.initializers.he_uniform()
        self.weights = init(key, (nin, nout))
        self.bias = jnp.ones(nout)
    
    @eqx.filter_jit
    def __call__(self, x):
        return x @ self.weights + bias

In [None]:
class FFNN(eqx.Module):
    def __init__(self, key, nin, nout, nhidden, n_layers=2):
        init = jax.nn.initializers.he_uniform()
        keys = jax.random.spilt(key, num=nlayers)
        layers = [
            Linear(keys[0], nin, nhidden)
        ]
        for i in range(1, nlayers-1):
            layers.append(jax.nn.relu)
            layers.append(Linear(keys[i], nhidden, nhidden))
        if n_layers != 1:
            layers.append(Linear(keys[-1], nhidden, nout))
        self.layers = layers
        
    @eqx.filter_jit
    def __call__(self, x):
        for i in range(x):
            x = self.layers(i)
        return x

In [None]:
class SelfAttention(eqx.Module):
    def __init__(self, key, dim, dim_k, dim_v, mask=None):
        qkey, kkey, vkey = jax.random.split(key, num=3)
        init = jax.nn.initializers.he_uniform()
        self.wquery = init(qkey, (dim, dim_k))
        self.wkey = init(kkey, (dim, dim_k))
        self.wvalue = init(vkey, (dim, dim_v))
        self.mask = mask
        
    @eqx.filter_jit
    def __call__(self, x):
        query, key, value = x @ self.wquery, x @ self.wkey, x @ self.vkey
        scaled_dot_prod = query @ key.T / jnp.sqrt(query.shape[1])
        if self.mask is not None:
            scaled_dot_prod = self.mask * scaled_dot_prod
        return (jax.nn.softmax(scaled_dot_prod) @ value)

In [None]:
class MutliHeadAttention(eqx.Module):
    def __init__(self, key, n_heads, dim, mask=None):
        if (dim % n_heads) != 0:
            raise ValueError("Model dimensions must be a multiple of no. of heads")
        dim_k = dim_v = dim // n_heads
        init = jax.nn.initializers.he_uniform()
        keys = jax.random.split(key, num=n_heads+1)
        self.weights = init(keys[0], (h * dim_v, dim))
        self.heads = [SelfAttention(k, dim, dim_k, dim_v, mask) for k in keys[1:]]
        
    @eqx.filter_jit
    def __call__(self, x):
        attn_vectors = []
        for head in self.heads:
            attn_vectors.append(head(x))
        return jnp.hstack(attn_vectors) @ self.weights

In [None]:
class LayerNorm(eqx.Module):
    def __init__(self, size, eps=1e-6):
        self.gamma = jnp.ones(size)
        self.bias = jnp.ones(bias)
        self.eps = 1e-6
        
    def __call__(self, x):
        mean = jnp.mean(x, -1)
        std = jnp.std(x, -1)
        return (self.gamma * (x - mean) / (std + self.eps)) + self.bias

In [None]:
class Encoder(eqx.Module):
    def __init__(self, key, n_layers, n_heads, dim):
        keys = jax.random.split(key, num=n_layers*2)
        attn_keys, ff_keys = keys[:n_layers], keys[n_layers:]
        self.attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in attn_keys]
        self.ff_layers = [FFNN(key, dim, dim, dim*4) for key in ff_keys]
        self.attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.ff_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.n_layers = n_layers
    
    @eqx.filter_jit
    def __call__(self, x):
        for i in range(n_layers):
            x = self.attn_norms[i](self.attn_layers[i](x) + x)
            x = self.ff_norms[i](self.ff_layers[i](x) + x)
            return x

In [None]:
class Decoder(eqx.Module):
    def __init__(self, key, n_layers, n_heads, dim):
        keys = jax.random.split(key, num=n_layers*3)
        mask = jnp.triu(jnp.ones((SEQ_LEN, dim)), -1) + 1e-9
        attn_keys, ff_keys, masked_attn_keys = keys[:n_layers], keys[n_layers:n_layers*2], keys[n_layers*2]
        self.masked_attn_layers = [MultiHeadAttention(key, n_heads, dim, mask=mask) for key in masked_attn_keys]
        self.attn_layers = [MultiHeadAttention(key, n_heads, dim) for key in attn_keys]
        self.ff_layers = [FFNN(key, dim, dim, dim*4) for key in ff_keys]
        self.attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.ff_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.masked_attn_norms = [LayerNorm(dim) for _ in range(n_layers)]
        self.n_layers = n_layers
        
    @eqx.filter_jit
    def __call__(self, x):
        for i in range(n_layers):
            x = self.masked_attn_norms[i](self.masked_attn_layers[i](x))
            x = self.attn_norms[i](self.attn_layers[i](x) + x)
            x = self.ff_norms[i](self.ff_layers[i](x) + x)
            return x

In [None]:
class EncoderDecoder(eqx.Module):
    def __init__(self, key, enc_heads, enc_layers, dec_layers, dec_heads, dim):
        enc_key, dec_key = jax.random.split(key, num=2)
        self.encoder = Encoder(enc_key, enc_layers, enc_heads)
        self.decoder = Decoder(dec_key, dec_layers, dec_heads)
        
    @eqx.filter_jit
    def __call__(self, x):
        