## Tree building evaluation on gold EDUs (mostly) and playground for tree building scripts

1. Modifications of library components for tree building
2. Scripts for test and evaluation of Sklearn-, AllenNLP- and gold-annotation-based RST parsers on manually segmented corpus

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
from utils.print_tree import printBTree
#from utils.rst_annotation import DiscourseUnit

import sys
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

In [3]:
from isanlp.annotation_rst import DiscourseUnit

class DiscourseUnitCreator:
    def __init__(self, id):
        self.id = id
        
    def __call__(self, left_node, right_node, proba):
        self.id += 1
        return DiscourseUnit(
            id=id,
            left=left_node,
            right=right_node,
            relation=1,
            proba=proba
        )

In [4]:
# %%bash

# pip install dostoevsky
# dostoevsky download fasttext-social-network-model

In [5]:
from _isanlp_rst.src.isanlp_rst.rst_tree_predictor import *
from _isanlp_rst.src.isanlp_rst.greedy_rst_parser import GreedyRSTParser
from _isanlp_rst.src.isanlp_rst.features_extractor import FeaturesExtractor
from _isanlp_rst.src.isanlp_rst.features_processor_default import FeaturesProcessor
from _isanlp_rst.src.isanlp_rst.classifier_wrappers import *

In [6]:
from utils.train_test_split import split_train_dev_test

train, dev, test = split_train_dev_test('./data')

news in train: 0.5344827586206896,	in dev: 0.6470588235294118,	in test: 0.6086956521739131
ling in train: 0.0,	in dev: 0.0,	in test: 0.0
comp in train: 0.0,	in dev: 0.0,	in test: 0.0
blog in train: 0.43103448275862066,	in dev: 0.5294117647058824,	in test: 0.4782608695652174


# Evaluation (Parser)

In [57]:
_SPAN_PREDICTOR = {
    'bimpm': (AllenNLPCustomBiMPMClassifier, 'structure_predictor_bimpm', 0., 0.6),
    'baseline': (SklearnClassifier, 'structure_predictor_baseline', 0.1, 0.2),
    'ensemble': (EnsembleClassifier,)
}

_LABEL_PREDICTOR = {
    'bimpm': (AllenNLPBiMPMClassifier, 'label_predictor_bimpm'),
    'esim': (AllenNLPBiMPMClassifier, 'label_predictor_esim'),
    'baseline': (SklearnClassifier, 'label_predictor_baseline'),
    'ensemble': (EnsembleClassifier,)
}

In [73]:
NEURAL_BINARY_PATH = 'models/structure_predictor_bimpm/elmo/'
BASELINE_BINARY_PATH = 'models/structure_predictor_baseline/'

NEURAL_LABEL_PATH = '../../models/label_predictor_esim/'
BASELINE_LABEL_PATH = 'models/relation_predictor_baseline/'

# neural_binary_classifier = _SPAN_PREDICTOR['bimpm'][0](NEURAL_BINARY_PATH)
baseline_binary_classifier = _SPAN_PREDICTOR['baseline'][0](BASELINE_BINARY_PATH)
# binary_classifier = _SPAN_PREDICTOR['ensemble'][0]((neural_binary_classifier, baseline_binary_classifier))
binary_classifier = baseline_binary_classifier

# neural_label_classifier = _LABEL_PREDICTOR['esim'][0](NEURAL_LABEL_PATH)
baseline_label_classifier = _LABEL_PREDICTOR['baseline'][0](BASELINE_LABEL_PATH)
# label_classifier = _LABEL_PREDICTOR['ensemble'][0]((neural_label_classifier, baseline_label_classifier), 
#                                                    weights=[1., 2.])
# while label classifier is on train!
label_classifier = baseline_label_classifier

features_processor = FeaturesProcessor(model_dir_path='models', verbose=False)
features_extractor = FeaturesExtractor(features_processor)

_predictor = [LargeNNTreePredictor,  # both classifiers are neural
              EnsembleNNTreePredictor,  # structure predictions are neural, for labels use an ensemble
              DoubleEnsembleNNTreePredictor,  # both classifiers are ensembles
             ]

predictor = _predictor[2](features_processor=features_extractor, 
                            relation_predictor_sentence=None,
                            relation_predictor_text=binary_classifier, 
                            label_predictor=label_classifier)

# paragraph_parser = GreedyRSTParser(predictor,
#                                    confidence_threshold=_SPAN_PREDICTOR['bimpm'][2], 
#                                    _same_sentence_bonus=1.)

# document_parser = GreedyRSTParser(predictor,
#                                   confidence_threshold=_SPAN_PREDICTOR['bimpm'][3], 
#                                   _same_sentence_bonus=0.)

# additional_document_parser = GreedyRSTParser(predictor,
#                                              confidence_threshold=_SPAN_PREDICTOR['bimpm'][3]-0.15, 
#                                              _same_sentence_bonus=0.)


paragraph_parser = GreedyRSTParser(predictor,
                                   confidence_threshold=_SPAN_PREDICTOR['baseline'][2], 
                                   _same_sentence_bonus=1.)

document_parser = GreedyRSTParser(predictor,
                                  confidence_threshold=_SPAN_PREDICTOR['baseline'][3], 
                                  _same_sentence_bonus=0.)

additional_document_parser = GreedyRSTParser(predictor,
                                             confidence_threshold=_SPAN_PREDICTOR['baseline'][3]-0.15, 
                                             _same_sentence_bonus=0.)



2021-12-15 16:06:26,462 - INFO - gensim.models.utils_any2vec - loading projection weights from models/w2v/default/model.vec
2021-12-15 16:07:03,976 - INFO - gensim.models.utils_any2vec - loaded (195071, 300) matrix from models/w2v/default/model.vec


In [74]:
from isanlp.annotation import Sentence

def split_by_paragraphs(annot_text, annot_tokens, annot_sentences, annot_lemma, annot_morph, annot_postag,
                        annot_syntax_dep_tree):

    def split_on_two(sents, boundary):
        list_sum = lambda l: sum([len(sublist) for sublist in l])

        i = 1
        while list_sum(sents[:i]) < boundary and i < len(sents):
            i += 1

        intersentence_boundary = min(len(sents[i - 1]), boundary - list_sum(sents[:i - 1]))
        return (sents[:i - 1] + [sents[i - 1][:intersentence_boundary]],
                [sents[i - 1][intersentence_boundary:]] + sents[i:])

    def recount_sentences(chunk):
        sentences = []
        lemma = []
        morph = []
        postag = []
        syntax_dep_tree = []
        tokens_cursor = 0

        for i, sent in enumerate(chunk['syntax_dep_tree']):
            if len(sent) > 0:
                sentences.append(Sentence(tokens_cursor, tokens_cursor + len(sent)))
                lemma.append(chunk['lemma'][i])
                morph.append(chunk['morph'][i])
                postag.append(chunk['postag'][i])
                syntax_dep_tree.append(chunk['syntax_dep_tree'][i])
                tokens_cursor += len(sent)

        chunk['sentences'] = sentences
        chunk['lemma'] = lemma
        chunk['morph'] = morph
        chunk['postag'] = postag
        chunk['syntax_dep_tree'] = syntax_dep_tree

        return chunk

    chunks = []
    prev_right_boundary = -1

    for i, token in enumerate(annot_tokens[:-1]):

        if '\n' in annot_text[token.end:annot_tokens[i + 1].begin]:
            if prev_right_boundary > -1:
                chunk = {
                    'text': annot_text[annot_tokens[prev_right_boundary].end:token.end + 1].strip(),
                    'tokens': annot_tokens[prev_right_boundary + 1:i + 1]
                }
            else:
                chunk = {
                    'text': annot_text[:token.end + 1].strip(),
                    'tokens': annot_tokens[:i + 1]
                }

            lemma, annot_lemma = split_on_two(annot_lemma, i - prev_right_boundary)
            morph, annot_morph = split_on_two(annot_morph, i - prev_right_boundary)
            postag, annot_postag = split_on_two(annot_postag, i - prev_right_boundary)
            syntax_dep_tree, annot_syntax_dep_tree = split_on_two(annot_syntax_dep_tree, i - prev_right_boundary)

            chunk.update({
                'lemma': lemma,
                'morph': morph,
                'postag': postag,
                'syntax_dep_tree': syntax_dep_tree,
            })
            chunks.append(recount_sentences(chunk))

            prev_right_boundary = i  # number of last token in the last chunk

    chunk = {
        'text': annot_text[annot_tokens[prev_right_boundary].end:].strip(),
        'tokens': annot_tokens[prev_right_boundary + 1:],
        'lemma': annot_lemma,
        'morph': annot_morph,
        'postag': annot_postag,
        'syntax_dep_tree': annot_syntax_dep_tree,
    }

    chunks.append(recount_sentences(chunk))
    return chunks

In [75]:
def split_by_paragraphs_edus(edus, text):
    res = []
    parag = []
    
    for edu in edus:
        parag.append(edu)
        boundary = text.find(edu)+len(edu)
        if boundary < len(text):
            if text[boundary] == '\n':
                res.append(parag)
                parag = []
         
    if parag:
        res.append(parag)
    return res

In [76]:
from utils.evaluation import prepare_gold_pairs

### Find edus containing multiple paragraphs and add to exceptions 

In [77]:
from tqdm import tqdm_notebook as tqdm
from utils.file_reading import *
from utils.evaluation import extr_pairs, extr_pairs_forest


broken_files = []
smallest_file = 'data/news2_4.edus'
coolest_file = 'data/blogs_17.edus'
shit = 'data/blogs_99.edus'
#test[:1]
for file in tqdm([smallest_file]):
    filename = '.'.join(file.split('.')[:-1])
    edus = read_edus(filename)
    gold = prepare_gold_pairs(read_gold(filename, features=True))
    annot = read_annotation(filename)
    
    for missegmentation in ("\nIMG", 
                            "\nгимнастический коврик;",
                            "\nгантели или бутылки с песком;",
                            "\nнебольшой резиновый мяч;",
                            "\nэластичная лента (эспандер);",
                            "\nхула-хуп (обруч).",
                            "\n200?",
                            "\n300?",
                            "\nНе требуйте странного.",
                            "\nИспользуйте мою модель.",
                            '\n"А чего вы от них требуете?"',
                            '\n"Решить проблемы с тестерами".',
                            "\nКак гончая на дичь.", "\nИ крупная.",
                            "\nВ прошлом году компания удивила рынок",
                            "\nЧужой этики особенно.",
                            "\nНо и своей тоже.",
                            "\nАэропорт имени,",
                            "\nА вот и монголы.",
                            "\nЗолотой Будда.", 
                            "\nДворец Богдо-Хана.",
                            "\nПлощадь Сухэ-Батора.",
                            "\nОдноклассники)",
                            "\nВечерняя площадь.",
                            "\nТугрики.",
                            "\nВнутренние монголы.",
                            "\nВид сверху.",
                            "\nНациональный парк Тэрэлж. IMG IMG",
                            '\nГора "Черепаха".',
                            "\nПуть к медитации.",
                            "\nЖить надо высоко,",
                            "\nЧан с кумысом.",
                            "\nЖилая юрта.",
                            "\nКумыс.",
                            "\nТрадиционное занятие монголов",
                            "\nДвугорбый верблюд мало где",
                            "\nМонгол Шуудан переводится",
                            "\nОвощные буузы.",
                            "\nЗнаменитый чай!"
                            ):
        annot['text'] = annot['text'].replace(missegmentation, ' '+missegmentation[1:])

    for edu in edus:
        if annot['text'].find(edu) == -1:
            print(f'::: {filename} ::: {edu}')

  0%|          | 0/1 [00:00<?, ?it/s]

### Evaluate on test

In [63]:
from utils.export_to_rs3 import ForestExporter  # for list of units (whole document)
from utils.export_to_rs3 import Exporter  # for single unit (one tree)

exporter = ForestExporter(encoding='utf-8')

! rm -r gold_predictions
! mkdir gold_predictions/

In [78]:
cache = []

In [79]:
from tqdm import tqdm_notebook as tqdm
from utils.file_reading import *
from utils.evaluation import *


broken_files = []
smallest_file = 'data/news2_4.edus'
weirdest_file = 'data/blogs_63.edus'

for file in tqdm(test[1:3]):
    filename = '.'.join(file.split('.')[:-1])
    edus = read_edus(filename)
    gold = prepare_gold_pairs(read_gold(filename, features=True))
    annot = read_annotation(filename)
    annot['text'] = annot['text'].strip()
    
    for missegmentation in ("\nIMG", 
                            "\nгимнастический коврик;",
                            "\nгантели или бутылки с песком;",
                            "\nнебольшой резиновый мяч;",
                            "\nэластичная лента (эспандер);",
                            "\nхула-хуп (обруч).",
                            "\n200?",
                            "\n300?",
                            "\nНе требуйте странного.",
                            "\nИспользуйте мою модель.",
                            '\n"А чего вы от них требуете?"',
                            '\n"Решить проблемы с тестерами".',
                            "\nКак гончая на дичь.", "\nИ крупная.",
                            "\nВ прошлом году компания удивила рынок",
                            "\nЧужой этики особенно.",
                            "\nНо и своей тоже.",
                            "\nАэропорт имени,",
                            "\nА вот и монголы.",
                            "\nЗолотой Будда.", 
                            "\nДворец Богдо-Хана.",
                            "\nПлощадь Сухэ-Батора.",
                            "\nОдноклассники)",
                            "\nВечерняя площадь.",
                            "\nТугрики.",
                            "\nВнутренние монголы.",
                            "\nВид сверху.",
                            "\nНациональный парк Тэрэлж. IMG IMG",
                            '\nГора "Черепаха".',
                            "\nПуть к медитации.",
                            "\nЖить надо высоко,",
                            "\nЧан с кумысом.",
                            "\nЖилая юрта.",
                            "\nКумыс.",
                            "\nТрадиционное занятие монголов",
                            "\nДвугорбый верблюд мало где",
                            "\nМонгол Шуудан переводится",
                            "\nОвощные буузы.",
                            "\nЗнаменитый чай!",
                            ):
        annot['text'] = annot['text'].replace(missegmentation, ' ' + missegmentation[1:])

    
    if '\n' in annot['text']:
        chunks = split_by_paragraphs(
            annot['text'],
            annot['tokens'], 
            annot['sentences'], 
            annot['lemma'], 
            annot['morph'], 
            annot['postag'], 
            annot['syntax_dep_tree'])
        
        chunked_edus = split_by_paragraphs_edus(edus, annot['text'])
    
    dus = []
    start_id = 0
    for i, chunk in enumerate(tqdm(chunks)):
        _edus = []
        last_end = 0
        
        for max_id in range(len(chunked_edus[i])):
            start = len(annot['text'][:last_end]) + annot['text'][last_end:].find(chunked_edus[i][max_id])
            end = start + len(chunked_edus[i][max_id])
            temp = DiscourseUnit(
                    id=start_id,
                    left=None,
                    right=None,
                    relation='edu',
                    start=start,
                    end=end,
                    orig_text=annot['text'],
                    proba=1.,
                )

            _edus.append(temp)
            last_end = end + 1
            start_id += 1
            
        if len(_edus) == 1:
            dus += _edus
            start_id = _edus[-1].id + 1

        elif len(_edus) > 1:
            trees = paragraph_parser(_edus,
                annot['text'], chunk['tokens'], chunk['sentences'], chunk['lemma'],
                chunk['morph'], chunk['postag'], chunk['syntax_dep_tree'])
            
            dus += trees
            start_id = max([tree.id for tree in dus]) + 1
        
    parsed = document_parser(
                dus, 
                annot['text'], 
                annot['tokens'], 
                annot['sentences'], 
                annot['lemma'], 
                annot['morph'], 
                annot['postag'], 
                annot['syntax_dep_tree'],
                genre=filename.split('_')[0])
    
    if len(parsed) > len(annot['text']) // 400:
        parsed = additional_document_parser(
            parsed, 
            annot['text'], 
            annot['tokens'], 
            annot['sentences'], 
            annot['lemma'], 
            annot['morph'], 
            annot['postag'], 
            annot['syntax_dep_tree'],
            genre=filename.split('_')[0]
        )
        
    exporter(parsed, f"gold_predictions/{filename.split('/')[-1]}_parsed_goldedu.rs3")
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed, annot['text']), 
                                columns=['snippet_x', 'snippet_y', 'category_id', 'order'])
    evaluation = eval_pipeline(parsed_pairs=parsed_pairs,
                               gold_edus=edus,
                               gold_pairs=gold[['snippet_x', 'snippet_y', 'category_id', 'order']],
                               text=annot['text'],
                               trees=parsed)
    evaluation['filename'] = file
    print(evaluation)
    cache.append(evaluation)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

{'seg_true_pred': 112, 'seg_all_pred': 112, 'seg_all_true': 112, 'unlab_true_pred': 35, 'unlab_all_pred': 96, 'unlab_all_true': 104, 'lab_true_pred': 19, 'lab_all_pred': 96, 'lab_all_true': 104, 'nuc_true_pred': 23, 'nuc_all_pred': 96, 'nuc_all_true': 104, 'full_true_pred': 18, 'full_all_pred': 96, 'full_all_true': 104, 'filename': './data/news1_47.edus'}


  0%|          | 0/35 [00:00<?, ?it/s]

{'seg_true_pred': 245, 'seg_all_pred': 246, 'seg_all_true': 490, 'unlab_true_pred': 92, 'unlab_all_pred': 216, 'unlab_all_true': 227, 'lab_true_pred': 63, 'lab_all_pred': 216, 'lab_all_true': 227, 'nuc_true_pred': 78, 'nuc_all_pred': 216, 'nuc_all_true': 227, 'full_true_pred': 63, 'full_all_pred': 216, 'full_all_true': 227, 'filename': './data/news2_38.edus'}


In [80]:
re = 35./104
pr = 35./98
2 * pr * re / (pr + re)

0.3465346534653466

In [81]:
re = 19./104
pr = 19./98
2 * pr * re / (pr + re)

0.1881188118811881

In [82]:
from utils.file_reading import *
from utils.evaluation import *

In [83]:
print(parsed[0])

id: 454
text: Проект по созданию объединенных арабских ВС требует перезапуска - Флоренс Гауб
Проект по созданию объединенных арабских вооруженных сил, отложенный на неопределенный срок из-за разногласий между членами Лиги арабских государств (ЛАГ) и недостаточной проработки законодательной базы, нуждается в перезапуске в новом формате и с уточненными целями. К такому выводу пришла старший аналитик Института Европейского союза по вопросам безопасности (EUISS) Флоренс Гауб (Florence Gaub) в статье "Застряли в бараках: объединенные арабские вооруженные силы".
Она напоминает, что в январе 2015 года секретариат ЛАГ вышел за рамки ранее обсуждавшихся ограниченных военных альянсов и предложил сформировать единые межарабские силы быстрого реагирования, ориентированные на борьбу с терроризмом, на основании договора о совместной обороне и экономическом сотрудничестве от 1950 года.
proba: 0.4533515563355963
relation: joint
nuclearity: NN
left: Проект по созданию объединенных арабских ВС требует п

In [84]:
print(parsed[4].right)

id: 80
text: Что необычно, в первую очередь совет ЛАГ поспешил внести поправки в устав Межарабского совета мира и безопасности для организации встреч на правительственном уровне два раза в год. Ранее эта структура, основанная в 2006 году, не имела никаких полномочий и состояла всего из пяти выборных членов. Перед советом была поставлена задача - подготовить стратегии по сохранению мира в регионе и укреплению безопасности в арабских странах.
proba: 0.21887695231527632
relation: joint
nuclearity: NN
left: Что необычно, в первую очередь совет ЛАГ поспешил внести поправки в устав Межарабского совета мира и безопасности для организации встреч на правительственном уровне два раза в год. Ранее эта структура, основанная в 2006 году, не имела никаких полномочий и состояла всего из пяти выборных членов.
right: Перед советом была поставлена задача - подготовить стратегии по сохранению мира в регионе и укреплению безопасности в арабских странах.
start: 2389
end: 2820


In [85]:
#tmp = pd.DataFrame(cache[7:27] + cache[28:])
tmp = pd.DataFrame(cache)
tmp['pr_seg'] = tmp.seg_true_pred / tmp.seg_all_pred
tmp['re_seg'] = tmp.seg_true_pred / tmp.seg_all_true
tmp['f1_seg'] = 2 * tmp.pr_seg * tmp.re_seg / (tmp.pr_seg + tmp.re_seg)
tmp['pr_unlab'] = tmp.unlab_true_pred / tmp.unlab_all_pred
tmp['re_unlab'] = tmp.unlab_true_pred / tmp.unlab_all_true
tmp['f1_unlab'] = 2 * tmp.pr_unlab * tmp.re_unlab / (tmp.pr_unlab + tmp.re_unlab)
tmp['pr_lab'] = tmp.lab_true_pred / tmp.lab_all_pred
tmp['re_lab'] = tmp.lab_true_pred / tmp.lab_all_true
tmp['f1_lab'] = 2 * tmp.pr_lab * tmp.re_lab / (tmp.pr_lab + tmp.re_lab)
tmp['pr_nuc'] = tmp.nuc_true_pred / tmp.nuc_all_pred
tmp['re_nuc'] = tmp.nuc_true_pred / tmp.nuc_all_true
tmp['f1_nuc'] = 2 * tmp.pr_nuc * tmp.re_nuc / (tmp.pr_nuc + tmp.re_nuc)
tmp['pr_full'] = tmp.full_true_pred / tmp.full_all_pred
tmp['re_full'] = tmp.full_true_pred / tmp.full_all_true
tmp['f1_full'] = 2 * tmp.pr_full * tmp.re_full / (tmp.pr_full + tmp.re_full)
tmp.sort_values('f1_full')

Unnamed: 0,seg_true_pred,seg_all_pred,seg_all_true,unlab_true_pred,unlab_all_pred,unlab_all_true,lab_true_pred,lab_all_pred,lab_all_true,nuc_true_pred,...,f1_unlab,pr_lab,re_lab,f1_lab,pr_nuc,re_nuc,f1_nuc,pr_full,re_full,f1_full
0,112,112,112,35,96,104,19,96,104,23,...,0.35,0.197917,0.182692,0.19,0.239583,0.221154,0.23,0.1875,0.173077,0.18
1,245,246,490,92,216,227,63,216,227,78,...,0.41535,0.291667,0.277533,0.284424,0.361111,0.343612,0.352144,0.291667,0.277533,0.284424


Unnamed: 0,seg_true_pred,seg_all_pred,seg_all_true,unlab_true_pred,unlab_all_pred,unlab_all_true,lab_true_pred,lab_all_pred,lab_all_true,nuc_true_pred,...,f1_unlab,pr_lab,re_lab,f1_lab,pr_nuc,re_nuc,f1_nuc,pr_full,re_full,f1_full
1,112,112,112,29,96,103,14,96,103,23,...,0.291457,0.145833,0.135922,0.140704,0.239583,0.223301,0.231156,0.145833,0.135922,0.140704
0,91,91,91,37,76,79,21,76,79,28,...,0.477419,0.276316,0.265823,0.270968,0.368421,0.35443,0.36129,0.276316,0.265823,0.270968
2,245,245,245,87,206,219,64,206,219,80,...,0.409412,0.31068,0.292237,0.301176,0.38835,0.365297,0.376471,0.31068,0.292237,0.301176


In [86]:
tmp2 = tmp[:]

In [87]:
tmp = tmp2[:]

In [88]:
tmp_news = tmp2[tmp2.filename.str.contains('news')]

In [89]:
tmp_blog = tmp2[tmp2.filename.str.contains('blog')]

Unlabeled tree building score

In [90]:
pr_micro = tmp.unlab_true_pred.sum() / tmp.unlab_all_pred.sum() * 100.
re_micro = tmp.unlab_true_pred.sum() / tmp.unlab_all_true.sum() * 100.
f1_micro = 2. * pr_micro * re_micro / (pr_micro + re_micro)

unlab_micro = (pr_micro, re_micro, f1_micro)
unlab_micro

(40.705128205128204, 38.368580060422964, 39.502332814930014)

In [91]:
pr_macro = tmp.pr_unlab.sum() / tmp.shape[0] * 100.
re_macro = tmp.re_unlab.sum() / tmp.shape[0] * 100.
f1_macro = 2. * pr_macro * re_macro / (pr_macro + re_macro)

unlab_macro = (pr_macro, re_macro, f1_macro)
unlab_macro

(39.52546296296296, 37.091240257539816, 38.26968223444606)

In [104]:
pr_macro = tmp_blog.pr_unlab.sum() / tmp_blog.shape[0] * 100.
re_macro = tmp_blog.re_unlab.sum() / tmp_blog.shape[0] * 100.
unlab_blog = 2. * pr_macro * re_macro / (pr_macro + re_macro)

pr_macro = tmp_news.pr_unlab.sum() / tmp_news.shape[0] * 100.
re_macro = tmp_news.re_unlab.sum() / tmp_news.shape[0] * 100.
unlab_news = 2. * pr_macro * re_macro / (pr_macro + re_macro)

Labeled tree building score

In [93]:
pr_micro = tmp.lab_true_pred.sum() / tmp.lab_all_pred.sum() * 100.
re_micro = tmp.lab_true_pred.sum() / tmp.lab_all_true.sum() * 100.
f1_micro = 2. * pr_micro * re_micro / (pr_micro + re_micro)

lab_micro = (pr_micro, re_micro, f1_micro)
lab_micro

(26.282051282051285, 24.773413897280967, 25.5054432348367)

In [94]:
pr_macro = tmp.pr_lab.sum() / tmp.shape[0] * 100.
re_macro = tmp.re_lab.sum() / tmp.shape[0] * 100.
f1_macro = 2. * pr_macro * re_macro / (pr_macro + re_macro)

lab_macro = (pr_macro, re_macro, f1_macro)
lab_macro

(24.479166666666668, 23.011267366994236, 23.72253109704657)

In [105]:
pr_macro = tmp_blog.pr_lab.sum() / tmp_blog.shape[0] * 100.
re_macro = tmp_blog.re_lab.sum() / tmp_blog.shape[0] * 100.
lab_blog = 2. * pr_macro * re_macro / (pr_macro + re_macro)

pr_macro = tmp_news.pr_lab.sum() / tmp_news.shape[0] * 100.
re_macro = tmp_news.re_lab.sum() / tmp_news.shape[0] * 100.
lab_news = 2. * pr_macro * re_macro / (pr_macro + re_macro)

Nuclearity score

In [96]:
pr_micro = tmp.nuc_true_pred.sum() / tmp.nuc_all_pred.sum() * 100.
re_micro = tmp.nuc_true_pred.sum() / tmp.nuc_all_true.sum() * 100.
f1_micro = 2. * pr_micro * re_micro / (pr_micro + re_micro)

nuc_micro = (pr_micro, re_micro, f1_micro)
nuc_micro

(32.371794871794876, 30.513595166163142, 31.41524105754277)

In [97]:
pr_macro = tmp.pr_nuc.sum() / tmp.shape[0] * 100.
re_macro = tmp.re_nuc.sum() / tmp.shape[0] * 100.
f1_macro = 2. * pr_macro * re_macro / (pr_macro + re_macro)

nuc_macro = (pr_macro, re_macro, f1_macro)
nuc_macro

(30.03472222222222, 28.23830904778041, 29.10882615135016)

In [98]:
pr_macro = tmp_blog.pr_nuc.sum() / tmp_blog.shape[0] * 100.
re_macro = tmp_blog.re_nuc.sum() / tmp_blog.shape[0] * 100.
nuc_blog = 2. * pr_macro * re_macro / (pr_macro + re_macro)

pr_macro = tmp_news.pr_nuc.sum() / tmp_news.shape[0] * 100.
re_macro = tmp_news.re_nuc.sum() / tmp_news.shape[0] * 100.
nuc_news = 2. * pr_macro * re_macro / (pr_macro + re_macro)

Full tree building score

In [99]:
pr_micro = tmp.full_true_pred.sum() / tmp.full_all_pred.sum() * 100.
re_micro = tmp.full_true_pred.sum() / tmp.full_all_true.sum() * 100.
f1_micro = 2. * pr_micro * re_micro / (pr_micro + re_micro)

full_micro = pr_micro, re_micro, f1_micro
full_micro

(25.961538461538463, 24.47129909365559, 25.194401244167963)

In [100]:
pr_macro = tmp.pr_full.sum() / tmp.shape[0] * 100.
re_macro = tmp.re_full.sum() / tmp.shape[0] * 100.
f1_macro = 2. * pr_macro * re_macro / (pr_macro + re_macro)

full_macro = (pr_macro, re_macro, f1_macro)
full_macro

(23.958333333333336, 22.530498136225006, 23.222488819371154)

In [101]:
pr_macro = tmp_blog.pr_full.sum() / tmp_blog.shape[0] * 100.
re_macro = tmp_blog.re_full.sum() / tmp_blog.shape[0] * 100.
full_blog = 2. * pr_macro * re_macro / (pr_macro + re_macro)

pr_macro = tmp_news.pr_full.sum() / tmp_news.shape[0] * 100.
re_macro = tmp_news.re_full.sum() / tmp_news.shape[0] * 100.
full_news = 2. * pr_macro * re_macro / (pr_macro + re_macro)

Draw a table

blogs

In [106]:
evaluation_table = pd.DataFrame(columns=['component', 'P', 'R', 'F1', 'P', 'R', 'F1'], data=[
    ['span', unlab_micro[0], unlab_micro[1], unlab_micro[2], unlab_macro[0], unlab_macro[1], unlab_macro[2]],
    ['nuclearity', nuc_micro[0], nuc_micro[1], nuc_micro[2], nuc_macro[0], nuc_macro[1], nuc_macro[2]],
    ['relation', lab_micro[0], lab_micro[1], lab_micro[2], lab_macro[0], lab_macro[1], lab_macro[2]],
    ['full', full_micro[0], full_micro[1], full_micro[2], full_macro[0], full_macro[1], full_macro[2]],
])

print(evaluation_table.to_latex(index=False, float_format='%.2f', column_format='|l|l|l|l|'))

\begin{tabular}{|l|l|l|l|}
\toprule
 component &     P &     R &    F1 &     P &     R &    F1 \\
\midrule
      span & 40.71 & 38.37 & 39.50 & 39.53 & 37.09 & 38.27 \\
nuclearity & 32.37 & 30.51 & 31.42 & 30.03 & 28.24 & 29.11 \\
  relation & 26.28 & 24.77 & 25.51 & 24.48 & 23.01 & 23.72 \\
      full & 25.96 & 24.47 & 25.19 & 23.96 & 22.53 & 23.22 \\
\bottomrule
\end{tabular}



news

In [107]:
evaluation_table = pd.DataFrame(columns=['component', 'P', 'R', 'F1', 'P', 'R', 'F1'], data=[
    ['span', unlab_micro[0], unlab_micro[1], unlab_micro[2], unlab_macro[0], unlab_macro[1], unlab_macro[2]],
    ['nuclearity', nuc_micro[0], nuc_micro[1], nuc_micro[2], nuc_macro[0], nuc_macro[1], nuc_macro[2]],
    ['relation', lab_micro[0], lab_micro[1], lab_micro[2], lab_macro[0], lab_macro[1], lab_macro[2]],
    ['full', full_micro[0], full_micro[1], full_micro[2], full_macro[0], full_macro[1], full_macro[2]],
])

print(evaluation_table.to_latex(index=False, float_format='%.2f', column_format='|l|l|l|l|'))

\begin{tabular}{|l|l|l|l|}
\toprule
 component &     P &     R &    F1 &     P &     R &    F1 \\
\midrule
      span & 40.71 & 38.37 & 39.50 & 39.53 & 37.09 & 38.27 \\
nuclearity & 32.37 & 30.51 & 31.42 & 30.03 & 28.24 & 29.11 \\
  relation & 26.28 & 24.77 & 25.51 & 24.48 & 23.01 & 23.72 \\
      full & 25.96 & 24.47 & 25.19 & 23.96 & 22.53 & 23.22 \\
\bottomrule
\end{tabular}



append separated genres to the main table

In [108]:
evaluation_table = pd.DataFrame(columns=['component', 'P', 'R', 'F1', 'P', 'R', 'F1', 'blogs', 'news'], data=[
    ['span', unlab_micro[0], unlab_micro[1], unlab_micro[2], unlab_macro[0], unlab_macro[1], unlab_macro[2], unlab_blog, unlab_news],
    ['nuclearity', nuc_micro[0], nuc_micro[1], nuc_micro[2], nuc_macro[0], nuc_macro[1], nuc_macro[2], nuc_blog, nuc_news],
    ['relation', lab_micro[0], lab_micro[1], lab_micro[2], lab_macro[0], lab_macro[1], lab_macro[2], lab_blog, lab_news],
    ['full', full_micro[0], full_micro[1], full_micro[2], full_macro[0], full_macro[1], full_macro[2], full_blog, full_news],
])

print(evaluation_table.to_latex(index=False, float_format='%.2f', column_format='|l|l|l|l|'))

\begin{tabular}{|l|l|l|l|}
\toprule
 component &     P &     R &    F1 &     P &     R &    F1 &  blogs &  news \\
\midrule
      span & 40.71 & 38.37 & 39.50 & 39.53 & 37.09 & 38.27 &    NaN & 38.27 \\
nuclearity & 32.37 & 30.51 & 31.42 & 30.03 & 28.24 & 29.11 &    NaN & 29.11 \\
  relation & 26.28 & 24.77 & 25.51 & 24.48 & 23.01 & 23.72 &    NaN & 23.72 \\
      full & 25.96 & 24.47 & 25.19 & 23.96 & 22.53 & 23.22 &    NaN & 23.22 \\
\bottomrule
\end{tabular}



# Evaluation (Gold)

In [109]:
cache = {}

In [110]:
from utils.evaluation import metric_parseval_df as metric_parseval
from utils.evaluation import extr_pairs_forest
from utils.file_reading import *

In [111]:
filenames = []
true_pos = []
all_parsed = []
all_gold = []

for key, value in cache.items():
    c_true_pos, c_all_parsed, c_all_gold = metric_parseval(value[0], value[1])
    filenames.append(key)
    true_pos.append(c_true_pos)
    all_parsed.append(c_all_parsed)
    all_gold.append(c_all_gold)
    
results = pd.DataFrame({'filename': filenames, 
                    'true_pos': true_pos,
                    'all_parsed': all_parsed,
                    'all_gold': all_gold})

In [112]:
from _isanlp_rst.src.isanlp_rst.rst_tree_predictor import GoldTreePredictor
from isanlp.annotation_rst import ForestExporter
exporter = ForestExporter(encoding='utf8')

In [123]:
def parse_golds(filename):
    filename = '.'.join(filename.split('.')[:-1])
    edus = read_edus(filename)
    gold = read_gold(filename, features=True)
    annot = read_annotation(filename)
    annot['text'] = annot['text'].replace('\n', ' ').replace('  ', ' ').replace('  ', ' ')
    
    _edus = []
    last_end = 0
    last_id = 0
    for max_id in range(len(edus)):
        start = len(annot['text'][:last_end]) + annot['text'][last_end:].find(edus[max_id])
        end = start + len(edus[max_id])
        temp = DiscourseUnit(
                id=start,
                left=None,
                right=None,
                relation='edu',
                start=start,
                end=end,
                text=edus[max_id],
                orig_text=annot['text'],
                proba=1.,
            )
        _edus.append(temp)
        last_end = end
        last_id += 1

    parser = GreedyRSTParser(GoldTreePredictor(gold), confidence_threshold=0.)
    parsed = parser(_edus, annot['text'], annot['tokens'], annot['sentences'],
                    annot['postag'], annot['morph'], annot['lemma'], annot['syntax_dep_tree'])
    
    filename = filename.split('/')[-1]
    exporter(parsed, 'parsed_golds_0406/'+filename+'.rs3')
    
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed, annot['text'], locations=True), 
                                columns=['snippet_x', 'snippet_y', 'category_id', 'order', 'loc_x', 'loc_y'])
    
    return filename, metric_parseval(parsed_pairs, gold)

In [124]:
! mkdir parsed_golds_0406

mkdir: cannot create directory ‘parsed_golds_0406’: File exists


In [125]:
%%writefile /opt/.pyenv/versions/3.7.4/lib/python3.7/site-packages/isanlp/annotation_rst.py

class DiscourseUnit:
    def __init__(self, id, left=None, right=None, text='', start=None, end=None,
                 orig_text=None, relation='', nuclearity='', proba=1.):
        """
        :param int id:
        :param DiscourseUnit left:
        :param DiscourseUnit right:
        :param str text: (optional)
        :param int start: start position in original text
        :param int end: end position in original text
        :param string relation: {the relation between left and right components | 'elementary' | 'root'}
        :param string nuclearity: {'NS' | 'SN' | 'NN'}
        :param float proba: predicted probability of the relation occurrence
        """
        self.id = id
        self.left = left
        self.right = right
        self.relation = relation
        self.nuclearity = nuclearity
        self.proba = proba
        self.start = start
        self.end = end

        if self.left:
            self.start = left.start
            self.end = right.end

        if orig_text:
            self.text = orig_text[self.start:self.end + 1].strip()
        else:
            self.text = text.strip()

        self._exporter = None

    def __str__(self):
        result = f"id: {self.id}\n"
        result += f"text: {self.text}\n"
        result += f"proba: {self.proba}\n"
        result += f"relation: {self.relation}\n"
        result += f"nuclearity: {self.nuclearity}\n"
        result += f"left: {self.left.text if self.left else None}\n"
        result += f"right: {self.right.text if self.right else None}\n"
        result += f"start: {self.start}\n"
        result += f"end: {self.end}"
        return result

    def to_rs3(self, filename, encoding='utf8'):
        self._exporter = Exporter(encoding=encoding)
        self._exporter(self, filename)


class Segment:
    def __init__(self, _id, parent, relname, text):
        self.id = _id + 1
        self.parent = parent + 1
        self.relname = relname
        self.text = text

    def __str__(self):
        if self.parent != -1:
            return f'<segment id="{self.id}" parent="{self.parent}" relname="{self.relname}">{self.text}</segment>'

        return f'<segment id="{self.id}" relname="{self.relname}">{self.text}</segment>'


class Group:
    def __init__(self, _id, type, parent, relname):
        self.id = _id + 1
        self.type = type
        self.parent = parent + 1
        self.relname = relname

    def __str__(self):
        return f'<group id="{self.id}" type="{self.type}" parent="{self.parent}" relname="{self.relname}"/>'


class Root(Group):
    def __init__(self, _id, type="span"):
        Group.__init__(self, _id, type=type, parent=-1, relname="span")

    def __str__(self):
        return f'<group id="{self.id}" type="{self.type}"/>'


class Exporter:
    def __init__(self, encoding='cp1251', verbose=False):
        self._encoding = encoding
        self.verbose = verbose
        self.max_id = 0

    def __call__(self, tree, filename):

        with open(filename, 'w', encoding=self._encoding) as fo:
            fo.write('<rst>\n')
            fo.write(self.make_header(tree))
            fo.write(self.make_body(tree))
            fo.write('</rst>')

    def compile_relation_set(self, tree):
        result = ['_'.join([tree.relation, tree.nuclearity])] + ['antithesis_NN']
        if not tree.left:
            return result
        if tree.left.left:
            result += self.compile_relation_set(tree.left)
        if tree.right.left:
            result += self.compile_relation_set(tree.right)

        return result

    def make_header(self, tree):
        relations = list(set(self.compile_relation_set(tree)))
        relations = [value if value != "elementary__" else "antithesis_NN" for value in relations]

        result = '\t<header>\n'
        result += '\t\t<relations>\n'
        for rel in relations:
            _relname, _type = rel.split('_')
            _type = 'multinuc' if _type == 'NN' else 'rst'
            result += f'\t\t\t<rel name="{_relname}" type="{_type}" />\n'
        result += '\t\t</relations>\n'
        result += '\t</header>\n'

        return result

    def print_log(self, log):
        if self.verbose:
            print(log)

    def get_max_id(self):
        self.max_id += 1
        return self.max_id

    def get_groups_and_edus(self, tree, terminal=False):
        groups = []
        edus = []

        if not tree.left:
            edus.append(Segment(tree.id, parent=-2, relname='antithesis', text=tree.text))
            return groups, edus

        if not tree.left.left:
            if tree.nuclearity == "SN":
                edus.append(Segment(tree.left.id, parent=tree.right.id, relname=tree.relation, text=tree.left.text))
            elif tree.nuclearity == "NS":
                edus.append(Segment(tree.left.id, parent=tree.id, relname='span', text=tree.left.text))
            else:
                edus.append(Segment(tree.left.id, parent=tree.id, relname=tree.relation, text=tree.left.text))

        else:
            if tree.nuclearity == "SN":
                groups.append(Group(tree.left.id, type='span', parent=tree.right.id, relname=tree.relation))
            elif tree.nuclearity == "NS":
                groups.append(Group(tree.left.id, type='span', parent=tree.id, relname='span'))
            else:
                groups.append(Group(tree.left.id, type='multinuc', parent=tree.id, relname=tree.relation))

            _groups, _edus = self.get_groups_and_edus(tree.left)
            groups += _groups
            edus += _edus

        if not tree.right.left:
            if tree.nuclearity == "SN":
                edus.append(Segment(tree.right.id, parent=tree.id, relname='span', text=tree.right.text))
            elif tree.nuclearity == "NS":
                edus.append(Segment(tree.right.id, parent=tree.left.id, relname=tree.relation, text=tree.right.text))
            else:
                edus.append(Segment(tree.right.id, parent=tree.id, relname=tree.relation, text=tree.right.text))

        else:
            if tree.nuclearity == "SN":
                groups.append(Group(tree.right.id, type='multinuc', parent=tree.id, relname='span'))
            elif tree.nuclearity == "NS":
                groups.append(Group(tree.right.id, type='span', parent=tree.left.id, relname=tree.relation))
            else:
                groups.append(Group(tree.right.id, type='span', parent=tree.id, relname=tree.relation))

            _groups, _edus = self.get_groups_and_edus(tree.right)
            groups += _groups
            edus += _edus

        if terminal:
            if len(edus) > 1:
                if tree.nuclearity == "NN":
                    groups.append(Root(tree.id, type='multinuc'))
                else:
                    groups.append(Root(tree.id))

        return groups, edus

    def make_body(self, tree):
        groups, edus = self.get_groups_and_edus(tree, terminal=True)
        result = '\t<body>\n'
        for edu in edus + groups:
            result += '\t\t' + str(edu) + '\n'
        result += '\t</body>\n'

        return result


class ForestExporter:
    def __init__(self, encoding='cp1251', verbose=False):
        self._encoding = encoding
        self._tree_exporter = Exporter(self._encoding, verbose=verbose)

    def __call__(self, trees, filename):

        with open(filename, 'w', encoding=self._encoding) as fo:
            fo.write('<rst>\n')
            fo.write(self.make_header(trees))
            fo.write(self.make_body(trees))
            fo.write('</rst>')

    def compile_relation_set(self, trees):
        result = []

        for tree in trees:
            result += list(set(self._tree_exporter.compile_relation_set(tree)))

        result = [value if value != "elementary__" else "antithesis_NN" for value in result]
        return result

    def make_header(self, trees):
        relations = list(set(self.compile_relation_set(trees)))

        result = '\t<header>\n'
        result += '\t\t<relations>\n'
        for rel in relations:
            _relname, _type = rel.split('_')
            _type = 'multinuc' if _type == 'NN' else 'rst'
            result += f'\t\t\t<rel name="{_relname}" type="{_type}" />\n'
        result += '\t\t</relations>\n'
        result += '\t</header>\n'

        return result

    def make_body(self, trees):
        groups, edus = [], []
        roots = []

        for tree in trees:
            _groups, _edus = self._tree_exporter.get_groups_and_edus(tree, terminal=True)

            if len(edus) > 1:
                if tree.nuclearity == "NN":
                    roots.append(Root(tree.id, type='multinuc'))
                else:
                    roots.append(Root(tree.id))

            groups += _groups
            edus += _edus

        result = '\t<body>\n'
        for edu in edus + groups:
            result += '\t\t' + str(edu) + '\n'
        result += '\t</body>\n'

        return result.replace('\u2015', '-')

Overwriting /opt/.pyenv/versions/3.7.4/lib/python3.7/site-packages/isanlp/annotation_rst.py


In [126]:
%%time

import multiprocessing as mp

pool = mp.Pool(5)
result = pool.map(parse_golds, test)
pool.close()

CPU times: user 20.9 ms, sys: 100 ms, total: 121 ms
Wall time: 15.9 s


In [127]:
result = [(f[0], f[1][0], f[1][1], f[1][2]) for f in result]

In [128]:
results = pd.DataFrame(columns=['filename', 'true_pos', 'all_parsed', 'all_gold'], data=result)
difference = results['all_parsed'] - results['true_pos']
results['all_gold'] += difference
results['true_pos'] = results['all_parsed']

results['recall'] = results['true_pos'] / results['all_gold']
results['precision'] = results['true_pos'] / results['all_parsed']
results['F1'] = 2 * results['precision'] * results['recall'] / (results['precision'] + results['recall'])

In [129]:
results.sort_values('F1')

Unnamed: 0,filename,true_pos,all_parsed,all_gold,recall,precision,F1
16,blogs_21,265,265,272,0.974265,1.0,0.986965
8,news2_34,159,159,161,0.987578,1.0,0.99375
17,blogs_99,200,200,202,0.990099,1.0,0.995025
22,blogs_39,202,202,204,0.990196,1.0,0.995074
5,news1_28,164,164,165,0.993939,1.0,0.99696
0,news1_23,83,83,83,1.0,1.0,1.0
21,blogs_86,136,136,136,1.0,1.0,1.0
20,blogs_63,84,84,84,1.0,1.0,1.0
19,blogs_52,91,91,91,1.0,1.0,1.0
18,blogs_31,155,155,155,1.0,1.0,1.0


In [130]:
results.F1.mean()

0.9987109548727691