In [1]:
import jax
import jax.numpy as jnp 
import numpy as np
from transformers import Transformer

In [2]:
transformer = Transformer(input_vocab=50, output_vocab=50, model_dim=4, feedforward_dim=4, num_attention_layer=2)

Try using the transformer with None mask

In [3]:
input_sample = jnp.arange(12).reshape((2,6))
output_sample = jnp.arange(20).reshape((2,10))

params = transformer.init(jax.random.key(0), input_sample, output_sample, None, None)
transformer.apply(params, input_sample, output_sample, None, None)

Array([[[-3.603883 , -4.0610104, -3.1380339, -3.374127 , -4.6955166,
         -2.9281735, -4.632578 , -5.0074053, -3.2173376, -3.6309705,
         -3.9314628, -4.15722  , -3.5566726, -4.9649153, -5.492261 ,
         -5.5499268, -3.6127524, -3.7578468, -2.7155385, -5.7924376,
         -4.4435678, -2.8005562, -4.3252234, -3.2564018, -3.0761425,
         -3.7882452, -3.8674212, -4.0647535, -4.6734343, -4.9556856,
         -3.6732337, -3.0245109, -5.5029955, -6.761177 , -4.890404 ,
         -4.755705 , -4.7640505, -2.5425756, -6.1702604, -4.9347587,
         -4.816683 , -4.641508 , -4.727723 , -5.4903345, -4.9537516,
         -3.3816879, -5.1942663, -4.923053 , -4.753208 , -6.513439 ],
        [-3.603629 , -4.0608835, -3.1365728, -3.3745825, -4.6952186,
         -2.9286842, -4.6325016, -5.0082474, -3.2178364, -3.6310973,
         -3.931076 , -4.157234 , -3.5563345, -4.9654717, -5.4919977,
         -5.5509434, -3.612389 , -3.7583237, -2.7159522, -5.7915955,
         -4.4434233, -2.8005686, 

Try using the transformer with real mask

In [4]:
key = jax.random.key(0)
input_sample = jnp.arange(12).reshape((2,6))
output_sample = jnp.arange(20).reshape((2,10))
input_mask_sample = jax.random.randint(key=key, minval=0, maxval=2, shape=(2,1,6))
output_mask_sample = jax.random.randint(key=key, minval=0, maxval=2, shape=(2, 10, 10))

params = transformer.init(key, input_sample, output_sample, input_mask_sample, output_mask_sample)
transformer.apply(params, input_sample, output_sample, input_mask_sample, output_mask_sample)

Array([[[-3.603989 , -4.0667534, -3.0990503, -3.401723 , -4.6962757,
         -2.9440773, -4.694457 , -5.065493 , -3.2189136, -3.6387858,
         -3.8920884, -4.144193 , -3.553751 , -4.9897647, -5.5128183,
         -5.5400167, -3.6427002, -3.71738  , -2.7410808, -5.777725 ,
         -4.4498806, -2.8218968, -4.295107 , -3.264205 , -3.064641 ,
         -3.7683206, -3.8424459, -4.063854 , -4.634813 , -4.950702 ,
         -3.6904452, -3.0409403, -5.5029   , -6.745652 , -4.891506 ,
         -4.7566223, -4.747865 , -2.5156102, -6.128814 , -4.9246073,
         -4.8040357, -4.6842966, -4.7367573, -5.4389534, -4.8977823,
         -3.4013178, -5.1699514, -4.8934565, -4.7408514, -6.4938803],
        [-3.5893905, -4.0576286, -3.0258148, -3.4202871, -4.6783237,
         -2.9700656, -4.6691704, -5.096091 , -3.2488794, -3.6442692,
         -3.8827024, -4.1497536, -3.5351818, -5.0145435, -5.4898996,
         -5.603633 , -3.6115198, -3.7599292, -2.75815  , -5.732032 ,
         -4.4393206, -2.8163948, 

## Training and testing transformer 
We'll try to create a transformer which is able to copy tasks perfectly

In [5]:
from utils import create_train_state, train_model, Batch

In [6]:
transformer = Transformer(input_vocab=10, output_vocab=10, model_dim=128, feedforward_dim=512, num_attention_layer=8)

In [7]:
state = create_train_state(model=transformer, learning_rate=1e-3, key=jax.random.key(0))

In [8]:
import random

def copy_data_generator(num_batches=50, batch_size=16, input_size=10):
    for _ in range(num_batches):
        data = jax.random.randint(key=jax.random.key(0), shape=(batch_size,input_size), minval=0, maxval=10)
        batch = Batch(data, data, 0)
        yield batch

In [9]:
trained_state = train_model(state, data_generator=copy_data_generator, num_epoch=10)

Epoch: 0, Loss: 2.262263774871826
Epoch: 1, Loss: 1.8123723268508911
Epoch: 2, Loss: 1.3162075281143188
Epoch: 3, Loss: 0.6910473108291626
Epoch: 4, Loss: 0.39335551857948303
Epoch: 5, Loss: 0.1555006355047226
Epoch: 6, Loss: 0.30309563875198364
Epoch: 7, Loss: 0.23683658242225647
Epoch: 8, Loss: 0.06251252442598343
Epoch: 9, Loss: 0.25028517842292786


In [10]:
params = trained_state.params 
params

{'decode_preprocessor': {'embedding': {'embed': {'embedding': Array([[ 0.13046482,  0.05108283,  0.07649177, ..., -0.05616669,
            -0.0596702 ,  0.08942231],
           [-0.08292061, -0.11324471,  0.03027694, ...,  0.02345745,
            -0.02232559, -0.09712467],
           [-0.03112921, -0.0152321 ,  0.09909549, ...,  0.02898805,
            -0.02250978, -0.03741901],
           ...,
           [-0.08962072, -0.096058  ,  0.18010834, ..., -0.00664432,
            -0.01227721,  0.01758045],
           [ 0.05646501, -0.0065653 ,  0.01465466, ...,  0.01682173,
             0.00196423,  0.0879422 ],
           [-0.06332215, -0.09996408, -0.03398724, ..., -0.03026902,
             0.20613018,  0.16358805]], dtype=float32)}}},
 'decoder': {'layers_0': {'attention1': {'wk': {'bias': Array([-2.1593855e-03, -1.7061057e-03, -3.8039095e-03,  2.2698171e-03,
             1.8988467e-03,  1.6266392e-03,  8.0279395e-04, -1.0158069e-03,
            -3.7158933e-03,  6.3158898e-04, -2.6964992e

We'll now try to run this transformer

In [11]:
from utils import decode

input = jax.random.randint(key=jax.random.key(0), shape=(5,10), minval=0, maxval=9)
output_init = input[:, 0].reshape(-1, 1)
input_mask = jnp.ones((1,1,10), dtype=int)

In [12]:
input

Array([[2, 3, 7, 4, 8, 7, 1, 0, 6, 7],
       [0, 6, 1, 3, 8, 1, 8, 1, 2, 2],
       [7, 3, 1, 4, 8, 6, 4, 1, 4, 0],
       [8, 8, 2, 5, 2, 8, 3, 1, 2, 2],
       [4, 6, 1, 0, 3, 7, 1, 0, 8, 5]], dtype=int32)

In [13]:
decode(trained_state, input, output_init, 10, input_mask)

Array([[2, 3, 9, 6, 4, 2, 2, 2, 2, 2],
       [0, 4, 3, 0, 0, 0, 6, 6, 6, 6],
       [7, 6, 3, 9, 1, 8, 6, 4, 2, 2],
       [8, 5, 8, 3, 0, 8, 8, 2, 8, 0],
       [4, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)