In [33]:
# Tutorial based on http://androidkt.com/text-classification-using-attention-mechanism-in-keras/
%matplotlib inline

import tensorflow as tf
from keras_preprocessing import sequence
from tensorflow import keras
from tensorflow.python.keras import Input
from tensorflow.python.keras.layers import Concatenate


In [29]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')

def plot_history(history):
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    x = range(1, len(acc) + 1)

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(x, acc, 'b', label='Training acc')
    plt.plot(x, val_acc, 'r', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(x, loss, 'b', label='Training loss')
    plt.plot(x, val_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend()


vocab_size = 10000
 
pad_id = 0
start_id = 1
oov_id = 2
index_offset = 2
 
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=vocab_size, start_char=start_id,
                                                                        oov_char=oov_id, index_from=index_offset)
 
word2idx = tf.keras.datasets.imdb.get_word_index()
 
idx2word = {v + index_offset: k for k, v in word2idx.items()}
 
idx2word[pad_id] = '<PAD>'
idx2word[start_id] = '<START>'
idx2word[oov_id] = '<OOV>'
 
max_len = 200
rnn_cell_size = 128
 
x_train = sequence.pad_sequences(x_train,
                                 maxlen=max_len,
                                 truncating='post',
                                 padding='post',
                                 value=pad_id)
x_test = sequence.pad_sequences(x_test, maxlen=max_len,
                                truncating='post',
                                padding='post',
                                value=pad_id)

In [8]:
class Attention(tf.keras.Model):
    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights


In [4]:
sequence_input = Input(shape=(max_len,), dtype='int32')

embedded_sequences = keras.layers.Embedding(vocab_size, 128, input_length=max_len)(sequence_input)

In [10]:
lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM
                                     (rnn_cell_size,
                                      dropout=0.3,
                                      return_sequences=True,
                                      return_state=True,
                                      recurrent_activation='relu',
                                      recurrent_initializer='glorot_uniform'), name="bi_lstm_0")(embedded_sequences)

layers_lstm_birnn = tf.keras.layers.LSTM(rnn_cell_size, dropout=0.2, return_sequences=True, return_state=True,
                                         recurrent_activation='relu', recurrent_initializer='glorot_uniform')

lstm, forward_h, forward_c, backward_h, backward_c = tf.keras.layers.Bidirectional(layers_lstm_birnn)(lstm)

state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])

attention = Attention(128)
context_vector, attention_weights = attention(lstm, state_h)

output = keras.layers.Dense(1, activation='sigmoid')(context_vector)

model = keras.Model(inputs=sequence_input, outputs=output)

# summarize layers
print(model.summary())


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 200)          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 200, 128)     1280000     input_1[0][0]                    
__________________________________________________________________________________________________
bi_lstm_0 (Bidirectional)       [(None, 200, 256), ( 263168      embedding[0][0]                  
__________________________________________________________________________________________________
bidirectional_4 (Bidirectional) [(None, 200, 256), ( 394240      bi_lstm_0[0][0]                  
                                                                 bi_lstm_0[0][1]                  
          

In [12]:
model.compile(optimizer=tf.train.AdamOptimizer(),
              loss='binary_crossentropy',
              metrics=['accuracy'])
 
early_stopping_callback = keras.callbacks.EarlyStopping(monitor='val_loss',
                                                        min_delta=0,
                                                        patience=1,
                                                        verbose=0, mode='auto')



In [13]:
history = model.fit(x_train,
                    y_train,
                    epochs=10,
                    batch_size=200,
                    validation_split=.3, verbose=1, callbacks=[early_stopping_callback])


Train on 17500 samples, validate on 7500 samples
Epoch 1/10


  200/17500 [..............................] - ETA: 11:38 - loss: 0.6932 - acc: 0.4750

  400/17500 [..............................] - ETA: 8:46 - loss: 0.6931 - acc: 0.4875 

  600/17500 [>.............................] - ETA: 7:44 - loss: 0.6931 - acc: 0.4983

  800/17500 [>.............................] - ETA: 7:11 - loss: 0.6931 - acc: 0.5112

 1000/17500 [>.............................] - ETA: 6:53 - loss: 0.6930 - acc: 0.5240

 1200/17500 [=>............................] - ETA: 6:38 - loss: 0.6930 - acc: 0.5233

 1400/17500 [=>............................] - ETA: 6:25 - loss: 0.6931 - acc: 0.5114

 1600/17500 [=>............................] - ETA: 6:14 - loss: 0.6931 - acc: 0.5037

 1800/17500 [==>...........................] - ETA: 6:05 - loss: 0.6931 - acc: 0.5011

 2000/17500 [==>...........................] - ETA: 5:57 - loss: 0.6931 - acc: 0.5040

 2200/17500 [==>...........................] - ETA: 5:49 - loss: 0.6931 - acc: 0.5064

 2400/17500 [===>..........................] - ETA: 5:42 - loss: 0.6930 - acc: 0.5104

 2600/17500 [===>..........................] - ETA: 5:35 - loss: 0.6930 - acc: 0.5127

 2800/17500 [===>..........................] - ETA: 5:29 - loss: 0.6930 - acc: 0.5104

 3000/17500 [====>.........................] - ETA: 5:24 - loss: 0.6929 - acc: 0.5100

 3200/17500 [====>.........................] - ETA: 5:18 - loss: 0.6929 - acc: 0.5112

 3400/17500 [====>.........................] - ETA: 5:12 - loss: 0.6928 - acc: 0.5126

 3600/17500 [=====>........................] - ETA: 5:07 - loss: 0.6929 - acc: 0.5097

 3800/17500 [=====>........................] - ETA: 5:01 - loss: 0.6927 - acc: 0.5142

 4000/17500 [=====>........................] - ETA: 4:56 - loss: 0.6926 - acc: 0.5157









































































































































In [14]:
result = model.evaluate(x_test, y_test)
print(result)


   32/25000 [..............................] - ETA: 10:43

   64/25000 [..............................] - ETA: 8:09 

   96/25000 [..............................] - ETA: 7:28

  128/25000 [..............................] - ETA: 7:00

  160/25000 [..............................] - ETA: 6:47

  192/25000 [..............................] - ETA: 6:32

  224/25000 [..............................] - ETA: 6:20

  256/25000 [..............................] - ETA: 6:09

  288/25000 [..............................] - ETA: 6:05

  320/25000 [..............................] - ETA: 6:02

  352/25000 [..............................] - ETA: 5:58

  384/25000 [..............................] - ETA: 5:52

  416/25000 [..............................] - ETA: 5:50

  448/25000 [..............................] - ETA: 5:51

  480/25000 [..............................] - ETA: 5:48

  512/25000 [..............................] - ETA: 5:46

  544/25000 [..............................] - ETA: 5:44

  576/25000 [..............................] - ETA: 5:43

  608/25000 [..............................] - ETA: 5:41

  640/25000 [..............................] - ETA: 5:41

  672/25000 [..............................] - ETA: 5:39

  704/25000 [..............................] - ETA: 5:37

  736/25000 [..............................] - ETA: 5:35

  768/25000 [..............................] - ETA: 5:33

  800/25000 [..............................] - ETA: 5:33

  832/25000 [..............................] - ETA: 5:32

  864/25000 [>.............................] - ETA: 5:30

  896/25000 [>.............................] - ETA: 5:29

  928/25000 [>.............................] - ETA: 5:28

  960/25000 [>.............................] - ETA: 5:27

  992/25000 [>.............................] - ETA: 5:26

 1024/25000 [>.............................] - ETA: 5:25

 1056/25000 [>.............................] - ETA: 5:24

 1088/25000 [>.............................] - ETA: 5:24

 1120/25000 [>.............................] - ETA: 5:24

 1152/25000 [>.............................] - ETA: 5:31

 1184/25000 [>.............................] - ETA: 5:44

 1216/25000 [>.............................] - ETA: 5:45

 1248/25000 [>.............................] - ETA: 5:47

 1280/25000 [>.............................] - ETA: 5:54

 1312/25000 [>.............................] - ETA: 5:55

 1344/25000 [>.............................] - ETA: 5:57

 1376/25000 [>.............................] - ETA: 6:02

 1408/25000 [>.............................] - ETA: 6:06

 1440/25000 [>.............................] - ETA: 6:09

 1472/25000 [>.............................] - ETA: 6:10

 1504/25000 [>.............................] - ETA: 6:09

 1536/25000 [>.............................] - ETA: 6:09

 1568/25000 [>.............................] - ETA: 6:08

 1600/25000 [>.............................] - ETA: 6:10

 1632/25000 [>.............................] - ETA: 6:17

 1664/25000 [>.............................] - ETA: 6:18

 1696/25000 [=>............................] - ETA: 6:16

 1728/25000 [=>............................] - ETA: 6:14

 1760/25000 [=>............................] - ETA: 6:13

 1792/25000 [=>............................] - ETA: 6:11

 1824/25000 [=>............................] - ETA: 6:11

 1856/25000 [=>............................] - ETA: 6:11

 1888/25000 [=>............................] - ETA: 6:10

 1920/25000 [=>............................] - ETA: 6:09

 1952/25000 [=>............................] - ETA: 6:09

 1984/25000 [=>............................] - ETA: 6:08

 2016/25000 [=>............................] - ETA: 6:07

 2048/25000 [=>............................] - ETA: 6:06

 2080/25000 [=>............................] - ETA: 6:05

 2112/25000 [=>............................] - ETA: 6:05

 2144/25000 [=>............................] - ETA: 6:05

 2176/25000 [=>............................] - ETA: 6:04

 2208/25000 [=>............................] - ETA: 6:04

 2240/25000 [=>............................] - ETA: 6:02

 2272/25000 [=>............................] - ETA: 6:01

 2304/25000 [=>............................] - ETA: 6:02

 2336/25000 [=>............................] - ETA: 6:01

 2368/25000 [=>............................] - ETA: 6:00

 2400/25000 [=>............................] - ETA: 5:59

 2432/25000 [=>............................] - ETA: 5:58

 2464/25000 [=>............................] - ETA: 5:56

 2496/25000 [=>............................] - ETA: 5:55

 2528/25000 [==>...........................] - ETA: 5:54

 2560/25000 [==>...........................] - ETA: 5:53

 2592/25000 [==>...........................] - ETA: 5:52

 2624/25000 [==>...........................] - ETA: 5:51

 2656/25000 [==>...........................] - ETA: 5:50

 2688/25000 [==>...........................] - ETA: 5:49

 2720/25000 [==>...........................] - ETA: 5:50

 2752/25000 [==>...........................] - ETA: 5:50

 2784/25000 [==>...........................] - ETA: 5:49

 2816/25000 [==>...........................] - ETA: 5:51

 2848/25000 [==>...........................] - ETA: 5:51

 2880/25000 [==>...........................] - ETA: 5:50

 2912/25000 [==>...........................] - ETA: 5:49

 2944/25000 [==>...........................] - ETA: 5:47

 2976/25000 [==>...........................] - ETA: 5:46

 3008/25000 [==>...........................] - ETA: 5:45

 3040/25000 [==>...........................] - ETA: 5:44

 3072/25000 [==>...........................] - ETA: 5:43

 3104/25000 [==>...........................] - ETA: 5:42

 3136/25000 [==>...........................] - ETA: 5:42

 3168/25000 [==>...........................] - ETA: 5:41

 3200/25000 [==>...........................] - ETA: 5:40

 3232/25000 [==>...........................] - ETA: 5:40

 3264/25000 [==>...........................] - ETA: 5:39

 3296/25000 [==>...........................] - ETA: 5:40

 3328/25000 [==>...........................] - ETA: 5:40

 3360/25000 [===>..........................] - ETA: 5:41

 3392/25000 [===>..........................] - ETA: 5:40

 3424/25000 [===>..........................] - ETA: 5:39

 3456/25000 [===>..........................] - ETA: 5:39

 3488/25000 [===>..........................] - ETA: 5:38

 3520/25000 [===>..........................] - ETA: 5:37

 3552/25000 [===>..........................] - ETA: 5:36

 3584/25000 [===>..........................] - ETA: 5:36

 3616/25000 [===>..........................] - ETA: 5:35

 3648/25000 [===>..........................] - ETA: 5:34

 3680/25000 [===>..........................] - ETA: 5:33

 3712/25000 [===>..........................] - ETA: 5:32

 3744/25000 [===>..........................] - ETA: 5:31

 3776/25000 [===>..........................] - ETA: 5:30

 3808/25000 [===>..........................] - ETA: 5:29

 3840/25000 [===>..........................] - ETA: 5:29

 3872/25000 [===>..........................] - ETA: 5:28

 3904/25000 [===>..........................] - ETA: 5:27

 3936/25000 [===>..........................] - ETA: 5:26

 3968/25000 [===>..........................] - ETA: 5:26

 4000/25000 [===>..........................] - ETA: 5:25

 4032/25000 [===>..........................] - ETA: 5:24

 4064/25000 [===>..........................] - ETA: 5:22

 4096/25000 [===>..........................] - ETA: 5:22

 4128/25000 [===>..........................] - ETA: 5:21

 4160/25000 [===>..........................] - ETA: 5:20

 4192/25000 [====>.........................] - ETA: 5:19

 4224/25000 [====>.........................] - ETA: 5:18

 4256/25000 [====>.........................] - ETA: 5:17

 4288/25000 [====>.........................] - ETA: 5:16

 4320/25000 [====>.........................] - ETA: 5:15

 4352/25000 [====>.........................] - ETA: 5:14

 4384/25000 [====>.........................] - ETA: 5:13

 4416/25000 [====>.........................] - ETA: 5:12

 4448/25000 [====>.........................] - ETA: 5:12

 4480/25000 [====>.........................] - ETA: 5:11

 4512/25000 [====>.........................] - ETA: 5:10

 4544/25000 [====>.........................] - ETA: 5:09

 4576/25000 [====>.........................] - ETA: 5:08

 4608/25000 [====>.........................] - ETA: 5:08

 4640/25000 [====>.........................] - ETA: 5:07

 4672/25000 [====>.........................] - ETA: 5:07

 4704/25000 [====>.........................] - ETA: 5:06

 4736/25000 [====>.........................] - ETA: 5:05

 4768/25000 [====>.........................] - ETA: 5:04

 4800/25000 [====>.........................] - ETA: 5:03

 4832/25000 [====>.........................] - ETA: 5:03

 4864/25000 [====>.........................] - ETA: 5:02

 4896/25000 [====>.........................] - ETA: 5:01

 4928/25000 [====>.........................] - ETA: 5:00

 4960/25000 [====>.........................] - ETA: 4:59

 4992/25000 [====>.........................] - ETA: 4:59

 5024/25000 [=====>........................] - ETA: 4:58

 5056/25000 [=====>........................] - ETA: 4:57

 5088/25000 [=====>........................] - ETA: 4:57

 5120/25000 [=====>........................] - ETA: 4:56

 5152/25000 [=====>........................] - ETA: 4:56

 5184/25000 [=====>........................] - ETA: 4:55

 5216/25000 [=====>........................] - ETA: 4:54

 5248/25000 [=====>........................] - ETA: 4:54

 5280/25000 [=====>........................] - ETA: 4:53

 5312/25000 [=====>........................] - ETA: 4:52

 5344/25000 [=====>........................] - ETA: 4:52

 5376/25000 [=====>........................] - ETA: 4:51

 5408/25000 [=====>........................] - ETA: 4:50

 5440/25000 [=====>........................] - ETA: 4:50

 5472/25000 [=====>........................] - ETA: 4:49

 5504/25000 [=====>........................] - ETA: 4:48

 5536/25000 [=====>........................] - ETA: 4:48

 5568/25000 [=====>........................] - ETA: 4:47

 5600/25000 [=====>........................] - ETA: 4:46

 5632/25000 [=====>........................] - ETA: 4:46

 5664/25000 [=====>........................] - ETA: 4:45

 5696/25000 [=====>........................] - ETA: 4:44

 5728/25000 [=====>........................] - ETA: 4:43

 5760/25000 [=====>........................] - ETA: 4:43

 5792/25000 [=====>........................] - ETA: 4:42

 5824/25000 [=====>........................] - ETA: 4:41

















































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































[nan, 0.5]


In [28]:
model.predict(x_test[1].reshape(1, -1))

array([[nan]], dtype=float32)