In [43]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

class Attention(layers.Layer):
    def __init__(self, unit):
        super(Attention, self).__init__()
        self.unit = unit

    def build(self, input_shape):
        assert len(input_shape) == 3
        
        self.weight = self.add_weight(shape=(input_shape[-1], self.unit), 
                                      initializer=keras.initializers.RandomNormal(),
                                      trainable=True)
        self.features_dim = input_shape[-1]
        self.bias = self.add_weight(shape=(self.unit,),
                                    initializer=keras.initializers.Zeros(),
                                    trainable=True)
        self.u = self.add_weight(shape=(self.unit,),
                                 initializer=keras.initializers.RandomNormal(),
                                 trainable=True)

    def call(self, inputs, mask=None):
        v = tf.tanh(tf.tensordot(inputs, self.weight, axes=1) + self.bias)
        vu = tf.tensordot(v, self.u, axes=1)
        alphas = tf.nn.softmax(vu)
        output = tf.reduce_sum(inputs * tf.expand_dims(alphas, axis=-1), axis=1)
        return output

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.features_dim

In [44]:
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Embedding, Dense, Dropout, Bidirectional, LSTM

class TextAttBiRNN(object):
    def __init__(self, maxlen, max_features, embedding_dims,
                 class_num=1,
                 last_activation='sigmoid'):
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dims = embedding_dims
        self.class_num = class_num
        self.last_activation = last_activation

    def get_model(self):
        input = Input((self.maxlen,))

        embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)(input)
        x = Bidirectional(LSTM(128, return_sequences=True))(embedding)  # LSTM or GRU
        x = Attention(self.maxlen)(x)

        output = Dense(self.class_num, activation=self.last_activation)(x)
        model = Model(inputs=input, outputs=output)
        return model

In [45]:
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence


max_features = 5000
maxlen = 400
batch_size = 32
embedding_dims = 50
epochs = 10

print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')

print('Pad sequences (samples x time)...')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)

print('Build model...')
model = TextAttBiRNN(maxlen, max_features, embedding_dims).get_model()
model.summary()
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])

print('Train...')
early_stopping = EarlyStopping(monitor='val_acc', patience=3, mode='max')
model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          callbacks=[early_stopping],
          validation_data=(x_test, y_test))

print('Test...')
result = model.predict(x_test)

Loading data...
25000 train sequences
25000 test sequences
Pad sequences (samples x time)...
x_train shape: (25000, 400)
x_test shape: (25000, 400)
Build model...
Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_15 (InputLayer)        [(None, 400)]             0         
_________________________________________________________________
embedding_14 (Embedding)     (None, 400, 50)           250000    
_________________________________________________________________
bidirectional_14 (Bidirectio (None, 400, 256)          183296    
_________________________________________________________________
attention_13 (Attention)     (None, 256)               103200    
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 257       
Total params: 536,753
Trainable params: 536,753
Non-trainable params: 0
______________________

W0916 16:50:21.084877 140735620006784 deprecation.py:323] From /usr/local/lib/python3.7/site-packages/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where




W0916 16:59:14.697167 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 2/10

W0916 17:08:11.926981 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 3/10

W0916 17:16:55.745413 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 4/10

W0916 17:25:43.450053 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 5/10

W0916 17:34:29.260735 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 6/10

W0916 17:43:15.641942 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 7/10

W0916 17:52:05.872829 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 8/10

W0916 18:00:48.534768 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy,val_loss,val_accuracy


Epoch 9/10
 2112/25000 [=>............................] - ETA: 6:36 - loss: 0.0648 - accuracy: 0.9808

W0916 18:01:25.166435 140735620006784 callbacks.py:1249] Early stopping conditioned on metric `val_acc` which is not available. Available metrics are: loss,accuracy


KeyboardInterrupt: 