## Imports

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

ModuleNotFoundError: No module named 'tensorflow'

## Utils

In [None]:
class FeedForwardNetwork(layers.Layer):
    """Feed Forward Network.

    Args:
        dims (`int`): Dimension of the model.
        dropout (`float`): Dropout probability.
    """

    def __init__(self, dims: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.ffn = keras.Sequential(
            [
                layers.Dense(units=4*dims, activation=tf.nn.gelu),
                layers.Dense(units=dims),
                layers.Dropout(rate=dropout),
            ]
        )
        self.add = layers.Add()
        self.layernorm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, x: tf.Tensor):
        x = self.add([x, self.ffn(x)])
        x = self.layernorm(x)
        return x


class BaseAttention(layers.Layer):
    """The base attention module.
    
    Args:
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability.
    
    """

    def __init__(self, num_heads: int, key_dim: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.mha = layers.MultiHeadAttention(num_heads, key_dim, dropout=dropout)
        self.layernorm = layers.LayerNormalization(epsilon=1e-5)
        self.add = layers.Add()


class CrossAttention(BaseAttention):
    def call(self, x: tf.Tensor, context: tf.Tensor):
        (attention_outs, attention_scores) = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True,
        )
        self.last_attention_scores = attention_scores
        x = self.add([x, attention_outs])
        x = self.layernorm(x)
        return x


class SelfAttention(BaseAttention):
    def call(self, x):
        (attention_outputs, attention_scores) = self.mha(
            query=x,
            key=x,
            value=x,
            return_attention_scores=True,
        )
        self.last_attention_scores = attention_scores
        x = self.add([x, attention_outputs])
        x = self.layernorm(x)
        return x

class SelfAttentionWithFFN(layers.Layer):
    """Self-attention module with FFN

    Args:
        ffn_dims (`int`): Number of units in FFN
        ffn_dropout (`float`): Dropout probability for FFN
        num_heads (`int`): Number of attention heads
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability.

    """
    def __init__(self, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
                 num_heads: int = 4, key_dim: int = 256, dropout: float = 0.1, **kwargs):
        self.self_attention = SelfAttention(num_heads, key_dim, dropout, **kwargs)
        self.ffn = FeedForwardNetwork(ffn_dims, ffn_dropout)
    
    def call(self, x):
        attention_outputs = self.self_attention(x)
        x = self.ffn(attention_outputs)
        return x

class CrossAttentionWithFFN(layers.Layer):
    """Cross-attention module with FFN

    Args:
        ffn_dims (`int`): Number of units in FFN
        ffn_dropout (`float`): Dropout probability for FFN
        num_heads (`int`): Number of attention heads
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability.

    """
    def __init__(self, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
                 num_heads: int = 4, key_dim: int = 256, dropout: float = 0.1, **kwargs):
        self.cross_attention = CrossAttention(num_heads, key_dim, dropout, **kwargs)
        self.ffn = FeedForwardNetwork(ffn_dims, ffn_dropout)
    
    def call(self, x, context):
        attention_outputs = self.cross_attention(x, context)
        x = self.ffn(attention_outputs)
        return x

## Rough implementation of full model

As I can understand from the custom RNN, a large doubt of mine on how I could handle the size of the batch and chunks within the init function scope was solved, the implementation was more or less straightforward, but makes a few assumptions and has severe blocking overheads we need to try and optimize

- Interesting question to ask: When initializing, do we take `num_layers` as the limiting factor, or `R` as the limiting factor?
Example: If we have `num_layers = 5` and `R = 2`, to maintain the last layer as Self-Attention, we'd have to exceed `num_layers`.
On the other hand, if we kept num_layers as limiting, we'd have a Cross Attention layer as the last one. 
The implementation below puts `R` as the limiting factor.

- There is a O(n^2) loop inside the Perceptual Module initialization, which became necessary to create the dynamic number of layers according to the restrictions of `num_layers` and `R`. Any way to optimize/vectorize that one?

- There is a really bad-looking loop in the `call` function that is necessary to identify the SelfAttention and CrossAttention layers and then accordingly give them an input. Any solutions for that?

- Obviously, need to test.

In [None]:
class CustomCell(layers.Layer):
    def __init__(self, chunk_size, r, num_layers: int = 5, ffn_dims: int = 128, ffn_dropout: float = 0.1, 
                 num_heads: int = 4, key_dim: int = 256, dropout: float = 0.1, **kwargs):
        self.chunk_size = chunk_size
        self.r = r
        self.num_layers = num_layers
        self.ffn_dims = ffn_dims
        self.ffn_droput = ffn_dropout
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.dropout = dropout

        self.state_size = tf.TensorShape([chunk_size, ffn_dims])
        self.output_size = tf.TensorShape([chunk_size, ffn_dims])

        # This is the point where we need to add our custom logic
        # instead of the MLP

        # Update: Added some possible logic to mimic the Perceptual Module behavior, although testing is needed

        # Perceptual Module
        perceptual_module_layers = []
        while len(perceptual_module_layers) <= num_layers:
            layers.append(SelfAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout))
            if len(perceptual_module_layers) == num_layers:
                break
            else:
                self_attention_layers = [SelfAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout) for _ in range(r-1)]
                if len(self_attention_layers) != 0:
                    perceptual_module_layers.append(*self_attention_layers)
                perceptual_module_layers.append(CrossAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout))
    
        self.perceptual_module = perceptual_module_layers

        # Update: Added some possible logic to mimic the TLB module behavior, although testing is again needed

        # Temporal Latent Bottleneck
        self.tlb_cross_attention = CrossAttentionWithFFN(ffn_dims, ffn_dropout, num_heads, key_dim, dropout)


        # self.mlp = keras.Sequential([
        #     layers.Dense(units, activation="relu"),
        #     layers.Dense(units * 4, activation="relu"),
        #     layers.Dense(units, activation="relu"),
        # ])
        
        super().__init__(**kwargs)

    def possible_implemented_call(self, inputs, states):
        slow_stream = states[0]
        
        fast_stream = inputs # chunks
        for layer in self.perceptual_module:
            if isinstance(layer, SelfAttentionWithFFN):
                fast_stream = layer(fast_stream, fast_stream, fast_stream)
            elif isinstance(layer, CrossAttentionWithFFN):
                fast_stream = layer(fast_stream, slow_stream, slow_stream)

        slow_stream = self.tlb_cross_attention(slow_stream, fast_stream, fast_stream)

        return fast_stream, [slow_stream]

    def call(self, inputs, states):
        # inputs => (batch, chunk_size, dims)
        # states => [(batch, chunk_size, units)]

        prev_state = states[0]

        outputs = self.mlp(inputs)

        new_state = outputs + prev_state
        
        return outputs, [new_state]

    def get_config(self):
        return {"units": self.units, "chunk_size": self.chunk_size}