In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

class CustomConcat(Layer):
    def __init__(self):
        super(CustomConcat, self).__init__()

    def call(self, inputs):
        return tf.concat(inputs, -1)
        
    def compute_mask(self, inputs, mask=None):
        return inputs[0]._keras_mask

class CustomDot(Layer):
    def __init__(self):
        super(CustomDot, self).__init__()

    def call(self, inputs):
        return tf.matmul(inputs[0], inputs[1])
        
    def compute_mask(self, inputs, mask=None):
        return inputs[0]._keras_mask

class CustomActivation(Layer):
    def __init__(self):
        super(CustomActivation, self).__init__()
        self.activation = tf.keras.layers.Activation("relu")

    def call(self, inputs):
        s1, s2, s3 = tf.split(inputs, num_or_size_splits=3, axis=-1)
        s2 = self.activation(s2)
        output = tf.concat([s1, s2, s3], axis=-1)
        return output
        
    def compute_mask(self, inputs, mask=None):
        return mask

# This adds positional encoding as integers to the embeddings
class PositionalIndexAppender(Layer):
    def __init__(self):
        super(PositionalIndexAppender, self).__init__()

    def call(self, inputs):
        # inputs shape: (batch_size, sequence_length, embedding_dim)
        batch_size, sequence_length, _ = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]
        
        # Generate positional indices for the sequence length and tile for each batch
        positions = tf.range(1, sequence_length + 1, dtype=tf.float32)  # Shape: [sequence_length]
        positions = tf.reshape(positions, [1, sequence_length, 1])      # Shape: [1, sequence_length, 1]
        positions = tf.tile(positions, [batch_size, 1, 1])              # Shape: [batch_size, sequence_length, 1]

        # Concatenate along the last dimension
        return tf.concat([inputs, positions], axis=-1)
        
    def compute_mask(self, inputs, mask=None):
        return mask


class LoongStyleUnit(Layer):
    def __init__(self, units):
        super(LoongStyleUnit, self).__init__()
        self.state_size = units
        self.output_size = units

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel_1 = self.add_weight(shape=(self.state_size + input_dim, self.state_size))
        self.bias_1 = self.add_weight(shape=(self.state_size,))
        self.kernel_2 = self.add_weight(shape=(self.state_size * 2, self.state_size))
        self.bias_2 = self.add_weight(shape=(self.state_size,))
        self.kernel_3 = self.add_weight(shape=(self.state_size, self.state_size))
        self.bias_3 = self.add_weight(shape=(self.state_size,))
        self.kernel_4 = self.add_weight(shape=(input_dim, self.state_size))
        self.bias_4 = self.add_weight(shape=(self.state_size,))
        self.kernel_5 = self.add_weight(shape=(self.state_size * 2, self.state_size))
        self.bias_5 = self.add_weight(shape=(self.state_size,))
        self.kernel_6 = self.add_weight(shape=(self.state_size * 2, self.state_size))
        self.bias_6 = self.add_weight(shape=(self.state_size,))
        self.kernel_7 = self.add_weight(shape=(self.state_size, self.state_size))
        self.bias_7 = self.add_weight(shape=(self.state_size,))
        super(LoongStyleUnit, self).build(input_shape)

    def get_initial_state(self, batch_size=None):
        return [tf.zeros((batch_size, self.state_size))] 
    
    def call(self, inputs, states):
        states = states[0]

        # This part is similar to LSTM, but use more steps to generate values of gate.
        # It first does a linear transformation with the inputs and previous hidden states.
        # if we only do linear transformation once, no matter what the hidden states are, the model will do the same operation to the new inputs, and add that results to generate output.
        # That means when we process inputs, we do not consider the previous context. That is why LSTM perform better than basic RNN, because it uses gate instead of generating values directly.
        # In this RNN unit, it pushes this further. It embeds the information from both sides to the result of the first linear transformation.
        # The model then look at that result and the original information from one side to give a value. Finaly, it takes the third linear transformation by itself, which will be used as the gate value.
        # After each linear transformation, it does a layer normalization step.
        # This part acts as an activation function, otherwise all these steps can be summarized and expresssed by one linear transformation, which makes all the effort usesless. 
        v1 = tf.matmul(tf.concat([states, inputs], axis=-1), self.kernel_1) + self.bias_1
        v1_mean, v1_variance = tf.nn.moments(v1, axes=[-1], keepdims=True)
        v1 = (v1 - v1_mean) / tf.sqrt(v1_variance + 0.001)

        v2 = tf.matmul(tf.concat([v1, states], axis=-1), self.kernel_2) + self.bias_2
        v2_mean, v2_variance = tf.nn.moments(v2, axes=[-1], keepdims=True)
        v2 = (v2 - v2_mean) / tf.sqrt(v2_variance + 0.001)

        v3 = tf.matmul(v2, self.kernel_3) + self.bias_3

        s1 = tf.matmul(inputs, self.kernel_4) + self.bias_4

        s2 = tf.matmul(tf.concat([states, s1], axis=-1), self.kernel_5) + self.bias_5
        s2_mean, s2_variance = tf.nn.moments(s2, axes=[-1], keepdims=True)
        s2 = (s2 - s2_mean) / tf.sqrt(s2_variance + 0.001)

        s3 = tf.matmul(tf.concat([s2, s1], axis=-1), self.kernel_6) + self.bias_6
        s3_mean, s3_variance = tf.nn.moments(s3, axes=[-1], keepdims=True)
        s3 = (s3 - s3_mean) / tf.sqrt(s3_variance + 0.001)

        s4 = tf.matmul(s3, self.kernel_7) + self.bias_7
        
        final = v3 * states + s4 * s1
        mean, variance = tf.nn.moments(final, axes=[-1], keepdims=True)
        final = (final - mean) / tf.sqrt(variance + 0.001)
        
        return final, [final]

In [None]:
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    pos_append = PositionalIndexAppender()
    custom_dot = CustomDot()
    activation = CustomActivation()
    custom_concat = CustomConcat()

    # In Transformer architicture, it uses weights to select values to generate new value, but if you do it this way, the new value is not percise enough. 
    # Assume the input vectors has length N. In this method, It first uses a special RNN unit to go over all the inputs of the sequence, and generate an embedding with length M.
    # and it does a linear transformation using this embedding with a matirx of shape (N X N) X  M.
    # Next, reshape the resulting vector to a matrix with shape N x N. This matrix is dynamic, its values depends on the inputs, whereas the previous (N X N) X  M is fixed and trained.
    # Then, times all input vectors with the matrix to output new vectors with length N.
    # All the steps above is one layer of the structure, and can be repeated many times.
    # After several layers, concatanate the output of all the layers. if you have Z layers, the length of the new vector will be ZN.
    # Finally, use the special RNN unit to process the whole sequence to give the final result(after adding several Dense layers).
    def loong_style_layer(input):
        info = tf.keras.layers.RNN(cell=LoongStyleUnit(150))(input)
        matrix = tf.keras.layers.Dense(9900)(info)
        matrix = tf.keras.layers.Reshape((100,99))(matrix)
        output = custom_dot([input, matrix])
        output = tf.keras.layers.LayerNormalization(center=False, scale=False)(output)
        output = activation(output)
        output = tf.keras.layers.Dropout(0.2)(output)
        return output
    
    input = tf.keras.Input(shape = (None,))
    output = tf.keras.layers.Embedding(10000, 99, mask_zero=True)(input)
    output = pos_append(output)
    outputs = []
    for i in range(12):
        output = loong_style_layer(output)
        outputs.append(output)
        output = pos_append(output)
    output = custom_concat(outputs)
    output = pos_append(output)
    output = tf.keras.layers.RNN(cell=LoongStyleUnit(1200))(output)
    output = activation(output)
    output = tf.keras.layers.Dense(1000, activation='softmax')(output)

    model = tf.keras.Model(inputs=input, outputs=output)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy')