In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

In [2]:
# Load data
import pickle

with open("./en_fr/en_enc.pickle", "rb") as en, open("./en_fr/fr_enc.pickle", "rb") as fr:
    en_enc = pickle.load(en)
    fr_enc = pickle.load(fr)

In [None]:
MODEL_DIM = EMB_SIZE = 128
SEQ_LEN = 16

In [None]:
positions = jnp.arange(0, 16, 1)
positions = position / (1000 ** ((2*positions)/MODEL_DIM))
positional_encodings = jnp.where(jnp.array([True if i%2 == 0 else False for i in range(16)]), 
                                 jnp.sin(positions),
                                 jnp.cos(positions))

In [None]:
class Linear(eqx.Module):
    def __init__(self, key, nin, nout):
        init = jax.nn.initializers.he_uniform()
        self.weights = init(key, (nin, nout))
        self.bias = jnp.ones(nout)
    
    @eqx.filter_jit
    def __call__(self, x):
        return x @ self.weights + bias
    
class FFNN(eqx.Module):
    def __init__(self, key, nin, nout, nhidden, nlayers):
        init = jax.nn.initializers.he_uniform()
        keys = jax.random.spilt(key, num=nlayers)
        layers = [
            Linear(keys[0], nin, nhidden)
        ]
        for i in range(1, nlayers-1):
            layers.append(jax.nn.relu)
            layers.append(Linear(keys[i], nhidden, nhidden))
        layers.append(Linear(keys[-1], nhidden, nout))
        self.layers = layers
        
    @eqx.filter_jit
    def __call__(self, x):
        for i in range(x):
            x = self.layers(i)
        return x

class SelfAttention(eqx.Module):
    def __init__(self, key, dim, dim_k, dim_v, mask=None):
        qkey, kkey, vkey = jax.random.split(key, num=3)
        init = jax.nn.initializers.he_uniform()
        self.wquery = init(qkey, (dim, dim_k))
        self.wkey = init(kkey, (dim, dim_k))
        self.wvalue = init(vkey, (dim, dim_v))
        self.mask = mask
        
    @eqx.filter_jit
    def __call__(self, x):
        query, key, value = x @ self.wquery, x @ self.wkey, x @ self.vkey
        scaled_dot_prod = query @ key.T / jnp.sqrt(query.shape[1])
        if self.mask is not None:
            scaled_dot_prod = mask * scaled_dot_prod
        return (jax.nn.softmax(scaled_dot_prod) @ value)
    
class MutliHeadAttention(eqx.Module):
    def __init__(self, key, heads, dim):
        if (dim % head) != 0:
            raise ValueError("Model dimensions must be a multiple of no. of heads")
        dim_k = dim_v = dim // head
        init = jax.nn.initializers.he_uniform()
        keys = jax.random.split(key, num=heads+1)
        self.weights = init(keys[0], (h * dim_v, dim))
        self.heads = [SelfAttention(k, dim, dim_k, dim_v) for k in keys[1:]]
        
    @eqx.filter_jit
    def __call__(self, x):
        