# Imports

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Utils

In [None]:
class FeedForwardNetwork(layers.Layer):
    """Feed Forward Network.

    Args:
        dims (`int`): Dimension of the model.
        dropout (`float`): Dropout probability.
    """

    def __init__(self, dims: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.ffn = keras.Sequential(
            [
                layers.Dense(units=4*dims, activation=tf.nn.gelu),
                layers.Dense(units=dims),
                layers.Dropout(rate=dropout),
            ]
        )
        self.add = layers.Add()
        self.layernorm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, x: tf.Tensor):
        x = self.add([x, self.ffn(x)])
        x = self.layernorm(x)
        return x


class BaseAttention(layers.Layer):
    """The base attention module.
    
    Args:
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability.
    
    """

    def __init__(self, num_heads: int, key_dim: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.mha = layers.MultiHeadAttention(num_heads, key_dim, dropout=dropout)
        self.layernorm = layers.LayerNormalization(epsilon=1e-5)
        self.add = layers.Add()


class CrossAttention(BaseAttention):
    def call(self, x: tf.Tensor, context: tf.Tensor):
        (attention_outs, attention_scores) = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True,
        )
        self.last_attention_scores = attention_scores
        x = self.add([x, attention_outs])
        x = self.layernorm(x)
        return x


class SelfAttention(BaseAttention):
    def call(self, x):
        (attention_outputs, attention_scores) = self.mha(
            query=x,
            key=x,
            value=x,
            return_attention_scores=True,
        )
        self.last_attention_scores = attention_scores
        x = self.add([x, attention_outputs])
        x = self.layernorm(x)
        return x

# Perceptual Module

# Temporal Latent Bottleneck Model