In [1]:
import numpy as np
import pandas as pd
import matplotlib as plt
import json

import spacy
from spacy import displacy

from collections import Counter
import glob
import os


from spacy.util import minibatch, compounding
from spacy.util import decaying
import random
import re
from spacy.gold import GoldParse

%matplotlib inline

In [2]:
def getFileList(ftypes, start_path = '.'):
    
    file_list = []
    for dirpath, dirnames, filenames in os.walk(start_path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            if fp.endswith(ftypes):
                file_list.append(fp)
    
    return file_list

In [3]:
# reads txt, ann files (BRAT) and returns annotated data in spacy format
def loadRelations(ann_files, txt_files):
    REL = []
    for af, tf in zip(ann_files, txt_files):
        with open(af, 'r') as ann_data, open(tf, 'r',encoding='utf8') as text_data:
            text = text_data.read()
            res = []
            relist = []
            annlist = []
            for line in ann_data:
                if line.startswith('R'):
                    content = line.split()
                    if content[1] in ['roleDepartment','hasRole']:
                        relist.append(content)
                if line.startswith('T'):
                    content = line.split()
                    annlist.append(content)
                        
            #print(relist)
            #print(annlist)
            for r in relist:
                #print(r)
                s1 = r[2][len('Arg1:'):]
                s2 = r[3][len('Arg2:'):]
                L = [r[1]]
                #print(s1,s2)
                for line in annlist:
                    if line[0]==s1:
                        #print(line)
                        #cont = line.split()
                        L.append((line[1],int(line[2]),int(line[3]),text[int(line[2]):int(line[3])]))
                    if line[0]==s2:
                        #cont1 = line.split()
                        L.append((line[1],int(line[2]),int(line[3]),text[int(line[2]):int(line[3])]))
                #print(L)
                if r[1] == 'hasRole':
                    if L[1][0]=='Person':
                        res += [(L[0],L[1][1:],L[2][1:])]
                    else:
                        res+= [(L[0],L[2][1:],L[1][1:])]
                if r[1] == 'roleDepartment':
                    if L[1][0]=='Role':
                        res+= [(L[0],L[1][1:],L[2][1:])]
                    else:
                        res+= [(L[0],L[2][1:],L[1][1:])]
            if res: REL.append((text,{"relations": res}))    
            
            
    return REL

In [11]:
rds = loadRelations(ann, txt)

In [12]:
from nltk import tokenize
def splitRelations(old_train):
    new_train = []
    for article in old_train:
        old_ents = article[1]['relations']
        full_text = article[0]
        doc_sents = tokenize.sent_tokenize(article[0], language='french')
        
        for ind, sent in enumerate(doc_sents):
            new_ents = []
            sent_start = full_text.index(sent)
            sent_end = sent_start + len(sent)
            
            for item in old_ents:
                #print(item[1][1])
                if (sent_start <= item[1][1] <= sent_end) and (sent_start <= item[2][1] <= sent_end):
                    new_ents.append((item[0], (item[1][0]-sent_start, item[1][1]-sent_start, item[1][2]),(item[2][0]-sent_start, item[2][1]-sent_start, item[2][2])))
            if new_ents:
                new_train.append((sent, {'relations': new_ents}))
    return new_train

In [22]:
def filter_spans(spans):
    # Filter a sequence of spans so they don't contain overlaps
    get_sort_key = lambda span: (span.end - span.start, span.start)
    sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
    result = []
    seen_tokens = set()
    for span in sorted_spans:
        if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
            result.append(span)
            seen_tokens.update(range(span.start, span.end))
    return result

def extract_relations(doc):
    
    spans = list(doc.ents)
    spans = filter_spans(spans)
    with doc.retokenize() as retokenizer:
        for span in spans:
            retokenizer.merge(span)

    relations = []
    for target in filter(lambda w: w.ent_type_ == 'ROL', doc):
        if target.dep_ == "obl":
            subject = [w for w in target.head.lefts if w.dep_ == "nsubj:pass"]
            if subject:
                if subject[0].ent_type_ == 'PER':
                    relations.append(('hasRole',str(subject[0]), str(target)))
                if subject[0].ent_type_ == 'ORG':
                    relations.append(('roleDepartment', str(target), str(subject[0])))
            #'Ce matin, Valérie Pécresse sera installée à la présidence de la région Ile-de-France.',
            #('hasRole','Valérie Pécresse','présidence')

        if target.dep_ == "appos":
            if target.head.ent_type_ == 'PER':
                relations.extend([('hasRole',str(target.head), str(target))])
            if target.head.ent_type_ == 'ORG':
                relations.extend([('roleDepartment', str(target), str(target.head))])
            #« Ce sont des entreprises très souples, dans lesquelles la direction peut prendre très vite des décisions », 
            #commente Jacky Lintignat, directeur général de KPMG France.',
            #('hasRole', 'Jacky Lintignat', 'directeur général')
            else :
                mod = [target.head]
                if mod:
                    mod = [w for w in mod[0].children if w.dep_ == "nmod"]
                    if mod:
                        mod = mod[0]
                        if mod.ent_type_ == 'PER':
                            relations.extend([('hasRole',str(mod), str(target))])
                        if mod.ent_type_ == 'ORG':
                            relations.extend([('roleDepartment',str(target), str(mod))])
                        #'« Ce discours était un signal politique important, analyse Shahin Vallée, chercheur invité au cercle 
                        #de réflexion européen Bruegel.',
                        #('hasRole', 'Shahin Vallée', 'chercheur')
            mod1 = [w for w in target.children if w.dep_ =='nmod' and w.ent_type_ in ["PER","ORG"]]
            if mod1:
                mod1 = mod1[0]
                if mod1.ent_type_ == 'PER':
                    relations.extend([('hasRole',str(mod1), str(target))])
                if mod1.ent_type_ == 'ORG':
                    relations.extend([('roleDepartment', str(target), str(mod1))])
                #"Après Hubert Joly, le patron de l'américain Best Buy, c'est Alain Caparros, celui de l'allemand Rewe, qui a 
                #officiellement décliné.",
                #('roleDepartment', 'patron', 'Best Buy')
            
        if target.dep_ == "acl":
            mod1 = [w for w in target.children if w.dep_ =='nmod' and w.ent_type_ in ["PER","ORG"]]
            if mod1:
                mod1 = mod1[0]
                if mod1.ent_type_ == 'PER':
                    relations.extend([('hasRole',str(mod1), str(target))])
                if mod1.ent_type_ == 'ORG':
                    relations.extend([('roleDepartment', str(target), str(mod1))])
                #"« Si la concurrence s'intensifie, ils ne pourront pas y échapper », prédit Philippe Lerouge, fondateur 
                #du Salon Mobile Payment.",
                #('roleDepartment', 'fondateur', 'Salon Mobile Payment')
        if target.dep_ == "nsubj":
            mod = [w for w in target.children if w.dep_ in ["obj","appos"]]
            if mod:
                mod = mod[0]
                if mod.ent_type_ == 'PER':
                    relations.extend([('hasRole',str(mod), str(target))])
                if mod.ent_type_ == 'ORG':
                    relations.extend([('roleDepartment', str(target), str(mod))])
                #'Le PDG de la division aviation commerciale d\'Airbus, Fabrice Brégier, a estimé que cette livraison marquait "
                #la renaissance de l\'aviation d\'Iran Air" et constituait "un des développements les plus importants de 
                #l\'industrie (aéronautique) depuis bien des années".',
                #('hasRole', 'Fabrice Brégier', 'PDG')
        
        subject = [w for w in target.children if w.dep_ in ("nmod", "nsubj","obl")]
        if subject:
            if subject[0].ent_type_ == 'PER':
                relations.extend([('hasRole',str(s), str(target)) for s in subject])
            if subject[0].ent_type_ == 'ORG':
                relations.extend([('roleDepartment', str(target), str(s)) for s in subject])
            #'« Les mauvais taux de transformation ne nous incitent pas à investir pour adapter nos jeux aux terminaux sous 
            #Android », explique Gonzague de Vallois, vice-président de Gameloft.',
            #('roleDepartment', 'vice-président', 'Gameloft')]})

    return relations

In [8]:
nlp = spacy.load("./nerav1")

In [9]:
# evaluate function takes the dataset & ner model
def evaluate(test_data, model):
    
    ds = splitRelations(test_data)
    s, m, n = 0, 0, 0
    for item in ds:
        test_text = item[0]
        golds = [(gold[0],gold[1][2],gold[2][2]) for gold in item[1]['relations'] if gold[0] in ['roleDepartment','hasRole']]
        #doc = model(test_text)
        pred = extract_relations(model(test_text))
        n += len(golds)
        m += len(pred)
        s += len([element for element in pred if element in golds])
        
    try:
        precision = s/m
    except ZeroDivisionError:
        precision = float('nan')

    try:
        recall = s/n    
    except ZeroDivisionError:
        recall = float('nan')
        
    try:
        f1 = 2*(recall*precision)/(recall+precision)
    except ZeroDivisionError:
        f1 = float('nan')
    
    #print("Precision : ", precision)
    #print('Recall : ', recall)
    #print('F1 score : ', f1)
    
    return precision, recall, f1

In [10]:
def overlaps(element, golds):    
    ranges = [g for g in golds if g[0]==element[0]]
    
    return any ((element[1] in g[1] or g[1] in element[1]) and (element[2] in g[2] or g[2] in element[2]) for g in ranges)

def evaluate_partials(test_data, model):
    
    ds = splitRelations(test_data)
    s, m, n = 0, 0, 0
    for item in test_data:
        test_text = item[0]
        golds = [(gold[0],gold[1][2],gold[2][2]) for gold in item[1]['relations'] if gold[0] in ['roleDepartment','hasRole']]
        
        pred = extract_relations(model(test_text))
        n += len(golds)
        m += len(pred)
        s += len([element for element in pred if overlaps(element, golds)])
        
    try:
        precision = s/m
    except ZeroDivisionError:
        precision = float('nan')

    try:
        recall = s/n    
    except ZeroDivisionError:
        recall = float('nan')
        
    try:
        f1 = 2*(recall*precision)/(recall+precision)
    except ZeroDivisionError:
        f1 = float('nan')
    
    return precision, recall, f1