In [None]:
import keras

In [None]:
# feedforward
# Simple FeedForward layer. Zelfde als in GPT decoder.
@keras.saving.register_keras_serializable()
class FeedForward(keras.layers.Layer):

  def __init__(self, factor=4, **kwargs):
    super().__init__(**kwargs)
    self.factor = factor
    self.relu = keras.activations.relu

  def build(self, batch_input_shape):
    time_steps, embed_size = batch_input_shape[1:]
    self.kernel1 = self.add_weight(shape=(embed_size, self.factor*embed_size))
    self.bias1 = self.add_weight(shape=(self.factor*embed_size, ),
                                 initializer="zeros")
    self.kernel2 = self.add_weight(shape=(self.factor*embed_size, embed_size))
    self.bias2 = self.add_weight(shape=(embed_size, ),
                                 initializer="zeros")

  def call(self, inputs):
    a =  self.relu(keras.ops.matmul(inputs, self.kernel1) + self.bias1)
    return keras.ops.matmul(a, self.kernel2) + self.bias2

  def get_config(self):
    base_config = super().get_config()
    return {
        **base_config,
        "factor": self.factor,
    }
# EmbeddingWithPosition
# Embedding with Position.
@keras.saving.register_keras_serializable()
class EmbeddingWithPosition(keras.layers.Layer):
  """
  Computes and embedding and also adds a positional embedding.
  This layer does not support masking.
  """

  def __init__(self, num_tokens, max_seq_length, embed_size, dropout=0, **kwargs):
    super().__init__(**kwargs)
    self.num_tokens = num_tokens
    self.max_seq_length = max_seq_length
    self.embed_size = embed_size
    self.dropout = dropout

  def build(self, batch_input_shape):
    # Shape not needed

    self.kernel = self.add_weight(shape=(self.num_tokens, self.embed_size))
    self.pos_kernel = self.add_weight(shape=(self.max_seq_length, self.embed_size))
    if self.dropout > 0:
        self.dropout_layer = keras.layers.Dropout(self.dropout)

  def call(self, inputs):
    _, length = keras.ops.shape(inputs)

    embeddings = keras.ops.take(self.kernel, inputs, axis=0) # (batch, length, embed_size)
    pos_embeddings = self.pos_kernel[:length]

    s = embeddings + pos_embeddings  # rely on broadcasting. Mask is lost

    return s if self.dropout == 0 else self.dropout_layer(s)

  def get_config(self):
    base_config = super().get_config()
    return {
        **base_config,
        "num_tokens": self.num_tokens,
        "max_seq_length": self.max_seq_length,
        "embed_size": self.embed_size,
        "dropout": self.dropout
    }

# Encoder

![](../image/encoder.png)

## init

In [None]:
def __init__(self, num_heads, embed_size, dropout=0, **kwargs):
    super().__init__(**kwargs)
    self.num_heads = num_heads
    self.embed_size = embed_size
    self.dropout = dropout
    if self.dropout > 0:
        self.dropout_layer = keras.layers.Dropout(self.dropout)

    self.masked_multi_head_attn = keras.layers.MultiHeadAttention(
        num_heads=self.num_heads,
        key_dim = self.embed_size // self.num_heads
    )
    self.layer_norm_1 = keras.layers.LayerNormalization()
    self.feed_forward = FeedForward()
    self.layer_norm_2 = keras.layers.LayerNormalization()

## call

In [None]:
def call(self, inputs, attention_mask=None):
    skip = inputs
    inputs = self.masked_multi_head_attn(inputs, inputs, attention_mask=attention_mask)
    if self.dropout != 0:
        inputs = self.dropout_layer(inputs)
    inputs = self.layer_norm_1(keras.layers.Add()([inputs, skip]))

    skip = inputs
    inputs = self.feed_forward(inputs)
    if self.dropout != 0:
        inputs = self.dropout_layer(inputs)
    inputs = self.layer_norm_2(keras.layers.Add()([skip, inputs]))
    return inputs

# get model

In [None]:
def get_classification_encoder_model(num_tokens, max_seq_length, embed_size, num_heads, num_blocks, num_classes, use_mask, scale_embeddings, dropout):

  inputs = keras.layers.Input(shape=[max_seq_length], dtype=int)
  mask, attention_mask = None, None
  
  if use_mask:
    mask = keras.ops.not_equal(inputs, 0)
    attention_mask = keras.ops.expand_dims(mask, axis=1)

  embed = EmbeddingWithPosition(num_tokens=num_tokens, max_seq_length=max_seq_length, embed_size=embed_size, dropout=dropout, name="pos_embedding")(inputs)

  for index in range(num_blocks):
    embed = EncoderBlock(num_heads=num_heads, embed_size=embed_size, dropout=dropout, name=f"encoder_block_{index}")(embed, attention_mask)

  # Simple classification head
  embed = keras.layers.GlobalAveragePooling1D()(embed, mask=mask)
  if scale_embeddings:
      scale = keras.ops.sqrt(keras.ops.sum(mask, axis=1, keepdims=True))
      embed = embed * scale

  embed = keras.layers.Dense(units=embed_size, activation='relu')(embed)

  if dropout > 0:
      embed = keras.layers.Dropout(dropout)(embed)
  if num_classes == 1:
    output = keras.layers.Dense(units=1, activation='sigmoid')(embed)
  else:
    output = keras.layers.Dense(units=num_classes, activation='softmax')(embed)

  return keras.Model(inputs=inputs, outputs=output)

In [None]:
def get_classification_encoder_model(num_tokens: int, max_seq_length: int, embed_size: int, num_heads:int,
    num_blocks:int, num_classes:int, use_mask:bool=False, scale_embeddings:bool=False, dropout:float=0):

  inputs = keras.layers.Input(shape=[max_seq_length], dtype=int)
  mask, attention_mask = None, None
  if use_mask:
    mask = keras.ops.not_equal(inputs, 0)
    attention_mask = keras.ops.expand_dims(mask, axis=1)

  embed = EmbeddingWithPosition(num_tokens=num_tokens,
                                  max_seq_length=max_seq_length,
                                  embed_size=embed_size,
                                  dropout=dropout,
                                  name="pos_embedding")(inputs)

  for index in range(num_blocks):
    embed = EncoderBlock(num_heads=num_heads, embed_size=embed_size, dropout=dropout,
                         name=f"encoder_block_{index}")(embed, attention_mask)

  # Simple classification head
  embed = keras.layers.GlobalAveragePooling1D()(embed, mask=mask) # (B, embed_size)
  if scale_embeddings:
      scale = keras.ops.sqrt(keras.ops.sum(mask, axis=1, keepdims=True))
      embed = embed * scale

  embed = keras.layers.Dense(units=embed_size, activation='relu')(embed)
  # Additional dropout just before output layer
  if dropout > 0:
      embed = keras.layers.Dropout(dropout)(embed)
  if num_classes == 1:
    output = keras.layers.Dense(units=1, activation='sigmoid')(embed)
  else:
    output = keras.layers.Dense(units=num_classes, activation='softmax')(embed)

  return keras.Model(inputs=inputs, outputs=output)

# decoder

encoder-decoder attention zit niet bij een enkel decoder model aangezien het input van de encoder nodig heeft

![](../image/decoder.png)

decoder bestaat uit:
- masked multi head attention
- layer normalization
- feedforward layer
- layer normalization

## init

init is de achitectuur van het model

In [None]:
def __init__(self, num_heads, embed_size, **kwargs):
    super().__init__(**kwargs)
    self.num_heads = num_heads
    self.embed_size = embed_size

    self.masked_multi_head_attn = keras.layers.MultiHeadAttention(
        num_heads=self.num_heads,
        key_dim = self.embed_size // self.num_heads
    )
    
    self.layer_norm_1 = keras.layers.LayerNormalization()
    self.feed_forward = FeedForward()
    self.layer_norm_2 = keras.layers.LayerNormalization()

## call

dit is hoe de data door de architectuur stroomt.

In [None]:
def call(self, inputs):

    # dit is voor skip
    skip = inputs
    inputs = self.masked_multi_head_attn(inputs, inputs, use_causal_mask=True)
    inputs = self.layer_norm_1(keras.layers.Add()([inputs, skip]))

    skip = inputs
    inputs = self.feed_forward(inputs)
    inputs = self.layer_norm_2(keras.layers.Add()([skip, inputs]))
    return inputs

# encoder-decoder decoder

![](../image/decoder.png)

decoder bestaat uit:
- masked multi head attention
- layer normalization
- encoder-decoder attention
- layer normalization
- feedforward layer
- layer normalization

## init

init is de achitectuur van het model

In [None]:
def __init__(self, num_heads, embed_size, **kwargs):
    super().__init__(**kwargs)
    self.num_heads = num_heads
    self.embed_size = embed_size

    self.masked_multi_head_attn = keras.layers.MultiHeadAttention(
        num_heads=self.num_heads,
        key_dim = self.embed_size // self.num_heads
    )
    self.layer_norm_1 = keras.layers.LayerNormalization()

    self.cross_attn = keras.layers.MultiHeadAttention(
        num_heads=self.num_heads,
        key_dim = self.embed_size // self.num_heads
    )
    self.layer_norm_2 = keras.layers.LayerNormalization()

    self.feed_forward = FeedForward()
    self.layer_norm_3 = keras.layers.LayerNormalization()

## call

dit is hoe de data door de architectuur stroomt.

In [None]:
def call(self, inputs, cross_attention_mask=None):
    decoder_embs = inputs[0]
    encoder_embs = inputs[1]

    skip = decoder_embs
    decoder_embs = self.masked_multi_head_attn(
      decoder_embs, decoder_embs,  use_causal_mask=True)
    decoder_embs = self.layer_norm_1(keras.layers.Add()([decoder_embs, skip]))

    skip = decoder_embs
    decoder_embs = self.cross_attn(
      query=decoder_embs, value=encoder_embs, key=encoder_embs,
      attention_mask=cross_attention_mask)
    decoder_embs = self.layer_norm_2(keras.layers.Add()([decoder_embs, skip]))

    skip = decoder_embs
    decoder_embs = self.feed_forward(decoder_embs)
    decoder_embs = self.layer_norm_3(keras.layers.Add()([skip, decoder_embs]))
    return decoder_embs