# Imports

In [66]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Building a Custom RNN Cell

In [81]:
class CustomCell(layers.Layer):
    def __init__(self, chunk_size, units, **kwargs):
        self.chunk_size = chunk_size
        self.units = units
        self.state_size = tf.TensorShape([chunk_size, units])
        self.output_size = tf.TensorShape([chunk_size, units])

        # This is the point where we need to add our custom logic
        # instead of the MLP
        self.mlp = keras.Sequential([
            layers.Dense(units, activation="relu"),
            layers.Dense(units * 4, activation="relu"),
            layers.Dense(units, activation="relu"),
        ])
        
        super().__init__(**kwargs)

    def call(self, inputs, states):
        # inputs => (batch, chunk_size, dims)
        # states => [(batch, chunk_size, units)]

        prev_state = states[0]

        outputs = self.mlp(inputs)
        new_state = outputs + prev_state
        
        return outputs, [new_state]

    def get_config(self):
        return {"units": self.units, "chunk_size": chunk_size}

# Test the RNN Layer

In [83]:
keras.backend.clear_session()

units = 32
chunk_size = 8

dims = 16

batch_size = 64
num_batches = 10
timestep = 80

inputs = tf.random.normal(
    (batch_size, timestep//chunk_size, chunk_size, dims)
)

cell = CustomCell(units=units, chunk_size=chunk_size)
rnn = layers.RNN(cell, return_sequences=True)
rnn(inputs).shape

TensorShape([64, 10, 8, 32])