In [2]:
import jax
import jax.numpy as jnp 
import numpy as np
from transformers import Transformer

In [3]:
transformer = Transformer(input_vocab=10, output_vocab=10, model_dim=4, feedforward_dim=4, num_attention_layer=2)

In [4]:
input_sample = jnp.arange(6).reshape((1,6))
output_sample = jnp.arange(10).reshape((1,10))

In [5]:
val = transformer.init(jax.random.key(0), input_sample, output_sample, None, None)

In [6]:
transformer.apply(val, input_sample, output_sample, None, None)

Array([[[-3.5400112 , -3.1865373 , -4.278422  , -2.6243584 ,
         -2.461668  , -0.8055344 , -2.1971378 , -3.346291  ,
         -1.9560196 , -3.7539833 ],
        [-3.538907  , -3.179956  , -4.2882752 , -2.617251  ,
         -2.4648693 , -0.80752194, -2.2028275 , -3.337326  ,
         -1.9495046 , -3.7598567 ],
        [-3.5391853 , -3.1818264 , -4.2855744 , -2.6192086 ,
         -2.4639773 , -0.80691355, -2.2012715 , -3.3398287 ,
         -1.9514494 , -3.7582471 ],
        [-3.5384138 , -3.1775239 , -4.2919655 , -2.614549  ,
         -2.4660563 , -0.8082184 , -2.2049944 , -3.333952  ,
         -1.947267  , -3.7620478 ],
        [-3.5373738 , -3.1718285 , -4.3002157 , -2.6084504 ,
         -2.4687643 , -0.8100149 , -2.2098422 , -3.326214  ,
         -1.9417067 , -3.766941  ],
        [-3.538507  , -3.1778598 , -4.2914166 , -2.6149535 ,
         -2.465886  , -0.8081418 , -2.2046657 , -3.3344417 ,
         -1.9475198 , -3.761723  ],
        [-3.536104  , -3.1651356 , -4.3096533 , -2.6

## Training and testing a transformer 
We'll try to create a transformer which is able to copy tasks perfectly

In [7]:
from utils import create_train_state, train_model, Batch

In [8]:
transformer = Transformer(input_vocab=10, output_vocab=10, model_dim=4, feedforward_dim=4, num_attention_layer=2)

In [9]:
s = create_train_state(model=transformer, learning_rate=1e-3, key=jax.random.key(0), batch_size=5, input_seq_len=10, output_seq_len=9)

In [10]:
def copy_data_generator(num_batches=10, batch_size=5, input_size=10):
    for _ in range(num_batches):
        data = jax.random.randint(key=jax.random.key(0), shape=(batch_size,input_size), minval=0, maxval=9)
        batch = Batch(data, data, 0)
        yield batch

In [11]:
trained_s = train_model(s, data_generator=copy_data_generator, num_epoch=10)

In [12]:
trained_s.params

{'decode_preprocessor': {'embedding': {'embed': {'embedding': Array([[ 0.45437104, -0.20376906, -0.45638642,  0.5413851 ],
           [ 0.8758611 , -0.47267005,  0.6453514 ,  0.14286825],
           [-0.22862883,  0.4096755 , -0.36848074, -0.15795685],
           [-0.51804626, -0.21827853, -0.83361703, -0.6102724 ],
           [-0.12268893,  1.0895722 , -0.60204995, -0.0556876 ],
           [ 0.25102603, -0.16714285, -0.6975144 ,  0.36186084],
           [ 0.28035256, -0.29600814,  0.05154541, -0.2329466 ],
           [-1.1236365 ,  0.759048  , -0.92856157, -0.86625826],
           [-0.6288313 , -0.54888487, -0.60785353,  0.5248716 ],
           [-1.1798702 , -0.04060038, -0.6768519 ,  0.21358785]],      dtype=float32)}}},
 'decoder': {'layers_0': {'attention1': {'wk': {'bias': Array([ 0.00507987,  0.00228052, -0.00303675,  0.00406729], dtype=float32),
     'kernel': Array([[-1.1227298 , -0.46607286, -0.97124445, -0.4808624 ],
            [-0.09241292, -0.16312562,  1.0357897 ,  0.1776

We'll now try to run this transformer

In [20]:
from utils import decode

input = jax.random.randint(key=jax.random.key(0), shape=(5,10), minval=0, maxval=9)
output_init = input[:, 0].reshape(-1, 1)

In [21]:
input

Array([[2, 3, 7, 4, 8, 7, 1, 0, 6, 7],
       [0, 6, 1, 3, 8, 1, 8, 1, 2, 2],
       [7, 3, 1, 4, 8, 6, 4, 1, 4, 0],
       [8, 8, 2, 5, 2, 8, 3, 1, 2, 2],
       [4, 6, 1, 0, 3, 7, 1, 0, 8, 5]], dtype=int32)

In [22]:
decode(trained_s, input, output_init, 10)

[[2]
 [0]
 [7]
 [8]
 [4]]
[[2 2]
 [0 2]
 [7 2]
 [8 8]
 [4 2]]
[[2 2 2]
 [0 2 2]
 [7 2 2]
 [8 8 6]
 [4 2 2]]
[[2 2 2 2]
 [0 2 2 2]
 [7 2 2 2]
 [8 8 6 2]
 [4 2 2 2]]
[[2 2 2 2 2]
 [0 2 2 2 2]
 [7 2 2 2 2]
 [8 8 6 2 8]
 [4 2 2 2 2]]
[[2 2 2 2 2 2]
 [0 2 2 2 2 2]
 [7 2 2 2 2 2]
 [8 8 6 2 8 2]
 [4 2 2 2 2 2]]
[[2 2 2 2 2 2 2]
 [0 2 2 2 2 2 2]
 [7 2 2 2 2 2 2]
 [8 8 6 2 8 2 8]
 [4 2 2 2 2 2 2]]
[[2 2 2 2 2 2 2 2]
 [0 2 2 2 2 2 2 2]
 [7 2 2 2 2 2 2 2]
 [8 8 6 2 8 2 8 2]
 [4 2 2 2 2 2 2 2]]
[[2 2 2 2 2 2 2 2 2]
 [0 2 2 2 2 2 2 2 2]
 [7 2 2 2 2 2 2 2 2]
 [8 8 6 2 8 2 8 2 2]
 [4 2 2 2 2 2 2 2 2]]


Array([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [0, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [7, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [8, 8, 6, 2, 8, 2, 8, 2, 2, 2],
       [4, 2, 2, 2, 2, 2, 2, 2, 2, 2]], dtype=int32)