In [1]:
%tensorflow_version 2.x


TensorFlow 2.x selected.


In [2]:
import tensorflow as tf
import numpy as np


class Text_Encode_Embedding(tf.keras.layers.Layer):
    """
    Tensorflow Subclass layer
    The same embedding layer
    """

    def __init__(self, vocab_size, num_units, **kwargs):
        super(Text_Encode_Embedding, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.num_units = num_units
        self.weight_init = tf.keras.initializers.TruncatedNormal(0.0, 0.1)
        self.text_emb = tf.keras.layers.Embedding(self.vocab_size, self.num_units,
                                                  embeddings_initializer=self.weight_init)

    def call(self, x):
        return self.text_emb(x)


class TTS_Layer_Norm(tf.keras.layers.Layer):
    '''
    Tensorlfow Subclass layer
    The same norm_layer
    '''

    def __init__(self):
        super(TTS_Layer_Norm, self).__init__()
        self.norm_layer = tf.keras.layers.LayerNormalization()

    def call(self, x):
        return self.norm_layer(x)


class Dense_Highway(tf.keras.layers.Layer):
    '''
    y = H(x,WH)· T(x,WT) + x · (1 −T(x,WT)).
    '''

    def __init__(self, num_units, bias_init_value):
        super(Dense_Highway, self).__init__()
        self.num_units = num_units
        self.bias_init_value = bias_init_value
        # self.T_bias_init = tf.keras.initializers.constant(value=self.bias_init_value)

        self.H = tf.keras.layers.Dense(self.num_units,
                                       activation='relu')

        self.T = tf.keras.layers.Dense(self.num_units,
                                       activation='sigmoid',
                                       bias_initializer=tf.keras.initializers.constant(self.bias_init_value))

    def call(self, x):
        h = self.H(x)
        t = self.T(x)
        transform_gate = tf.keras.layers.Multiply()([h, t])
        carry_gate = tf.keras.layers.Lambda(lambda x: 1. - x)(t)
        carry_gate = tf.keras.layers.Multiply()([x, carry_gate])
        return tf.keras.layers.Add()([transform_gate, carry_gate])

# ------------------------------------------------- #

#               Test Environments                   #

# ------------------------------------------------- #

# Get Data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=10000)
x_train.shape
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, 100)
x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, 100)
print(x_train.shape)
print(y_train.shape)


# Model Builder
input_ = tf.keras.layers.Input(shape=(100,))
emb = Text_Encode_Embedding(10000 + 1, 100)(input_)
emb = tf.keras.layers.Flatten()(emb)
x = tf.keras.layers.Dense(256, activation='relu')(emb)
x = TTS_Layer_Norm()(x)
x = Dense_Highway(256, -3.)(x)

# x = Dense_Highway(256, -3.)(x)
# # x = TTS_Layer_Norm()(x)
out_ = tf.keras.layers.Dense(1, activation='sigmoid')(x)

test_model = tf.keras.models.Model(input_, out_)

test_model.summary()
test_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
hist = test_model.fit(x_train, y_train, batch_size=128, epochs=100, validation_data=(x_test, y_test))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
(25000, 100)
(25000,)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
text__encode__embedding (Tex (None, 100, 100)          1000100   
_________________________________________________________________
flatten (Flatten)            (None, 10000)             0         
_________________________________________________________________
dense (Dense)                (None, 256)               2560256   
_________________________________________________________________
tts__layer__norm (TTS_Layer_ (None, 256)               512       
_________________________________________________________________
dense__highway (Dense_Highwa (None, 256)               131584    
______________

KeyboardInterrupt: ignored