In [1]:
import os
import glob
import pandas as pd
import codecs
import copy
import logging
import math
import os
import pickle
import random
import tempfile
from typing import List, Tuple, Union
import numpy as np

In [2]:
import keras
from keras.utils import to_categorical

Using TensorFlow backend.


In [3]:
import tensorflow as tf
import tensorflow_hub as hub

In [4]:
NAMED_ENTITIES = ['O', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC']
EMBEDDING_SIZE = 1024

In [5]:
logging.basicConfig(level=logging.INFO)
factrueval_logger = logging.getLogger('elmo_lstm')

In [6]:
def load_document(tokens_file_name: str, spans_file_name: str,
                  objects_file_name: str) -> Tuple[List[List[Tuple[str, int, int]]], List[List[int]]]:
    texts = []
    new_text = []
    tokens_dict = dict()
    with codecs.open(tokens_file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        line_idx = 1
        cur_line = fp.readline()
        while len(cur_line) > 0:
            err_msg = 'File `{0}`: line {1} is wrong!'.format(tokens_file_name, line_idx)
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                token_description = prep_line.split()
                if len(token_description) != 4:
                    raise ValueError(err_msg)
                if (not token_description[0].isdigit()) or (not token_description[1].isdigit()) or \
                        (not token_description[2].isdigit()):
                    raise ValueError(err_msg)
                token_id = int(token_description[0])
                if token_id in tokens_dict:
                    raise ValueError(err_msg)
                new_text.append(token_id)
                token_start = int(token_description[1])
                token_length = int(token_description[2])
                tokens_dict[token_id] = (token_description[-1], 'O', token_start, token_length,
                                         (len(texts), len(new_text) - 1))
            else:
                if len(new_text) == 0:
                    raise ValueError(err_msg)
                texts.append(copy.copy(new_text))
                new_text.clear()
            cur_line = fp.readline()
            line_idx += 1
    if len(new_text) > 0:
        texts.append(copy.copy(new_text))
        new_text.clear()
    spans_dict = dict()
    with codecs.open(spans_file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        line_idx = 1
        cur_line = fp.readline()
        while len(cur_line) > 0:
            err_msg = 'File `{0}`: line {1} is wrong!'.format(spans_file_name, line_idx)
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                comment_pos = prep_line.find('#')
                if comment_pos < 0:
                    raise ValueError(err_msg)
                prep_line = prep_line[:comment_pos].strip()
                if len(prep_line) == 0:
                    raise ValueError(err_msg)
                span_description = prep_line.split()
                if len(span_description) != 6:
                    raise ValueError(err_msg)
                if (not span_description[0].isdigit()) or (not span_description[-1].isdigit()) or \
                        (not span_description[-2].isdigit()):
                    raise ValueError(err_msg)
                span_id = int(span_description[0])
                token_IDs = list()
                start_token_id = int(span_description[-2])
                n_tokens = int(span_description[-1])
                if (n_tokens <= 0) or (start_token_id not in tokens_dict):
                    raise ValueError(err_msg)
                text_idx = tokens_dict[start_token_id][4][0]
                token_pos_in_text = tokens_dict[start_token_id][4][1]
                for idx in range(n_tokens):
                    token_id = texts[text_idx][token_pos_in_text + idx]
                    if token_id not in tokens_dict:
                        raise ValueError(err_msg)
                    token_IDs.append(token_id)
                if span_id not in spans_dict:
                    spans_dict[span_id] = tuple(token_IDs)
            cur_line = fp.readline()
            line_idx += 1
    with codecs.open(objects_file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        line_idx = 1
        cur_line = fp.readline()
        while len(cur_line) > 0:
            err_msg = 'File `{0}`: line {1} is wrong!'.format(objects_file_name, line_idx)
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                comment_pos = prep_line.find('#')
                if comment_pos < 0:
                    raise ValueError(err_msg)
                prep_line = prep_line[:comment_pos].strip()
                if len(prep_line) == 0:
                    raise ValueError(err_msg)
                object_description = prep_line.split()
                if len(object_description) < 3:
                    raise ValueError(err_msg)
                if object_description[1] not in {'LocOrg', 'Org', 'Person', 'Location'}:
                    factrueval_logger.warning(err_msg + ' The entity `{0}` is unknown.'.format(object_description[1]))
                else:
                    span_IDs = []
                    for idx in range(2, len(object_description)):
                        if not object_description[idx].isdigit():
                            raise ValueError(err_msg)
                        span_id = int(object_description[idx])
                        if span_id not in spans_dict:
                            raise ValueError(err_msg)
                        span_IDs.append(span_id)
                    span_IDs.sort(key=lambda span_id: tokens_dict[spans_dict[span_id][0]][2])
                    token_IDs = []
                    for span_id in span_IDs:
                        start_token_id = spans_dict[span_id][0]
                        end_token_id = spans_dict[span_id][-1]
                        text_idx = tokens_dict[start_token_id][4][0]
                        token_pos_in_text = tokens_dict[start_token_id][4][1]
                        while token_pos_in_text < len(texts[text_idx]):
                            token_id = texts[text_idx][token_pos_in_text]
                            token_IDs.append(token_id)
                            if token_id == end_token_id:
                                break
                            token_pos_in_text += 1
                        if token_pos_in_text >= len(texts[text_idx]):
                            raise ValueError(err_msg)
                    if object_description[1] in {'LocOrg', 'Location'}:
                        class_label = 'LOC'
                    elif object_description[1] == 'Person':
                        class_label = 'PER'
                    else:
                        class_label = 'ORG'
                    tokens_are_used = False
                    if tokens_dict[token_IDs[0]][1] != 'O':
                        tokens_are_used = True
                    else:
                        for token_id in token_IDs[1:]:
                            if tokens_dict[token_id][1] != 'O':
                                tokens_are_used = True
                                break
                    if not tokens_are_used:
                        tokens_dict[token_IDs[0]] = (
                            tokens_dict[token_IDs[0]][0], 'B-' + class_label,
                            tokens_dict[token_IDs[0]][2], tokens_dict[token_IDs[0]][3],
                            tokens_dict[token_IDs[0]][4]
                        )
                        for token_id in token_IDs[1:]:
                            tokens_dict[token_id] = (
                                tokens_dict[token_id][0], 'I-' + class_label,
                                tokens_dict[token_id][2], tokens_dict[token_id][3],
                                tokens_dict[token_id][4]
                            )
            cur_line = fp.readline()
            line_idx += 1
    list_of_texts = []
    list_of_labels = []
    for tokens_sequence in texts:
        
        new_text = []
        new_labels_sequence = []
        for token_id in tokens_sequence:
            new_text.append((tokens_dict[token_id][0], tokens_dict[token_id][2], tokens_dict[token_id][3]))
            new_labels_sequence.append(NAMED_ENTITIES.index(tokens_dict[token_id][1]))
        list_of_texts.append(new_text)
        list_of_labels.append(new_labels_sequence)
    return list_of_texts, list_of_labels


def load_data_for_training(data_dir_name: str) -> Tuple[List[List[str]], List[List[int]]]:
    names_of_files = sorted(list(filter(lambda it: it.startswith('book_'), os.listdir(data_dir_name))))
    if len(names_of_files) == 0:
        raise ValueError('The directory `{0}` is empty!'.format(data_dir_name))
    if (len(names_of_files) % 6) != 0:
        raise ValueError('The directory `{0}` contains wrong data!'.format(data_dir_name))
    list_of_all_texts = []
    list_of_all_labels = []
    for idx in range(len(names_of_files) // 6):
        base_name = names_of_files[idx * 6]
        point_pos = base_name.rfind('.')
        if point_pos <= 0:
            raise ValueError('The file `{0}` has incorrect name.'.format(base_name))
        prepared_base_name = base_name[:point_pos].strip()
        if len(prepared_base_name) == 0:
            raise ValueError('The file `{0}` has incorrect name.'.format(base_name))
        tokens_file_name = os.path.join(data_dir_name, prepared_base_name + '.tokens')
        if not os.path.isfile(tokens_file_name):
            raise ValueError('The file `{0}` does not exist!'.format(tokens_file_name))
        spans_file_name = os.path.join(data_dir_name, prepared_base_name + '.spans')
        if not os.path.isfile(spans_file_name):
            raise ValueError('The file `{0}` does not exist!'.format(spans_file_name))
        objects_file_name = os.path.join(data_dir_name, prepared_base_name + '.objects')
        if not os.path.isfile(objects_file_name):
            raise ValueError('The file `{0}` does not exist!'.format(objects_file_name))
        list_of_texts, list_of_labels = load_document(tokens_file_name, spans_file_name, objects_file_name)
        list_of_all_texts += [list(map(lambda it: it[0], cur_text)) for cur_text in list_of_texts]
        list_of_all_labels += list_of_labels
    return list_of_all_texts, list_of_all_labels


def load_data_for_testing(data_dir_name: str) -> Tuple[List[List[str]], List[Tuple[str, List[List[Tuple[int, int]]]]]]:
    names_of_files = sorted(list(filter(lambda it: it.startswith('book_'), os.listdir(data_dir_name))))
    if len(names_of_files) == 0:
        raise ValueError('The directory `{0}` is empty!'.format(data_dir_name))
    if (len(names_of_files) % 6) != 0:
        raise ValueError('The directory `{0}` contains wrong data!'.format(data_dir_name))
    list_of_all_texts = []
    list_of_all_token_bounds = []
    for idx in range(len(names_of_files) // 6):
        base_name = names_of_files[idx * 6]
        point_pos = base_name.rfind('.')
        if point_pos <= 0:
            raise ValueError('The file `{0}` has incorrect name.'.format(base_name))
        prepared_base_name = base_name[:point_pos].strip()
        if len(prepared_base_name) == 0:
            raise ValueError('The file `{0}` has incorrect name.'.format(base_name))
        tokens_file_name = os.path.join(data_dir_name, prepared_base_name + '.tokens')
        if not os.path.isfile(tokens_file_name):
            raise ValueError('The file `{0}` does not exist!'.format(tokens_file_name))
        spans_file_name = os.path.join(data_dir_name, prepared_base_name + '.spans')
        if not os.path.isfile(spans_file_name):
            raise ValueError('The file `{0}` does not exist!'.format(spans_file_name))
        objects_file_name = os.path.join(data_dir_name, prepared_base_name + '.objects')
        if not os.path.isfile(objects_file_name):
            raise ValueError('The file `{0}` does not exist!'.format(objects_file_name))
        list_of_texts, list_of_labels = load_document(tokens_file_name, spans_file_name, objects_file_name)
        list_of_all_texts += [list(map(lambda it: it[0], cur_text)) for cur_text in list_of_texts]
        list_of_all_token_bounds.append(
            (
                base_name,
                [list(map(lambda it: (it[1], it[2]), cur_text)) for cur_text in list_of_texts]
            )
        )
    return list_of_all_texts, list_of_all_token_bounds


In [7]:
def get_maximum_length_of_data(train_texts, test_texts):
    train_maximum_length = len(max(train_texts,key=len))
    test_maximum_length = len(max(test_texts,key=len))
    return max(train_maximum_length, test_maximum_length)
    

In [8]:
def cast_data_to_the_shape(texts, labels, maximum_length):
    cast_texts = [x[:min(len(x), maximum_length)] + ['']*max((maximum_length - len(x)),0) for x in texts]
    cast_labels = [to_categorical(x[:min(len(x), maximum_length)] + [0]*max((maximum_length - len(x)),0), num_classes=len(NAMED_ENTITIES)) for x in labels]
    return cast_texts, cast_labels

In [10]:
def get_embeddings_from_texts(texts, max_length, embedding_size, batch_size):
    embeddings = np.zeros((len(texts), max_length, embedding_size), dtype=np.float32)
    steps_per_batch = int(math.ceil(len(texts)/float(batch_size)))                         
    with tf.Graph().as_default():
        elmo = hub.Module(ELMO_URL, trainable=True)
        with tf.Session() as sess:
            token_sentences = tf.placeholder(shape=(None, None), dtype=tf.string, name='tokens')
            token_sentences_lenghts = tf.placeholder(shape=(None,), dtype=tf.int32, name='tokens_length')
            embeddings_of_texts = elmo(
                inputs={
                    'tokens': token_sentences,
                    'sequence_len': token_sentences_lenghts
                },
                signature='tokens',
                as_dict=True)["elmo"]
            sess.run(tf.global_variables_initializer())
            sess.run(tf.tables_initializer())
            for current_position in range(steps_per_batch, len(texts), steps_per_batch):
                previous_position = current_position - steps_per_batch
                current_texts = texts[previous_position:current_position]
                current_embeddings = sess.run(
                    embeddings_of_texts,
                    feed_dict={
                        token_sentences: current_texts,
                        token_sentences_lenghts: [len(np.where(np.array(text) != '')[0]) for text in current_texts]
                    })
                embeddings[previous_position:current_position] = current_embeddings
            
    return embeddings

<h2>Transform data</h2>

In [10]:
train_data, train_labels = load_data_for_training('data/devset')
test_data, test_labels = load_data_for_training('data/testset')



In [11]:
MAX_LENGTH = get_maximum_length_of_data(train_data, test_data)

In [12]:
cast_train_data, cast_train_labels = cast_data_to_the_shape(train_data, train_labels, MAX_LENGTH)

In [13]:
int(len(cast_train_data)/4)

442

<h2>Get embeddings</h2>

In [14]:
ELMO_URL = "http://files.deeppavlov.ai/deeppavlov_data/elmo_ru-news_wmt11-16_1.5M_steps.tar.gz"

In [15]:
train_embeddings = get_embeddings_from_texts(cast_train_data, MAX_LENGTH, EMBEDDING_SIZE, 4)

INFO:tensorflow:Using C:\Users\Goodman\AppData\Local\Temp\tfhub_modules to cache modules.


INFO:tensorflow:Using C:\Users\Goodman\AppData\Local\Temp\tfhub_modules to cache modules.


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


In [16]:
with open('train_embeddings.pickle', 'wb') as f:
    pickle.dump(train_embeddings, f)

<h2>Model creating and fitting</h2>

In [20]:
from keras.layers import LSTM, Bidirectional, BatchNormalization, TimeDistributed
from keras.models import Sequential
from keras.layers import Dense

In [21]:
def named_entity_recognition_nn(input_length):
    model = Sequential()
    model.add(Bidirectional(LSTM(50, activation="sigmoid", return_sequences = True, dropout = 0.2), input_shape=(MAX_LENGTH, EMBEDDING_SIZE)))
    model.add(BatchNormalization())
    model.add(TimeDistributed(Dense(len(NAMED_ENTITIES), activation='softmax')))

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])
    return model

In [112]:
model = named_entity_recognition_nn(len(train_data))

In [113]:
model.fit(train_embeddings, np.array(cast_train_labels),  epochs=10, batch_size=4)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1e353779198>

<h2>Prediction</h2>

In [114]:
cast_test_data, cast_test_labels = cast_data_to_the_shape(test_data, test_labels, MAX_LENGTH)

In [118]:
test_embeddings = get_embeddings_from_texts(cast_test_data, MAX_LENGTH, EMBEDDING_SIZE, 4)

INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


In [119]:
predicted_test_probs = model.predict(test_embeddings)

<h2>F1 score</h2>

In [120]:
from sklearn.metrics import f1_score

In [121]:
predicted_test_labels = np.array([[np.argmax(word) for word in sentence] for sentence in predicted_test_probs])

In [122]:
true_test_labels = np.array([[np.argmax(word) for word in sentence] for sentence in cast_test_labels])

In [123]:
f1_score(np.hstack(true_test_labels), np.hstack(predicted_test_labels),
         average='macro')

0.8208276831999892