# Imports

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Utils

In [None]:
class FeedForwardNetwork(layers.Layer):
    """Feed Forward Network.

    Args:
        dims (`int`): Dimension of the model.
        dropout (`float`): Dropout probability.
    """

    def __init__(self, dims: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.ffn = keras.Sequential(
            [
                layers.Dense(units=4*dims, activation=tf.nn.gelu),
                layers.Dense(units=dims),
                layers.Dropout(rate=dropout),
            ]
        )
        self.add = layers.Add()
        self.layernorm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, x: tf.Tensor):
        x = self.add([x, self.ffn(x)])
        x = self.layernorm(x)
        return x


class BaseAttention(layers.Layer):
    """The base attention module.
    
    Args:
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability.
    
    """

    def __init__(self, num_heads: int, key_dim: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.mha = layers.MultiHeadAttention(num_heads, key_dim, dropout=dropout)
        self.layernorm = layers.LayerNormalization(epsilon=1e-5)
        self.add = layers.Add()


class CrossAttention(BaseAttention):
    def call(self, x: tf.Tensor, context: tf.Tensor):
        (attention_outs, attention_scores) = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True,
        )
        self.last_attention_scores = attention_scores
        x = self.add([x, attention_outs])
        x = self.layernorm(x)
        return x


class SelfAttention(BaseAttention):
    def call(self, x):
        (attention_outputs, attention_scores) = self.mha(
            query=x,
            key=x,
            value=x,
            return_attention_scores=True,
        )
        self.last_attention_scores = attention_scores
        x = self.add([x, attention_outputs])
        x = self.layernorm(x)
        return x

In [None]:
# Basic implementation - No super class inheritance below >>

class SelfAttentionWithFFN(layers.Layer):
    """Self-attention module with FFN

    Args:
        ffn_dims (`int`): Number of units in FFN
        ffn_dropout (`float`): Dropout probability for FFN
        num_heads (`int`): Number of attention heads
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability.

    """
    def __init__(self, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
                 num_heads: int = 4, key_dim: int = 256, dropout: float = 0.1, **kwargs):
        self.self_attention = SelfAttention(num_heads, key_dim, dropout, **kwargs)
        self.ffn = FeedForwardNetwork(ffn_dims, ffn_dropout)
    
    def call(self, x):
        attention_outputs = self.self_attention(x)
        x = self.ffn(attention_outputs)
        return x

class CrossAttentionWithFFN(layers.Layer):
    """Cross-attention module with FFN

    Args:
        ffn_dims (`int`): Number of units in FFN
        ffn_dropout (`float`): Dropout probability for FFN
        num_heads (`int`): Number of attention heads
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability.

    """
    def __init__(self, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
                 num_heads: int = 4, key_dim: int = 256, dropout: float = 0.1, **kwargs):
        self.cross_attention = CrossAttention(num_heads, key_dim, dropout, **kwargs)
        self.ffn = FeedForwardNetwork(ffn_dims, ffn_dropout)
    
    def call(self, x, context):
        attention_outputs = self.cross_attention(x, context)
        x = self.ffn(attention_outputs)
        return x

# Nuanced implementation - Super class inheritance performed below >>

# class SelfAttentionWithFFN(SelfAttention):
#     def __init__(self, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
#           num_heads: int = 4, key_dim: int = 256, dropout: float = 0.1, **kwargs):
#         super(SelfAttentionWithFFN, self).__init__(num_heads, key_dim, dropout, **kwargs)
#         self.ffn = FeedForwardNetwork(ffn_dims, ffn_dropout)

#     def call(self, x):
#         x = super.call(x)
#         x = self.ffn(x)
#         return x

# class CrossAttentionWithFFN(CrossAttention):
#     def __init__(self, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
#                  num_heads: int = 4, key_dim: int = 256, dropout: float = 0.1, **kwargs):
#         super(CrossAttentionWithFFN, self).__init__(num_heads, key_dim, dropout, **kwargs)
#         self.ffn = FeedForwardNetwork(ffn_dims, ffn_dropout)
    
#     def call(self, x, context):
#         x = super.call(x, context)
#         x = self.ffn(x)
#         return x

# Perceptual Module (WIP - sharp edges)

In [None]:
class PerceptualModule(layers.Layer):
    def __init__(self, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
                 num_heads: int = 4, key_dim: int = 256, dropout: float=0.1, **kwargs):
        self.sa_ffn1 = SelfAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout)
        self.ca_ffn1 = CrossAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout)
        self.sa_ffn2 = SelfAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout)
        self.ca_ffn2 = CrossAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout)
        self.sa_ffn3 = SelfAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout)

        # Need to supply a `state_size` and `output_size` to the RNN Layer from within the cell 
        # to indicate what sizes it should allocate for intermediate stages
        self.state_size = None
        self.output_size = None

    def call(self, x: tf.Tensor, slow_stream_context: tf.Tensor):
        x = self.sa_ffn1(x)
        x = self.ca_ffn1(x, slow_stream_context)
        x = self.sa_ffn2(x)
        x = self.ca_ffn2(x, slow_stream_context)
        x = self.sa_ffn3(x)
        return x

# Temporal Latent Bottleneck Model