In [1]:
from os import listdir
from os.path import isfile, join

test_path = "../resources/testset"
dev_path = "../resources/devset"

test_files = [join(test_path,f) for f in listdir(test_path) if isfile(join(test_path, f))]
dev_files = [join(dev_path,f) for f in listdir(dev_path) if isfile(join(dev_path, f))]

In [2]:
class Document:
    def __init__(self, id, tokens, entities, sentences):
        self.id = id #book id
        self.tokens = tokens
        self.entities = entities
        self.sentences = sentences
    def __repr__(self):
        return "Tokens:{}\nEntities:{}\nSentences:{}\n".format(self.tokens, self.entities, self.sentences)

class Entity:
    def __init__(self, id, ent_type, start_token, end_token, symbol_start_idx, symbol_len):
        self.id = id #id in objects file
        self.ent_type = ent_type#location or locorg
        self.start_token = start_token
        self.end_token = end_token
        self.symbol_start_idx = symbol_start_idx
        self.symbol_len = symbol_len
        
    def __repr__(self):
        return " id: " + str(self.id) + " type: " + self.ent_type + " start: " + str(self.start_token) + " end: " + str(
            self.end_token) + " symbol_start: {} symbol_len: {}".format(self.symbol_start_idx, self.symbol_len)

class Sentence:
    def __init__(self, start_token, end_token):
        self.start_token = start_token
        self.end_token = end_token
    def __repr__(self):
        return " start: " + str(self.start_token) + " end: " + str(
            self.end_token)

In [3]:
import re

def generate_doc(doc_id, tokens_file, span_file, objects_file):
    tokens = []
    token_ids = []
    sentences = []
    entities = []
    spans = []
    
    current_start = 0
    current_id = 0
    for line in tokens_file.strip().split('\n'):
        if line == "":
            sentences.append(Sentence(current_start,current_id))
            current_start = current_id
            continue
        
        lst = line.strip().split(' ')
        token = lst[-1]
        token_id = int(lst[0])
        
        tokens.append(token)
        token_ids.append(token_id)
        current_id += 1
    
    sentences.append(Sentence(current_start, len(tokens)))
    
    for line in span_file.strip().split('\n'):
        
        span_lst = line.strip().split(' ')
        span_token_id = int(span_lst[4])
        entity_len = int(span_lst[5])
        entity_type = span_lst[1]
        entity_symbol_start_idx = int(span_lst[2])
        entity_symbol_len = int(span_lst[3])
        span_id = int(span_lst[0])
        
        for i, token_id in enumerate(token_ids):
            if token_id == span_token_id:
                entity_start = i
                break
        
        spans.append((span_id, entity_start, entity_len, entity_symbol_start_idx, entity_symbol_len))
    
    for line in objects_file.strip().split('\n'):
            object_lst = line.strip().split(' ')
            object_id = int(object_lst[0])
            object_type = object_lst[1]
            
            span_ids = []
            for i in range(2, len(object_lst)):
                sp_id = object_lst[i]
                if sp_id == '#':
                    break
                span_ids.append(int(sp_id))
                
            if object_type == "Location":
                object_type = 'loc'
            elif object_type == "LocOrg":
                object_type = 'locorg'
            elif object_type == "Person":
                object_type = "per"
            elif object_type == "Org":
                object_type = "org"
            elif object_type == "Project":
                object_type = "org"
                
                            
            entity_spans = []
            
            for span_id in span_ids:
                for span in spans:
                    if span[0] == span_id:
                        entity_spans.append(span)
                        break
              
            max_token_idx = None
            min_token_idx = None
            max_symbol_idx = None
            min_symbol_idx = None
            
#             start_sorted_spans = sorted(entity_spans, key = lambda x: x[1])
#             end_sorted_spans = sorted(entity_spans, key = lambda x: x[2])
            
#             min_token_idx = start_sorted_spans[0][1]
#             max_token_idx = end_sorted_spans[-1][2]
#             max_symbol_idx = end_sorted_spans[-1][4]
#             min_symbol_idx = start_sorted_spans[0][3]
            
            
            for span in entity_spans:
                token_start = span[1]
                token_end = span[2] + token_start
                symbol_start = span[3]
                symbol_end = span[4] + symbol_start
                
                if min_symbol_idx is None or min_symbol_idx > symbol_start:
                    min_symbol_idx = symbol_start
                    
                if max_symbol_idx is None or max_symbol_idx < symbol_end:
                    max_symbol_idx = symbol_end
                    
                if min_token_idx is None or min_token_idx > token_start:
                    min_token_idx = token_start
                
                if max_token_idx is None or max_token_idx < token_end:
                    max_token_idx = token_end
            
            entities.append(Entity(object_id, object_type, min_token_idx, 
                                   max_token_idx, min_symbol_idx, max_symbol_idx-min_symbol_idx))
    return Document(doc_id, tokens, entities, sentences)   

def read_docs(files):
    re_pattern = "\d+"
    p = re.compile(re_pattern)
    
    files_with_id = dict()
    for f in files:
        search = p.search(f)
        if search is None:
            continue
        file_id = int(search.group())
        lst = files_with_id.setdefault(file_id, list())
        lst.append(f)
    
    docs = []
    for id, lst in files_with_id.items():
        tokens = []
        entities = []
        
        span_file = None
        tokens_file = None
        objects_file = None
        
        for file in lst:
            with open(file, "r") as open_file:
                file_content = open_file.read()
                                      
            if file[-5:] == "spans":
                span_file = file_content
            elif file[-7:] == "objects":
                objects_file = file_content
            elif file[-6:] == "tokens":
                tokens_file = file_content
                
        docs.append(generate_doc(id, tokens_file, span_file, objects_file))
    
    return docs             
                     

In [4]:
import os

def write_docs(docs, path):
    for doc in docs:
        doc_id = doc.id
        object_file_path = os.path.join(path, "book_"+str(doc_id)+".task1")
        
        lines_to_write = []
            
        for entity in doc.entities:
            str_format = '{} {} {}\n'.format(entity.ent_type, entity.symbol_start_idx,entity.symbol_len+2)
            lines_to_write.append(str_format)
        
        with open(object_file_path, "w") as f:
            f.writelines(lines_to_write)

# docs = read_docs(test_files)

# print(docs[0].id)

# write_docs(docs, "../resources/out")

In [5]:
embedding_model = {}

with open('../resources/word_vec.txt', "r") as f:
    for line in f:
        split = line.strip().split(' ')
        embedding_model[split[0]] = [float(num) for num in split[1:]]

word_emb_size = len(embedding_model['цикл'])
embedding_model["\""] = [0]*word_emb_size

In [75]:
import pymorphy2

def find_entity_sentence_idx(doc, entity):
    for sent in doc.sentences:
        if sent.start_token <= entity.start_token< sent.end_token:
            return sent

def get_samples(docs, word_window_size, max_len, pos_dict=None):   
    morph = pymorphy2.MorphAnalyzer()
    POS_tags = set()
    
    for doc in docs:
        doc.pos_tags = []
        for token in doc.tokens:
            pos_tag = morph.parse(token)[0].tag.POS
            doc.pos_tags.append(pos_tag)
            POS_tags.add(pos_tag)
    
    if pos_dict is None: 
        pos_dict = {}

        for i, tag in enumerate(sorted(POS_tags, key=str)):
            pos_dict[tag] = i

        pos_dict["OOV"] = len(pos_dict)
         
    samples_x = []
    samples_y = []
    for doc in docs:
        for ent in doc.entities:
            if ent.ent_type in {"loc", "locorg", "org"}:
                sent = find_entity_sentence_idx(doc, ent)
                sample = [0] * word_emb_size * word_window_size * 2
                for i in range(1, word_window_size + 1):
                    #left part of the window
                    token_pos = ent.start_token - i
                    if token_pos >= sent.start_token:
                        sample[(i - 1) * word_emb_size: i * word_emb_size] = embedding_model[doc.tokens[token_pos].lower()]
                        pos_tag_id = pos_dict[doc.pos_tags[token_pos]]
                    else:
                        pos_tag_id = pos_dict["OOV"]
                    
                    pos_one_hot = [0]*len(pos_dict)
                    pos_one_hot[pos_tag_id] = 1
                    sample += pos_one_hot
                    
                    
                    #right part of the window
                    token_pos = ent.end_token + i - 1
                    if token_pos < sent.end_token:
                        sample[(word_window_size + i - 1) * word_emb_size:(word_window_size + i) * word_emb_size] = embedding_model[
                            doc.tokens[token_pos].lower()]
                        pos_tag_id = pos_dict[doc.pos_tags[token_pos]]
                    else:
                        pos_tag_id = pos_dict["OOV"]
                        
                    pos_one_hot = [0]*len(pos_dict)
                    pos_one_hot[pos_tag_id] = 1
                    sample += pos_one_hot
                
                entity_len_one_hot = [0]*(max_len+1)
                entity_len = ent.end_token-ent.start_token
                
                if entity_len > max_len:
                    entity_len = max_len
                
                entity_len_one_hot[entity_len] = 1
                
                sample += entity_len_one_hot
                
                samples_x.append(sample)
                if ent.ent_type == "loc":
                    label = 0
#                 elif ent.ent_type == "locorg":
#                     label = 1
                else:
                    label = 1
                samples_y.append(label)
    return samples_x, samples_y, pos_dict

In [82]:
dev_docs = read_docs(dev_files)
word_window_size = 5
x, y, pos_dict = get_samples(dev_docs, word_window_size, 10)

In [84]:
len(x[0])

3201

In [78]:
from sklearn.ensemble import GradientBoostingClassifier

clf = GradientBoostingClassifier(loss="exponential", n_estimators=600, learning_rate=0.1, verbose=True)

clf.fit(x,y)

      Iter       Train Loss   Remaining Time 
         1           0.8572            7.70m
         2           0.8187            7.68m
         3           0.7841            7.66m
         4           0.7551            7.64m
         5           0.7289            7.63m
         6           0.7053            7.64m
         7           0.6844            7.63m
         8           0.6668            7.61m
         9           0.6491            7.60m
        10           0.6335            7.59m
        20           0.5385            7.45m
        30           0.4844            7.32m
        40           0.4458            7.16m
        50           0.4119            6.99m
        60           0.3807            6.86m
        70           0.3509            6.73m
        80           0.3250            6.59m
        90           0.3047            6.44m
       100           0.2855            6.31m
       200           0.1636            4.99m
       300           0.0977            3.71m
       40

GradientBoostingClassifier(criterion='friedman_mse', init=None,
              learning_rate=0.1, loss='exponential', max_depth=3,
              max_features=None, max_leaf_nodes=None,
              min_impurity_decrease=0.0, min_impurity_split=None,
              min_samples_leaf=1, min_samples_split=2,
              min_weight_fraction_leaf=0.0, n_estimators=600,
              n_iter_no_change=None, presort='auto', random_state=None,
              subsample=1.0, tol=0.0001, validation_fraction=0.1,
              verbose=True, warm_start=False)

In [None]:
from sklearn.neural_network import MLPClassifier

clf = MLPClassifier(max_iter=15, hidden_layer_sizes=(3000,300), solver="adam", batch_size=16, learning_rate_init=0.03, verbose=True)

clf.fit(x,y)

In [79]:
test_docs = read_docs(test_files)
test_samples, _, _= get_samples(test_docs, word_window_size, 10, pos_dict)

pred = clf.predict(test_samples)

In [80]:
j = 0
for doc in test_docs:
    for i, ent in enumerate(doc.entities):
        if ent.ent_type in {"loc", "locorg"}:
            doc.entities[i].ent_type = "loc" if pred[j] == 0 else "locorg"
            j += 1
        if ent.ent_type == "org":
            j += 1

In [81]:
write_docs(test_docs, "../resources/out")