In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Example

In [288]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

# class Transformer(nn.Module):
#     num_layers: int
#     num_heads: int
#     d_model: int
#     d_ff: int
#     vocab_size: int
#     max_len: int

#     def setup(self):
#         self.token_embed = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)
#         self.pos_embed = nn.Embed(num_embeddings=self.max_len, features=self.d_model)
#         self.encoder_layers = [nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.d_model)
#                                for _ in range(self.num_layers)]
#         self.ffn_layers = [nn.Dense(features=self.d_ff) for _ in range(self.num_layers)]
#         self.output_dense = nn.Dense(features=self.vocab_size)

#     def __call__(self, x):
#         seq_len = x.shape[1]
#         token_embeddings = self.token_embed(x)
#         position_embeddings = self.pos_embed(jnp.arange(seq_len))
#         x = token_embeddings + position_embeddings

#         for encoder, ffn in zip(self.encoder_layers, self.ffn_layers):
#             x = encoder(x)
#             x = nn.relu(ffn(x))
#         return self.output_dense(x)


In [289]:
from transformers import BertTokenizer


In [8]:
from datasets import load_dataset

In [9]:
import optax


# My transformer

In [11]:
import jax
import jax.numpy as jnp
from flax import linen as nn

# importing required libraries
import math, copy, re
import warnings
# import pandas as pd
# import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

warnings.simplefilter("ignore")

# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

DEBUGGING = False
# DEBUGGING = True
if DEBUGGING:
    dbg_figs = ['encoder-decoder attention']
    figs = {k: plt.figure() for k in dbg_figs}
    plt.ion()
    plt.show()

dropout_rate = 0.1

In [12]:
key = jax.random.PRNGKey(0)
key

Array([0, 0], dtype=uint32)

In [61]:
from typing import Optional

## Embedding

In [271]:
class Embedding(nn.Module):
    vocab_size: int
    embed_dim: int
    embedding: Optional[jnp.ndarray]=None
    
    def setup(self):
        """
        Args:
            vocab_size: size of vocabulary
            embed_dim: dimension of embeddings
        """
        super(Embedding, self).__init__()
        if self.embedding is not None:
            # assert embedding.shape == (vocab_size, embed_dim)
            self.embed = nn.Embed(self.vocab_size, self.embed_dim, embedding_init=lambda *args: self.embedding)
            assert True
        else:
            self.embed = nn.Embed(self.vocab_size, self.embed_dim)

    def __call__(self, x):
        """
        Args:
            x: input vector
        Returns:
            out: embedding vector
        """
        out = self.embed(x)
        return out

    def hidden_states(self, x):
        """
        Args:
            x: input vector
        Returns:
            out: {'': embedding vector}
        """
        out = self.embed(x)
        return {'': out}
Embedding(10, 16), Embedding(10, 16, jax.random.ball(key, 16, shape=(10,)))

(Embedding(
     # attributes
     vocab_size = 10
     embed_dim = 16
     embedding = None
 ),
 Embedding(
     # attributes
     vocab_size = 10
     embed_dim = 16
     embedding = Array([[-2.87476331e-01,  2.78082073e-01, -1.73102710e-02,
             -3.03498030e-01,  2.67142653e-01,  1.80099145e-01,
             -1.70347854e-01, -4.68755960e-01,  1.86945155e-01,
             -5.18128425e-02, -2.80789323e-02,  2.15174913e-01,
              3.42391841e-02, -3.02901477e-01, -5.08673489e-02,
             -3.79411519e-01],
            [-2.90833622e-01,  5.64623713e-01,  5.08974195e-02,
              3.39994580e-01, -1.88382730e-01,  4.01012897e-02,
             -1.17090441e-01,  5.94152473e-02,  2.81335157e-03,
              2.45557472e-01, -4.65076149e-01, -5.62565252e-02,
              1.03506058e-01,  8.74283444e-03, -2.38579288e-02,
             -1.75395131e-01],
            [ 1.68782607e-01,  3.82230878e-02, -5.11262774e-01,
             -2.96743274e-01,  1.48061633e-01,  6.6110

In [72]:
jax.random.ball(key, 16, shape=(10,)).shape

(10, 16)

In [None]:
l = Embedding(10, 20)

x = jnp.ones((1, 9), dtype=int)
params = l.init(jax.random.key(0), x)
params

x = l.apply(params, x)
x

## PositionalEmbedding

In [21]:
import numpy as np

In [87]:
# register buffer in Pytorch ->
# If you have parameters in your model, which should be saved and restored in the state_dict,
# but not trained by the optimizer, you should register them as buffers.


class PositionalEmbedding(nn.Module):
    max_seq_len: int
    embed_model_dim: int
    def setup(self):
        """
        Args:
            seq_len: length of input sequence
            embed_model_dim: demension of embedding
        """
        # super(PositionalEmbedding, self).__init__()
        self.embed_dim = self.embed_model_dim

        pe = np.zeros((self.max_seq_len, self.embed_model_dim))
        for pos in range(self.max_seq_len):
            for i in range(0, self.embed_dim, 2):
                pe[pos, i] = math.sin(pos / (10000 ** (i / self.embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** (i / self.embed_dim)))
        # pe = pe.unsqueeze(0)
        # self.pe = jnp.array(pe)
        self.pe = jax.device_put(pe)
        # self.register_buffer('pe', pe)

    def __call__(self, x):
        """
        Args:
            x: input vector
        Returns:
            x: output
        """

        # make embeddings relatively larger
        x = x * math.sqrt(self.embed_dim)
        # add constant to embedding
        seq_len = x.shape[1]
        x = x + self.pe[:seq_len, :]
        return x

    def hidden_states(self, x):
        # make embeddings relatively larger
        x = x * math.sqrt(self.embed_dim)
        # add constant to embedding
        seq_len = x.shape[1]
        x = x + self.pe[:seq_len, :]
        return {'': x}



In [88]:
x.shape

(1, 9)

In [93]:
l = PositionalEmbedding(10, 20)

dummy = jnp.ones_like(x)
params = l.init(jax.random.key(0), dummy)
params

{}

In [94]:
x = l.apply(params, x)
x

Array([[[ 1.2973769 ,  3.441404  ,  0.8275838 ,  1.1559722 ,
          0.97165006,  0.7105597 , -2.61547   ,  1.4329365 ,
         -1.057561  ,  3.00142   , -1.2118894 ,  0.15464747,
          0.9254025 ,  0.66376436,  0.8197205 ,  1.1931446 ,
         -0.3509831 ,  2.641285  , -1.2605853 , -0.17066157],
        [ 2.1388478 ,  2.9817064 ,  1.215258  ,  1.0777686 ,
          1.1294767 ,  0.69802654, -2.552416  ,  1.4309466 ,
         -1.0324448 ,  3.0011046 , -1.2018895 ,  0.15459746,
          0.9293836 ,  0.6637565 ,  0.8213054 ,  1.1931432 ,
         -0.35035217,  2.641285  , -1.2603341 , -0.17066163],
        [ 2.2066743 ,  2.0252573 ,  1.5422972 ,  0.85538954,
          1.2833472 ,  0.6607412 , -2.489613  ,  1.4249849 ,
         -1.0073445 ,  3.0001583 , -1.1918907 ,  0.1544475 ,
          0.93336457,  0.66373265,  0.8228903 ,  1.1931396 ,
         -0.3497212 ,  2.6412842 , -1.260083  , -0.17066169],
        [ 1.4384968 ,  1.4514116 ,  1.7575502 ,  0.5236167 ,
          1.4294046 ,

In [95]:
x.shape

(1, 9, 20)

## Attention

In [526]:
class MultiHeadAttention(nn.Module):
    embed_dim: int=512
    n_heads: int=8
    def setup(self):
        """
        Args:
            embed_dim: dimension of embeding vector output
            n_heads: number of self attention heads
        """
        self.single_head_dim = int(self.embed_dim / self.n_heads)  # 512/8 = 64  . each key,query, value will be of 64d

        # key,query and value matrixes    #64 x 64
        self.query_matrix = nn.Dense(self.single_head_dim)  # single key matrix for all 8 keys #512x512
        self.key_matrix = nn.Dense(self.single_head_dim)
        self.value_matrix = nn.Dense(self.single_head_dim)
        # self.out = nn.Dense(self.n_heads * self.single_head_dim, self.embed_dim)
        self.out = nn.Dense(self.embed_dim)

    def __call__(self, key, query, value, mask=None):  # batch_size x sequence_length x embedding_dim    # 32 x 10 x 512

        """
        Args:
           key : key vector
           query : query vector
           value : value vector
           mask: mask for decoder

        Returns:
           output vector from multihead attention
        """
        batch_size = key.shape[0]
        seq_length = key.shape[1]

        # query dimension can change in decoder during inference.
        # so we cant take general seq_length
        seq_length_query = query.shape[1]

        # 32x10x512 => (32x10x8x64)
        key = key.reshape(batch_size, seq_length, self.n_heads, self.single_head_dim)
        query = query.reshape(batch_size, seq_length_query, self.n_heads, self.single_head_dim)
        value = value.reshape(batch_size, seq_length, self.n_heads, self.single_head_dim)

        k = self.key_matrix(key)  # (32x10x8x64)
        q = self.query_matrix(query)
        v = self.value_matrix(value)

        q = q.transpose(0, 2, 1, 3)  # (batch_size, n_heads, seq_len, single_head_dim)    # (32 x 8 x 10 x 64)
        k = k.transpose(0, 2, 1, 3)  # (batch_size, n_heads, seq_len, single_head_dim)
        v = v.transpose(0, 2, 1, 3)  # (batch_size, n_heads, seq_len, single_head_dim)

        # computes attention
        # adjust key for matrix multiplication
        # 10 -> seq_len
        k_adjusted = k.transpose(0, 1, 3, 2)  # (batch_size, n_heads, single_head_dim, seq_len)  #(32 x 8 x 64 x 10)
        product = q @ k_adjusted  # (32 x 8 x 10 x 64) x (32 x 8 x 64 x 10) = #(32x8x10x10)
        # product = jnp.dot(q, k_adjusted)

        # fill those positions of product matrix as (-1e20) where mask positions are 0
        if mask is not None:
            # product = product.masked_fill(mask == 0, float("-1e20"))
            product = jnp.where(mask == 0, float("-1e20"), product)

        # divising by square root of key dimension
        product = product / math.sqrt(self.single_head_dim)  # / sqrt(64)

        # applying softmax
        scores = nn.softmax(product, axis=-1)
        if DEBUGGING:
            if scores.shape[-1] != scores.shape[-2]:
                fig = figs['encoder-decoder attention']
                if not fig.axes:
                    fig.subplots(1, 1)
                ax = fig.axes[0]
                ax.imshow(scores[0, 0].detach())
                plt.pause(0.001)
                plt.ion()

        # mutiply with value matrix
        # scores = torch.matmul(scores, v)  ##(32x8x 10x 10) x (32 x 8 x 10 x 64) = (32 x 8 x 10 x 64)
        scores = scores @ v

        # concatenated output
        concat = scores.transpose(0, 2, 1, 3).reshape(
            batch_size, seq_length_query, self.single_head_dim * self.n_heads)  # (32x8x10x64) -> (32x10x8x64)  -> (32,10,512)

        output = self.out(concat)  # (32,10,512) -> (32,10,512)

        return output

    def hidden_states(self, key, query, value,
                      mask=None):  # batch_size x sequence_length x embedding_dim    # 32 x 10 x 512

        """
        Args:
           key : key vector
           query : query vector
           value : value vector
           mask: mask for decoder
        """
        batch_size = key.shape[0]
        seq_length = key.shape[1]

        # query dimension can change in decoder during inference.
        # so we cant take general seq_length
        seq_length_query = query.shape[1]


        # 32x10x512 => (32x10x8x64)
        key = key.reshape(batch_size, seq_length, self.n_heads, self.single_head_dim)
        query = query.reshape(batch_size, seq_length_query, self.n_heads, self.single_head_dim)
        value = value.reshape(batch_size, seq_length, self.n_heads, self.single_head_dim)


        k = self.key_matrix(key)  # (32x10x8x64)
        q = self.query_matrix(query)
        v = self.value_matrix(value)

        states = {'key': key, 'query': query, 'value': value,
                    'k': k,       'q': q,         'v': v}

        q = q.transpose(0, 2, 1, 3)  # (batch_size, n_heads, seq_len, single_head_dim)    # (32 x 8 x 10 x 64)
        k = k.transpose(0, 2, 1, 3)  # (batch_size, n_heads, seq_len, single_head_dim)
        v = v.transpose(0, 2, 1, 3)  # (batch_size, n_heads, seq_len, single_head_dim)

        # computes attention
        # adjust key for matrix multiplication
        k_adjusted = k.transpose(0, 1, 3, 2)  # (batch_size, n_heads, single_head_dim, seq_len)  #(32 x 8 x 64 x 10)
        product = q @ k_adjusted  # (32 x 8 x 10 x 64) x (32 x 8 x 64 x 10) = #(32x8x10x10)
        # product = jnp.dot(q, k_adjusted)

        # fill those positions of product matrix as (-1e20) where mask positions are 0
        if mask is not None:
            # product = product.masked_fill(mask == 0, float("-1e20"))
            product = jnp.where(mask == 0, float("-1e20"), product)

        # divising by square root of key dimension
        product = product / math.sqrt(self.single_head_dim)  # / sqrt(64)

        states['attention'] = product

        # applying softmax
        scores = nn.softmax(product, axis=-1)
        if DEBUGGING:
            if scores.shape[-1] != scores.shape[-2]:
                fig = figs['encoder-decoder attention']
                if not fig.axes:
                    fig.subplots(1, 1)
                ax = fig.axes[0]
                ax.imshow(scores[0, 0].detach())
                plt.pause(0.001)
                plt.ion()

        states['attention_softmax'] = scores

        # mutiply with value matrix
        # scores = torch.matmul(scores, v)  ##(32x8x 10x 10) x (32 x 8 x 10 x 64) = (32 x 8 x 10 x 64)
        scores = scores @ v

        states['v_extracted'] = scores

        # concatenated output
        concat = scores.transpose(0, 2, 1, 3).reshape(
            batch_size, seq_length_query, self.single_head_dim * self.n_heads)  # (32x8x10x64) -> (32x10x8x64)  -> (32,10,512)

        output = self.out(concat)  # (32,10,512) -> (32,10,512)
        states['out'] = output

        states[''] = output
        return states

MultiHeadAttention()

MultiHeadAttention(
    # attributes
    embed_dim = 512
    n_heads = 8
)

In [527]:
l = MultiHeadAttention(20, 5)
dummy = jnp.ones_like(x)
params = l.init(jax.random.PRNGKey(0), dummy, dummy, dummy)
params

TypeError: reshape total size must be unchanged, got new_sizes (1, 9, 5, 4) for shape (1, 9).

In [389]:
x = l.apply(params, x, x, x)
x

TypeError: reshape total size must be unchanged, got new_sizes (1, 9, 5, 4) for shape (1, 9).

In [390]:
x = l.apply(params, x, x, x)
x = l.apply(params, x, x, x)
x

TypeError: reshape total size must be unchanged, got new_sizes (1, 9, 5, 4) for shape (1, 9).

## TransformerBlock

In [391]:
class TransformerBlock(nn.Module):
    embed_dim: int
    expansion_factor: int=4
    n_heads: int=8
    def setup(self):
        super(TransformerBlock, self).__init__()

        """
        Args:
           embed_dim: dimension of the embedding
           expansion_factor: fator ehich determines output dimension of linear layer
           n_heads: number of attention heads

        """
        self.attention = MultiHeadAttention(self.embed_dim, self.n_heads)

        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()

        self.feed_forward = nn.Sequential([
            nn.Dense(self.expansion_factor * self.embed_dim),
            nn.relu,
            nn.sigmoid,
            nn.Dense(self.embed_dim)
        ])

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def __call__(self, key, query, value, train=True):
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           norm2_out: output of transformer block

        """
        attention_out = self.attention(key, query, value)  # 32x10x512
        attention_residual_out = attention_out + query  # 32x10x512
        norm1_out = self.dropout1(self.norm1(attention_residual_out), deterministic=not train)  # 32x10x512

        feed_fwd_out = self.feed_forward(norm1_out)  # 32x10x512 -> #32x10x2048 -> 32x10x512
        feed_fwd_residual_out = feed_fwd_out + norm1_out  # 32x10x512
        norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out), deterministic=not train)  # 32x10x512

        return norm2_out

    def hidden_states(self, key, query, value):
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           norm2_out: output of transformer block

        """

        states = {}
        extend_states(states, 'attention',
                      self.attention.hidden_states(key, query, value)  # 32x10x512
                      )
        attention_residual_out = states['attention'] + query  # 32x10x512
        states['attention_residual'] = attention_residual_out
        states['norm1'] = self.norm1(attention_residual_out)  # 32x10x512
        states['dropout1'] = self.dropout1(states['norm1'], deterministic=False)

        states['feed_fwd'] = self.feed_forward(states['dropout1'])  # 32x10x512 -> #32x10x2048 -> 32x10x512
        states['feed_fwd_residual'] = states['feed_fwd'] + states['norm1']  # 32x10x512
        states['norm2'] = self.norm2(states['feed_fwd_residual'])
        states['dropout2'] = self.dropout2(states['norm2'], deterministic=False)

        states[''] = states['dropout2']
        return states

In [392]:
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)


In [393]:
l = TransformerBlock(20, 4, 5)
dummy = jnp.ones_like(x)
params = l.init({'params': params_key, 'dropout': dropout_key}, dummy, dummy, dummy)
params

TypeError: reshape total size must be unchanged, got new_sizes (1, 9, 5, 4) for shape (1, 9).

In [394]:
x = l.apply(params, x, x, x, rngs={'dropout': dropout_key})

TypeError: reshape total size must be unchanged, got new_sizes (1, 9, 5, 4) for shape (1, 9).

In [395]:
x

Array([[1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)

## TransformerEncoder

In [396]:
class TransformerEncoder(nn.Module):
    """
    Args:
        seq_len : length of input sequence
        embed_dim: dimension of embedding
        num_layers: number of encoder layers
        expansion_factor: factor which determines number of linear layers in feed forward layer
        n_heads: number of heads in multihead attention

    Returns:
        out: output of the encoder
    """
    seq_len: int
    vocab_size: int
    embed_dim: int
    num_layers: int=2
    expansion_factor: int=4
    n_heads: int=8
    word_embedding: Optional[jnp.ndarray]=None

    def setup(self):
        self.embedding_layer = Embedding(self.vocab_size, self.embed_dim, self.word_embedding)
        self.positional_encoder = PositionalEmbedding(self.seq_len, self.embed_dim)

        self.layers = [
            TransformerBlock(
                self.embed_dim,
                self.expansion_factor,
                self.n_heads
            ) for i in range(self.num_layers)
        ]

    def __call__(self, x, train=True):
        embed_out = self.embedding_layer(x)
        out = self.positional_encoder(embed_out)
        for layer in self.layers:
            out = layer(out, out, out, train)

        return out  # 32x10x512

    def hidden_states(self, x, train=True):
        states = {}
        extend_states(states, 'embedding_layer',
                      self.embedding_layer.hidden_states(x)
                      )
        extend_states(states, 'positional_encoder',
                      self.positional_encoder.hidden_states(states['embedding_layer'])
                      )
        out = states['positional_encoder']
        for i, layer in enumerate(self.layers):
            out = extend_states(states, f'layer.{i}',
                                layer.hidden_states(out, out, out, train))

        states[''] = out
        return states

In [397]:
l = TransformerEncoder(9, 22, 20, 2, 2, 5)

x = jnp.ones((1, 9), dtype=int)
params = l.init({'params': params_key, 'dropout': dropout_key}, x)
params

(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)
(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)


{'params': {'embedding_layer': {'embed': {'embedding': Array([[ 3.40739101e-01,  1.37797877e-01,  2.72358924e-01,
            -1.14617124e-01,  5.27124941e-01,  1.30033538e-01,
             2.35915124e-01, -2.99763143e-01, -1.28704578e-01,
             1.06023505e-01,  1.73475847e-01,  2.19109803e-01,
             4.53408778e-01,  1.80053078e-02, -1.88508064e-01,
            -5.44008791e-01,  3.03656548e-01, -1.54989958e-01,
             2.99121350e-01,  1.39290705e-01],
           [-1.63057372e-02, -1.39550030e-01, -3.03789154e-02,
            -1.18934819e-02, -3.25240463e-01, -3.75061750e-01,
             6.32392382e-03, -5.26759624e-02, -1.90114692e-01,
             3.41678351e-01, -1.21545315e-01,  2.46547729e-01,
             2.98387161e-03,  1.77394986e-01, -1.34345040e-01,
            -1.57021388e-01, -6.17517941e-02, -7.06213936e-02,
             1.16219029e-01,  9.09966975e-02],
           [ 6.60828725e-02,  1.82694376e-01,  3.01550746e-01,
             1.18764430e-01, -3.3099

In [398]:
x = l.apply(params, x, rngs={'dropout': dropout_key})
x

(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)
(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)


Array([[[-0.11565194, -1.6598902 , -1.4172521 ,  0.53541154,
         -1.0068995 ,  0.66421616,  0.7811579 ,  1.6725011 ,
         -0.93520933,  1.0702497 , -1.356227  ,  0.        ,
          0.4493922 ,  0.        , -1.9744937 , -0.6584746 ,
          1.0557652 ,  1.5334997 , -0.3267221 ,  0.09394453],
        [ 0.6572015 , -0.7420527 , -1.1805828 ,  0.7322328 ,
         -2.194664  ,  0.        , -0.3185876 ,  2.2280846 ,
         -1.0266783 ,  0.9867189 , -2.0689952 ,  0.09805519,
          0.        , -0.3869323 ,  0.2505166 , -0.3874002 ,
          1.1780714 ,  1.7159901 , -0.5626618 ,  0.58282185],
        [-0.11513499, -0.74673676, -0.7222719 ,  0.62054735,
         -2.057434  ,  0.96022016, -0.7308702 ,  0.5823666 ,
         -0.9596438 ,  1.039979  , -1.7560956 ,  0.14752632,
          0.        ,  1.9350797 , -1.530718  , -0.44935864,
          1.4612373 ,  1.7188765 , -0.29112458,  0.11971954],
        [ 0.2590491 , -2.47408   , -0.6671802 ,  0.1780893 ,
          0.01852507,

## DecoderBlock

In [399]:
class DecoderBlock(nn.Module):
    embed_dim: int
    expansion_factor: int=4
    n_heads: int=8
    residue_link: bool=False
    def setup(self):
        """
        Args:
           embed_dim: dimension of the embedding
           expansion_factor: fator ehich determines output dimension of linear layer
           n_heads: number of attention heads

        """
        self.attention = MultiHeadAttention(self.embed_dim, n_heads=self.n_heads)
        self.norm = nn.LayerNorm(self.embed_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.transformer_block = TransformerBlock(self.embed_dim, self.expansion_factor, self.n_heads)

    def __call__(self, key, x, value, mask, train=True):
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           mask: mask to be given for multi head attention
        Returns:
           out: output of transformer block

        """

        # we need to pass mask mask only to fst attention
        attention = self.attention(x, x, x, mask=mask)  # 32x10x512
        if self.residue_link:
            attention = attention + x
        query = self.dropout(self.norm(attention), deterministic=not train)

        out = self.transformer_block(key, query, value)

        return out

    def hidden_states(self, key, x, value, mask, train=True):
        """
        Args:
           key: key vector
           query: query vector
           value: value vector
           mask: mask to be given for multi head attention
        Returns:
           out: output of transformer block

        """

        states = {}
        # we need to pass mask mask only to fst attention
        attention = extend_states(states, 'attention',
                                  self.attention.hidden_states(x, x, x, mask=mask))  # 32x10x512
        if self.residue_link:
            attention = states['with_residue'] = attention + x
        norm = states['norm'] = self.norm(attention)
        query = states['dropout'] = self.dropout(norm, deterministic=not train)

        out = extend_states(states, 'transformer_block',
                            self.transformer_block.hidden_states(key, query, value))

        states[''] = out
        return states

In [400]:
x.shape

(1, 9, 20)

In [401]:
l = DecoderBlock(20, 4, 5)
dummy = jnp.ones_like(x)
mask = jnp.ones([x.shape[1], x.shape[1]])
params = l.init({'params': params_key, 'dropout': dropout_key}, dummy, dummy, dummy, mask)
params

(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)
(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)


{'params': {'attention': {'key_matrix': {'kernel': Array([[-0.18677855,  0.18058035,  0.6181211 , -0.2998894 ],
           [-0.32722148, -0.94128335, -0.70250386,  0.37284884],
           [ 0.35065067, -0.1755944 ,  0.81841916, -0.16105253],
           [-0.28495827, -0.42211068,  0.31499204, -0.7232241 ]],      dtype=float32),
    'bias': Array([0., 0., 0., 0.], dtype=float32)},
   'query_matrix': {'kernel': Array([[ 0.4894005 , -0.3451201 , -0.45575768,  0.01895893],
           [ 0.50087863,  0.23168735, -0.00885719, -0.45979372],
           [-0.75233257, -0.8167948 , -0.6542018 , -0.6405037 ],
           [-0.27382806, -0.5299147 ,  0.38899225,  0.30272093]],      dtype=float32),
    'bias': Array([0., 0., 0., 0.], dtype=float32)},
   'value_matrix': {'kernel': Array([[ 0.23126236,  0.4436052 , -0.29465643, -0.34886923],
           [-0.37178776,  0.68388206,  0.0160487 , -0.24249408],
           [-0.77710074, -0.94769955, -0.39375776, -0.26743925],
           [ 0.44620758,  0.47507367

In [402]:
x = l.apply(params, x, x, x, mask, rngs={'dropout': dropout_key})
x

(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)
(1, 5, 9, 4) (1, 5, 4, 9) (1, 5, 9, 9)


Array([[[ 2.46022567e-01,  1.46098563e-03, -1.42914581e+00,
         -5.38634777e-01, -7.06733167e-01, -1.53858244e+00,
          5.59874594e-01,  1.90151536e+00,  0.00000000e+00,
          1.51196682e+00,  4.26133960e-01,  1.28290981e-01,
          1.34291482e+00,  0.00000000e+00,  1.43035233e+00,
         -1.21255827e+00,  0.00000000e+00,  1.51028931e+00,
         -7.74921000e-01, -1.19906414e+00],
        [ 3.21431607e-01,  0.00000000e+00, -1.22170544e+00,
         -7.37037286e-02, -7.45696127e-01, -1.35879421e+00,
          3.31473768e-01,  2.03196049e+00, -9.17375863e-01,
          1.49787772e+00,  5.33217847e-01, -1.88892829e+00,
          1.74825382e+00, -9.46330547e-01,  1.51540029e+00,
         -1.13508582e+00,  1.82265528e-02, -1.23175994e-01,
          1.06789815e+00, -8.77075434e-01],
        [ 4.15128648e-01,  3.41603965e-01, -1.14918101e+00,
         -2.13871792e-01, -3.33832651e-01, -1.30437350e+00,
          0.00000000e+00,  1.09846044e+00, -1.00399053e+00,
         -8.

## TransformerDecoder

In [403]:
class TransformerDecoder(nn.Module):
    vocab_size: int
    embed_dim: int
    seq_len: int
    num_layers: int = 2
    expansion_factor: int = 4
    n_heads: int = 8
    word_embedding: Optional[jnp.ndarray]=None
    residue_links: bool=False
    def setup(self):
        """  
        Args:
           target_vocab_size: vocabulary size of taget
           embed_dim: dimension of embedding
           seq_len : length of input sequence
           num_layers: number of encoder layers
           expansion_factor: factor which determines number of linear layers in feed forward layer
           n_heads: number of heads in multihead attention

        """
        self.embedding_layer = Embedding(self.vocab_size, self.embed_dim, self.word_embedding)
        # self.word_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = PositionalEmbedding(self.seq_len, self.embed_dim)
        self.dropout = nn.Dropout(dropout_rate)

        self.layers = [
            DecoderBlock(self.embed_dim, expansion_factor=self.expansion_factor,
                         n_heads=self.n_heads, residue_link=self.residue_links)
            for _ in range(self.num_layers)
        ]
        self.fc_out = nn.Dense(self.vocab_size)

    def __call__(self, x, enc_out, mask, train=True):
        """
        Args:
            x: input vector from target
            enc_out : output from encoder layer
            trg_mask: mask for decoder self attention
        Returns:
            out: output vector
        """
        x = self.embedding_layer(x)  # 32x10x512
        # x = self.word_embedding(x)  # 32x10x512
        x = self.position_embedding(x)  # 32x10x512
        x = self.dropout(x, deterministic=not train)

        for layer in self.layers:
            x = layer(enc_out, x, enc_out, mask, train)

        out = nn.softmax(self.fc_out(x), axis=-1)

        return out

    def hidden_states(self, x, enc_out, mask, train=True):
        states = {}
        x = extend_states(states, 'embedding_layer',
                          self.embedding_layer.hidden_states(x))  # 32x10x512

        # x = self.word_embedding(x)  # 32x10x512
        x = states['position_embedding'] = self.position_embedding(x)  # 32x10x512
        x = states['dropout'] = self.dropout(x, deterministic=not train)

        for i, layer in enumerate(self.layers):
            x = extend_states(states, f'layer.{i}', layer.hidden_states(enc_out, x, enc_out, mask))

        x = states['fc_out'] = self.fc_out(x)
        out = states['softmax'] = nn.softmax(x, axis=-1)

        states[''] = out
        return states

## Transformer

In [528]:
class Transformer(nn.Module):
    embed_dim: int
    src_vocab_size: int
    target_vocab_size: int
    source_seq_length: int
    target_seq_length: int
    num_layers: int=2
    expansion_factor: int=4
    n_heads: int=8
    target_mask_fn: Optional[callable]=None
    source_embedding: Optional[jnp.ndarray]=None
    target_embedding: Optional[jnp.ndarray]=None
    decoder_residue_links: bool=True
    def __post_init__(self):
        super().__post_init__()
        if self.target_mask_fn is None:
            self.target_mask_fn = self.make_trg_mask
    def setup(self):
        """  
        Args:
           embed_dim:  dimension of embedding 
           src_vocab_size: vocabulary size of source
           target_vocab_size: vocabulary size of target
           seq_length : length of input sequence
           num_layers: number of encoder layers
           expansion_factor: factor which determines number of linear layers in feed forward layer
           n_heads: number of heads in multihead attention

        """

        self.encoder = TransformerEncoder(self.source_seq_length, self.src_vocab_size, self.embed_dim, num_layers=self.num_layers,
                                          expansion_factor=self.expansion_factor, n_heads=self.n_heads,
                                          word_embedding=self.source_embedding)
        self.decoder = TransformerDecoder(self.target_vocab_size, self.embed_dim, self.target_seq_length, num_layers=self.num_layers,
                                          expansion_factor=self.expansion_factor, n_heads=self.n_heads,
                                          word_embedding=self.target_embedding, residue_links=self.decoder_residue_links)



    def encode(self, src):
        return self.encoder(src)

    def decode(self, trg, mem):
        mask = self.make_trg_mask(trg)
        return self.decoder(trg, mem, mask)

    @staticmethod
    def make_trg_mask(trg):
        """
        Args:
            trg: target sequence
        Returns:
            trg_mask: target mask
        """
        batch_size, trg_len = trg.shape
        # returns the lower triangular part of matrix filled with ones
        trg_mask = jnp.tril(jnp.ones((trg_len, trg_len))).reshape(
            1, 1, trg_len, trg_len
        )
        # returns the upper triangular part of matrix filled with ones
        # trg_mask = torch.triu(torch.ones((trg_len, trg_len))).expand(
        #     1, 1, trg_len, trg_len
        # )
        return trg_mask

    def __call__(self, src, trg, train=True):
        """
        Args:
            src: input to encoder
            trg: input to decoder
        out:
            out: final vector which returns probabilities of each target word
        """
        batch_size, trg_len = trg.shape
        trg_mask = self.target_mask_fn(trg).reshape(1, 1, trg_len, trg_len)
        enc_out = self.encoder(src, train)
        outputs = self.decoder(trg, enc_out, trg_mask, train)
        return outputs

    def hidden_states(self, src, trg, train=True):
        batch_size, trg_len = trg.shape
        trg_mask = self.target_mask_fn(trg).expand(1, 1, trg_len, trg_len)
        states = {}
        extend_states(states, 'encoder', self.encoder.hidden_states(src, train))

        extend_states(states, 'decoder',
                      self.decoder.hidden_states(trg, states['encoder'], trg_mask, train))
        states[''] = states['decoder']
        return states

In [529]:
transformer = Transformer(embed_dim=20, src_vocab_size=10,
                          target_vocab_size=10, source_seq_length=9, target_seq_length=9,
                          num_layers=2, expansion_factor=4, n_heads=5,
                          # target_mask_fn=target_mask_fn,
                          # target_embedding=torch.tensor(sampled_embedding, dtype=torch.float),
                          # decoder_residue_links=True
                         )

In [530]:
x = jnp.ones((1, 9), dtype=int)
params = transformer.init({'params': params_key, 'dropout': dropout_key}, x, x)
params

{'params': {'encoder': {'embedding_layer': {'embed': {'embedding': Array([[-0.35061383, -0.03230591,  0.10626552,  0.42444903,  0.17031135,
             -0.02585911,  0.18707182,  0.29888493,  0.12937152, -0.1375768 ,
             -0.2042156 , -0.24436855, -0.18710352,  0.34488687,  0.14218727,
             -0.11418119,  0.04331478, -0.40311542, -0.16473353,  0.1425994 ],
            [-0.2837518 ,  0.07933215, -0.03812821, -0.13036345,  0.17909051,
             -0.00351502, -0.14935017,  0.22353616,  0.33425716,  0.16363573,
              0.20019042,  0.00570672,  0.18387173,  0.4413731 ,  0.10931236,
             -0.09923771, -0.21314052, -0.26975507, -0.19149259,  0.11731356],
            [ 0.1479116 , -0.19780372, -0.05210437, -0.15691863, -0.05059471,
              0.25138798,  0.09263865, -0.41243762,  0.03876812, -0.35490045,
              0.07417419, -0.20628579, -0.02170018, -0.08809549,  0.26343006,
              0.16287436, -0.05236056,  0.21390586,  0.10913108, -0.37126184],

In [531]:
def extend_states(states, mod_name, mod_states):
    for k, state in mod_states.items():
        if k:
            states[f'{mod_name}.{k}'] = state
        else:
            states[f'{mod_name}'] = state
    return mod_states['']

In [532]:
def prepare_data():
    dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')['train']
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def tokenize_function(examples):
        return tokenizer(examples['text'], padding="max_length", truncation=True)

    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    tokenized_datasets = tokenized_datasets.shuffle().batch(32)
    return tokenized_datasets

In [533]:
import pandas as pd
from datasets import Dataset
def prepare_data2():
    df = pd.read_csv('medium_articles/eng_de.csv', header=None, names=['en', 'de'])
    dataset = Dataset.from_pandas(df)
    # return dataset
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def tokenize_function(examples):
        ret = {
            'input_ids_en': tokenizer(examples['en'], padding='longest').input_ids,
            'input_ids_de': tokenizer(examples['de'], padding='longest').input_ids,
        }
        return ret

    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    # tokenized_datasets = tokenized_datasets.map(tokenize_function_tgt, batched=True)
    tokenized_datasets = tokenized_datasets.shuffle().batch(2)
    tokenized_datasets = tokenized_datasets.remove_columns(['en', 'de'])
    tokenized_datasets = tokenized_datasets.with_format("jax")

    return tokenized_datasets
prepare_data2()

Map:   0%|          | 0/36 [00:00<?, ? examples/s]

Batching examples:   0%|          | 0/36 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids_en', 'input_ids_de'],
    num_rows: 18
})

In [534]:
TGT_LANGUAGE = 'de'
SRC_LANGUAGE = 'en'
toydata = pd.read_csv('medium_articles/eng_de.csv', header=None, names=['en', 'de'])

In [535]:
tokenizers = {
    SRC_LANGUAGE: BertTokenizer.from_pretrained('bert-base-uncased'),
    TGT_LANGUAGE: BertTokenizer.from_pretrained('bert-base-uncased'),
}

In [536]:
SRC_VOCAB_SIZE = len(tokenizers[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(tokenizers[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3


In [537]:
SRC_VOCAB_SIZE, TGT_VOCAB_SIZE

(30522, 30522)

In [544]:
def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
    return loss.mean()
    
def create_train_state(rng, model, learning_rate):
    main_key, params_key, dropout_key = jax.random.split(key=rng, num=3)
    params = model.init({'params': params_key, 'dropout': dropout_key}, jnp.ones((1, 128), dtype=jnp.int32), jnp.ones((1, 128), dtype=jnp.int32))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [545]:
@jax.jit
def train_step(state, batch):
    print(batch)
    def loss_fn(params):
        tgt_input = batch['input_ids_de'][:, :-1]
        tgt_output = batch['input_ids_de'][:, 1:]
        print(tgt_input, tgt_output)
        logits = state.apply_fn({'params': params}, batch['input_ids_en'], tgt_input, rngs={'dropout': dropout_key})
        loss = cross_entropy_loss(logits, tgt_output)
        return loss

    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state


In [546]:
rng = jax.random.PRNGKey(0)
# model = Transformer(num_layers=6, num_heads=8, d_model=512, d_ff=2048, vocab_size=30522, max_len=512)
model = Transformer(embed_dim=512, src_vocab_size=30522,
                    target_vocab_size=30522, source_seq_length=512, target_seq_length=512,
                    num_layers=6, expansion_factor=4, n_heads=8,
                    # target_mask_fn=target_mask_fn,
                    # target_embedding=torch.tensor(sampled_embedding, dtype=torch.float),
                    # decoder_residue_links=True
                   )
state = create_train_state(rng, model, learning_rate)

In [550]:
def train_model(state, num_epochs, learning_rate):
    
    ds = prepare_data2()

    train_ds = ds[:20]
    val_ds = ds[20:]

    for epoch in range(num_epochs):
        for batch in train_ds:
            print(batch)
            state = train_step(state, batch)
        evaluate_model(state, val_ds)
        print(f"Epoch {epoch + 1} completed")


In [551]:
@jax.jit
def evaluate(model, params, batch):
    logits = model.apply({'params': params}, batch['input_ids_en'])
    predictions = jnp.argmax(logits, axis=-1)
    return predictions

def evaluate_model(state, val_ds):
    for batch in val_ds:
        predictions = evaluate(state.apply_fn, state.params, batch)
        # Compute accuracy or other metrics


In [549]:
num_epochs = 3
learning_rate = 1e-4
trained_state = train_model(state, num_epochs, learning_rate)


Map:   0%|          | 0/36 [00:00<?, ? examples/s]

Batching examples:   0%|          | 0/36 [00:00<?, ? examples/s]

{'input_ids_en': Array([[ 101, 2057, 3191, 2338,  102,    0,    0],
       [ 101, 2057, 2215, 2000, 3191, 3780,  102]], dtype=int32), 'input_ids_de': Array([[  101, 15536,  2099,  4649,  2368, 20934,  2818,   102,     0,
            0,     0,     0],
       [  101, 15536,  2099,  9587, 10143,  2368, 27838, 28813,  4649,
         2368,   102,     0]], dtype=int32)}
{'input_ids_de': Traced<ShapedArray(int32[2,12])>with<DynamicJaxprTrace(level=1/0)>, 'input_ids_en': Traced<ShapedArray(int32[2,7])>with<DynamicJaxprTrace(level=1/0)>}
Traced<ShapedArray(int32[2,11])>with<DynamicJaxprTrace(level=1/0)> Traced<ShapedArray(int32[2,11])>with<DynamicJaxprTrace(level=1/0)>






{'input_ids_en': Array([[ 101, 2057, 2064, 3191, 2338,  102,    0],
       [ 101, 2057, 4392, 5404,  102,    0,    0]], dtype=int32), 'input_ids_de': Array([[  101, 15536,  2099, 12849, 10087,  2078, 20934,  2818,  4649,
         2368,   102,     0],
       [  101, 15536,  2099, 13012,  8950,  2368, 12170,  2121,   102,
            0,     0,     0]], dtype=int32)}
{'input_ids_en': Array([[ 101, 2057, 2064, 4521, 7852,  102,    0],
       [ 101, 1045, 2215, 2000, 3191, 3780,  102]], dtype=int32), 'input_ids_de': Array([[  101, 15536,  2099, 12849, 10087,  2078, 22953,  2102, 29032,
          102,     0,     0],
       [  101, 22564,  9587, 10143,  2368, 27838, 28813,  4649,  2368,
          102,     0,     0]], dtype=int32)}
{'input_ids_en': Array([[ 101, 2057, 4392, 2300,  102,    0,    0],
       [ 101, 1045, 3191, 3780,  102,    0,    0]], dtype=int32), 'input_ids_de': Array([[  101, 15536,  2099, 13012,  8950,  2368,  2001,  8043,   102,
            0,     0,     0],
       [  101, 