In [None]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from tensorflow.keras import backend
from tensorflow.keras.preprocessing.text import Tokenizer 
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Dense, Embedding, LSTM, GRU, TimeDistributed, Attention, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import optimizers

import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
import time

plt.style.use("default")
warnings.filterwarnings("ignore")

In [None]:
backend.clear_session()

In [None]:
df = pd.read_csv("data_clean.csv")

df = df.drop_duplicates().reset_index(drop=True)

In [None]:
len_c = np.array([len(x.split()) for x in df['Content']])
len_s = np.array([len(x.split()) for x in df['Summary']])

In [None]:
max_len_content = 400
max_len_summary = 25

In [None]:
df = df.iloc[np.where(np.logical_and(len_c<=max_len_content, len_s<=max_len_summary))[0]].reset_index(drop=True)

In [None]:
x_train.shape, x_test.shape

In [None]:
x_train, x_test, y_train, y_test = train_test_split(df['Content'].values, 
                                                    df['Summary_clean'].values, 
                                                    test_size=0.1,
                                                    random_state=767, 
                                                    shuffle=True)

x_train, x_val, y_train, y_val = train_test_split(x_train, 
                                                  y_train, 
                                                  test_size=0.1, 
                                                  random_state=767, 
                                                  shuffle=True)

In [None]:
st = time.time()

tokenizer_content = Tokenizer()
tokenizer_content.fit_on_texts(x_train);

x_train = tokenizer_content.texts_to_sequences(x_train)
x_val = tokenizer_content.texts_to_sequences(x_val)
x_test = tokenizer_content.texts_to_sequences(x_test)

x_train= pad_sequences(x_train,  maxlen=max_len_content, padding='post')
x_val = pad_sequences(x_val,  maxlen=max_len_content, padding='post')
x_test = pad_sequences(x_test,  maxlen=max_len_content, padding='post')

et = time.time()
print("Time taken: {:d} h {:d} min {:.2f} s".format(int((et - st)/3600), int(((et - st)%3600)/60), ((et - st)%3600)%60))

In [None]:
st = time.time()

tokenizer_summary = Tokenizer()
tokenizer_summary.fit_on_texts(y_train);

y_train = tokenizer_summary.texts_to_sequences(y_train)
y_val = tokenizer_summary.texts_to_sequences(y_val)
y_test = tokenizer_summary.texts_to_sequences(y_test)

y_train= pad_sequences(y_train,  maxlen=max_len_summary, padding='post')
y_val = pad_sequences(y_val,  maxlen=max_len_summary, padding='post')
y_test = pad_sequences(y_test,  maxlen=max_len_summary, padding='post')

et = time.time()
print("Time taken: {:d} h {:d} min {:.2f} s".format(int((et - st)/3600), int(((et - st)%3600)/60), ((et - st)%3600)%60))

In [None]:
x_voc = len(tokenizer_content.word_index) + 1
y_voc = len(tokenizer_summary.word_index) + 1

## Encoder

In [None]:
lstm_units = 800
embedding_units = 500

encoder_input = Input(shape=(max_len_content,))

encoder_embedding = Embedding(x_voc, embedding_units, trainable=True, name="encoder_emb")(encoder_input)

encoder_lstm1 = LSTM(lstm_units, return_sequences=True, return_state=True, name="encoder_lstm1")
encoder_layer1, state_a1, state_c1 = encoder_lstm1(encoder_embedding)

encoder_lstm2 = LSTM(lstm_units, return_sequences=True, return_state=True, name="encoder_lstm2")
encoder_layer2, state_a2, state_c2 = encoder_lstm2(encoder_layer1)

encoder_lstm3 = LSTM(lstm_units, return_sequences=True, return_state=True, name="encoder_lstm3")
encoder_layer_last, state_a_last, state_c_last = encoder_lstm3(encoder_layer2)

## Decoder

In [None]:
decoder_input = Input(shape=(None,))

decoder_embedding = Embedding(y_voc, embedding_units, trainable=True, name="decoder_emb")
decoder_emb_layer = decoder_embedding(decoder_input)

decoder_lstm = LSTM(lstm_units, return_sequences=True, return_state=True, name="decoder_lstm")
decoder_layer, decoder_state_f, decoder_state_b = decoder_lstm(decoder_emb_layer, initial_state=[state_a_last, state_c_last])

attention = Attention()
attention_layer = attention([decoder_layer, encoder_layer_last])

#attention_pool = GlobalAveragePooling1D()(attention_layer)
#decoder_layer_pool = GlobalAveragePooling1D()(decoder_layer)

decoder_concat = Concatenate(axis=-1)([decoder_layer, attention_layer])

decoder_dense = TimeDistributed(Dense(y_voc, activation="softmax"))
decoder_output = decoder_dense(decoder_concat)

In [None]:
model = Model(inputs=[encoder_input, decoder_input], outputs=[decoder_output])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=5)

In [None]:
model.summary()

In [None]:
if "model_weights.h5" in os.listdir():
    model.load_weights("model_weights.h5")

else:
    history = model.fit(x=[x_train, y_train[:,:-1]], 
                        y=y_train.reshape(-1, max_len_summary, 1)[:,1:], 
                        validation_data=([x_val, y_val[:,:-1]], y_val.reshape(-1, max_len_summary, 1)[:,1:]), 
                        epochs=20, 
                        callbacks=[early_stopping])

## Inference

In [None]:
encoder_model = Model(inputs=[encoder_input], outputs=[encoder_layer_last, state_a_last, state_c_last])

inference_decoder_input = Input(shape=(max_len_content, lstm_units))
decoder_input_a = Input(shape=(lstm_units,))
decoder_input_c = Input(shape=(lstm_units,))

inference_decoder_emb = decoder_embedding(decoder_input)

inference_decoder_layer, inf_state_a, inf_state_c = decoder_lstm(inference_decoder_emb, 
                                                         initial_state=[decoder_input_a, decoder_input_c])

inference_attention = attention([inference_decoder_layer, inference_decoder_input])
#inference_attention, shp1 = attention([inference_decoder_input, inference_decoder_layer])

inference_concat = Concatenate()([inference_decoder_layer, inference_attention])

inference_decoder_output = decoder_dense(inference_concat)

In [None]:
inference_model = Model([decoder_input] + [inference_decoder_input, decoder_input_a, decoder_input_c], 
                        [inference_decoder_output] + [inf_state_a, inf_state_c])

In [None]:
if "inference_model_weights.h5" in os.listdir():
    inference_model.load_weights("inference_model_weights.h5")

In [None]:
index_word_content = tokenizer_content.index_word
index_word_summary = tokenizer_summary.index_word
word_index_content = tokenizer_content.word_index
word_index_summary = tokenizer_summary.word_index

In [None]:
index_word_content = tokenizer_content.index_word
index_word_summary = tokenizer_summary.index_word
word_index_content = tokenizer_content.word_index
word_index_summary = tokenizer_summary.word_index

## Make Predictions

In [None]:
def predict_summary(input_tokens):
    
    encoder_output, encoder_state_a, encoder_state_c = encoder_model.predict(input_tokens)
    
    decoder_input_token = np.array([[word_index_summary['start']]])
    
    flag = False
    pred_sentence = "start"
    
    while not flag:
        inference_output, inference_a, inference_c = inference_model.predict([decoder_input_token] + [encoder_output, encoder_state_a, encoder_state_c])
        
        token_idx = np.argmax(inference_output.ravel())
        pred_word = index_word_summary[token_idx]
        
        if pred_word != 'end':
            pred_sentence = pred_sentence + " " + pred_word
            
        if pred_word == "end" or len(pred_sentence.split()) >= max_len_summary:
            pred_sentence = pred_sentence + " end"
            flag = True
            
        decoder_input_token = np.array([[token_idx]])
        
        encoder_state_a = inference_a
        encoder_state_c = inference_c
    
    return pred_sentence

In [None]:
i = 36
print(predict_summary(x_test[i].reshape(1,-1)))
print(tokenizer_summary.sequences_to_texts([y_test[i]]))