# Transformer: BERT Style Encoder

## Functional API

### Input Embedding

In [1]:
import tensorflow as tf


def wpe(length: int, depth: int) -> tf.Tensor:
    """Word Positional Encoding.

    This function is used to encode the position of each token in the sequence,
    which is then added to the embedding of each token.

    Args:
        length: Sequence length.
        depth: Embedding depth.

    Returns:
        Positional encoding tensor of shape (length, depth).
    """
    depth = depth / 2

    # (seq, 1)
    positions = tf.range(length, dtype=tf.float32)[:, tf.newaxis]
    # (1, depth)
    depths = tf.range(depth, dtype=tf.float32)[tf.newaxis, :] / depth

    angle_rates = 1 / (10000**depths)  # (1, depth)
    angle_rads = positions * angle_rates  # (pos, depth)

    return tf.concat([tf.sin(angle_rads), tf.cos(angle_rads)], axis=-1)


def wte(x: tf.Tensor, vocab_size: int, depth: int) -> tf.Tensor:
    """Word Token Embedding.

    This function is used to encode each token in the sequence into a vector.

    Args:
        x: Token ID sequence tensor of shape (batch, seq).
        vocab_size: Vocabulary size.
        depth: Embedding depth.

    Returns:
        tf.Tensor: Embedding tensor of shape (batch, seq, depth).
    """
    return tf.keras.layers.Embedding(vocab_size, depth)(x)


def embedding(x: tf.Tensor, vocab_size: int, depth: int) -> tf.Tensor:
    """Input Embedding.

    This function is used to encode each token in the sequence into a vector,
    and then add the positional encoding to each token.

    Args:
        x: Token ID sequence tensor of shape (batch, seq).
        vocab_size: Vocabulary size.
        depth: Embedding depth.

    Returns:
        tf.Tensor: Embedding tensor of shape (batch, seq, depth).
    """
    seq_len = tf.shape(x)[1]
    pos_enc = wpe(seq_len, depth)
    x = wte(x, vocab_size, depth)
    return x + pos_enc

### Layers

In [2]:
def msa(x: tf.Tensor, n_head: int, d_mha: int, drop_rate: float) -> tf.Tensor:
    """Multi-head Self Attention.

    Args:
        x: input tensor of shape (batch_size, len_q, d_model)
        n_head: number of heads
        d_mha: dimension of multi-head attention
        drop_rate: dropout rate

    Returns:
        output tensor of shape (batch_size, len_q, d_model)
    """
    x_org = tf.identity(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.MultiHeadAttention(
        num_heads=n_head,
        key_dim=d_mha,
        dropout=drop_rate,
    )(query=x, key=x, value=x, return_attention_scores=False)
    return x + x_org

def ff(x: tf.Tensor, d_ff: int, d_model: int, drop_rate: float) -> tf.Tensor:
    """Feed Forward.

    Args:
        x: input tensor of shape (batch_size, len_q, d_model)
        d_ff: dimension of feed forward
        d_model: embedding size
        drop_rate: dropout rate

    Returns:
        output tensor of shape (batch_size, len_q, d_model)
    """
    x_org = tf.identity(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Dense(units=d_ff, activation="relu")(x)
    x = tf.keras.layers.Dense(units=d_model)(x)
    x = tf.keras.layers.Dropout(rate=drop_rate)(x)
    return x + x_org

### Encoder

In [3]:
def encoder(  # noqa: PLR0913
    x: tf.Tensor,
    n_head: int,
    d_mha: int,
    d_ff: int,
    d_model: int,
    drop_rate: float,
) -> tf.Tensor:
    """Encoder of Transformer.

    Args:
        x: Input tensor.
        n_head: Number of heads.
        d_mha: Dimension of multi-head attention.
        d_ff: Dimension of feed-forward layer.
        d_model: Dimension of embedding.
        drop_rate: Dropout rate.

    Returns:
        tf.Tensor: Output tensor.
    """
    x = ff(x, d_ff, d_model, drop_rate)
    return msa(x, n_head, d_mha, drop_rate)

### TX Model

In [4]:
def tx_func(  # noqa: PLR0913
    x: tf.Tensor,
    vocab_size: int,
    n_layer: int,
    n_head: int,
    d_model: int,
    d_mha: int,
    d_ff: int,
    d_label: int,
    drop_rate: float,
) -> tf.Tensor:
    """Transformer Model.

    Args:
        x: Input tensor (English / translated text) of shape (B, LEN_X).
        vocab_size: Vocabulary size.
        n_layer: Number of layers.
        n_head: Number of heads in multi-head attention.
        d_model: Model dimension.
        d_mha: Multi-head attention dimension.
        d_ff: Feed-forward dimension.
        d_label: Label dimension (vocab_size) for output layer.
        drop_rate: Dropout rate.

    Returns:
        tf.Tensor: translated (EN) text of shape (B, LEN_X, d_label).
    """
    x = embedding(x, vocab_size, d_model)
    x = tf.keras.layers.Dropout(drop_rate)(x)

    for _ in range(n_layer):
        x = encoder(x, n_head, d_mha, d_ff, d_model, drop_rate)

    return tf.keras.layers.Dense(d_label, activation="softmax")(x)

In [5]:
BATCH = 4
VOCAB = 32000
N_LAYER = 6
N_HEAD = 8
D_MODEL = 512
D_MHA = D_MODEL // N_HEAD
D_FF = D_MODEL * 4
D_LABEL = 30522
DROP_RATE = 0.1

In [6]:
def get_tx_encoder() -> tf.keras.Model:
    """Get standard Encoder TX transformer model.

    Returns:
        tf.keras.Model: the transformer model
    """
    inputs = tf.keras.Input(shape=(None, ), dtype=tf.int32, name="x")
    out = tx_func(
        inputs,
        VOCAB,
        N_LAYER,
        N_HEAD,
        D_MODEL,
        D_MHA,
        D_FF,
        D_LABEL,
        DROP_RATE,
    )
    return tf.keras.Model(inputs=inputs, outputs=out)

### Test

In [7]:
x = tf.random.uniform(
    shape=(BATCH, 10),
    minval=0,
    maxval=VOCAB,
    dtype=tf.int32,
)

tx = get_tx_encoder()
prd = tx(x, training=False)
print(prd.shape)

2023-08-21 00:29:39.815242: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1


(4, 10, 30522)


### Model Summary

In [8]:
tx.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 x (InputLayer)              [(None, None)]               0         []                            
                                                                                                  
 tf.compat.v1.shape (TFOpLa  (2,)                         0         ['x[0][0]']                   
 mbda)                                                                                            
                                                                                                  
 tf.__operators__.getitem (  ()                           0         ['tf.compat.v1.shape[0][0]']  
 SlicingOpLambda)                                                                                 
                                                                                              