## Tree building evaluation on gold EDUs (mostly) and playground for tree building scripts

1. Modifications of library components for tree building
2. Scripts for test and evaluation of Sklearn-, AllenNLP- and gold-annotation-based RST parsers on manually segmented corpus

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
from utils.print_tree import printBTree
#from utils.rst_annotation import DiscourseUnit

import sys
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

In [None]:
from isanlp.annotation_rst import DiscourseUnit

In [None]:
def printTree(tree):
    def _(n):
        if n.relation:
            value = (n.relation, "%.2f"%(n.proba))
        else:
            value = n.text
        return str(value), n.left, n.right

    return printBTree(_) 

In [None]:
class DiscourseUnitCreator:
    def __init__(self, id):
        self.id = id
        
    def __call__(self, left_node, right_node, proba):
        self.id += 1
        return DiscourseUnit(
            id=id,
            left=left_node,
            right=right_node,
            relation=1,
            proba=proba
        )

In [None]:
from allennlp.predictors import Predictor

pr = Predictor.from_path('../isanlp_rst/models/structure_predictor_lstm/model.tar.gz')

In [None]:
from isanlp_rst.src.isanlp_rst.sklearn_classifier import SklearnClassifier
from isanlp_rst.src.isanlp_rst.allennlp_classifier import AllenNLPClassifier
from isanlp_rst.src.isanlp_rst.rst_tree_predictor import *
from isanlp_rst.src.isanlp_rst.greedy_rst_parser import GreedyRSTParser
from isanlp_rst.src.isanlp_rst.features_extractor import FeaturesExtractor
from isanlp_rst.src.isanlp_rst.features_processor_tokenizer import FeaturesProcessor

In [None]:
from utils.train_test_split import split_train_dev_test

train, dev, test = split_train_dev_test('./data')

# Evaluation (Parser)

In [None]:
_SPAN_PREDICTOR = {
    'lstm': (AllenNLPClassifier, 'structure_predictor_lstm', 0.1, 0.5),
    'ensemble': (SklearnClassifier, 'structure_predictor', 0.15, 0.2),
}

_LABEL_PREDICTOR = {
    'lstm': (AllenNLPClassifier, 'label_predictor_lstm'),
    'ensemble': (SklearnClassifier, 'label_predictor'),
}

In [None]:
binary_classifier = AllenNLPClassifier('../isanlp_rst/models/structure_predictor_lstm/')
label_classifier = AllenNLPClassifier('../isanlp_rst/models/label_predictor_lstm/')

features_processor = FeaturesProcessor(model_dir_path='models', verbose=False)
features_extractor = FeaturesExtractor(features_processor)

predictor = NNTreePredictor(features_processor=features_extractor, 
                            relation_predictor_sentence=None,
                            relation_predictor_text=binary_classifier, 
                            label_predictor=label_classifier)

paragraph_parser = GreedyRSTParser(predictor,
                                   confidence_threshold=_SPAN_PREDICTOR['lstm'][2])

document_parser = GreedyRSTParser(predictor,
                                  confidence_threshold=_SPAN_PREDICTOR['lstm'][3])

In [None]:
additional_document_parser = GreedyRSTParser(predictor,
                                             confidence_threshold=_SPAN_PREDICTOR['lstm'][3]-0.15)

In [None]:
from isanlp.annotation import Sentence

def split_by_paragraphs(annot_text, annot_tokens, annot_sentences, annot_lemma, annot_morph, annot_postag,
                        annot_syntax_dep_tree):

    def split_on_two(sents, boundary):
        list_sum = lambda l: sum([len(sublist) for sublist in l])

        i = 1
        while list_sum(sents[:i]) < boundary and i < len(sents):
            i += 1

        intersentence_boundary = min(len(sents[i - 1]), boundary - list_sum(sents[:i - 1]))
        return (sents[:i - 1] + [sents[i - 1][:intersentence_boundary]],
                [sents[i - 1][intersentence_boundary:]] + sents[i:])

    def recount_sentences(chunk):
        sentences = []
        lemma = []
        morph = []
        postag = []
        syntax_dep_tree = []
        tokens_cursor = 0

        for i, sent in enumerate(chunk['syntax_dep_tree']):
            if len(sent) > 0:
                sentences.append(Sentence(tokens_cursor, tokens_cursor + len(sent)))
                lemma.append(chunk['lemma'][i])
                morph.append(chunk['morph'][i])
                postag.append(chunk['postag'][i])
                syntax_dep_tree.append(chunk['syntax_dep_tree'][i])
                tokens_cursor += len(sent)

        chunk['sentences'] = sentences
        chunk['lemma'] = lemma
        chunk['morph'] = morph
        chunk['postag'] = postag
        chunk['syntax_dep_tree'] = syntax_dep_tree

        return chunk

    chunks = []
    prev_right_boundary = -1

    for i, token in enumerate(annot_tokens[:-1]):

        if '\n' in annot_text[token.end:annot_tokens[i + 1].begin]:
            if prev_right_boundary > -1:
                chunk = {
                    'text': annot_text[annot_tokens[prev_right_boundary].end:token.end + 1].strip(),
                    'tokens': annot_tokens[prev_right_boundary + 1:i + 1]
                }
            else:
                chunk = {
                    'text': annot_text[:token.end + 1].strip(),
                    'tokens': annot_tokens[:i + 1]
                }

            lemma, annot_lemma = split_on_two(annot_lemma, i - prev_right_boundary)
            morph, annot_morph = split_on_two(annot_morph, i - prev_right_boundary)
            postag, annot_postag = split_on_two(annot_postag, i - prev_right_boundary)
            syntax_dep_tree, annot_syntax_dep_tree = split_on_two(annot_syntax_dep_tree, i - prev_right_boundary)

            chunk.update({
                'lemma': lemma,
                'morph': morph,
                'postag': postag,
                'syntax_dep_tree': syntax_dep_tree,
            })
            chunks.append(recount_sentences(chunk))

            prev_right_boundary = i  # number of last token in the last chunk

    chunk = {
        'text': annot_text[annot_tokens[prev_right_boundary].end:].strip(),
        'tokens': annot_tokens[prev_right_boundary + 1:],
        'lemma': annot_lemma,
        'morph': annot_morph,
        'postag': annot_postag,
        'syntax_dep_tree': annot_syntax_dep_tree,
    }

    chunks.append(recount_sentences(chunk))
    return chunks

In [None]:
cache = {}

In [None]:
# 1. Split text and annotations on paragraphs and process separately
dus = []
start_id = 0

if '\n' in annot_text:
    chunks = self.split_by_paragraphs(
        annot_text,
        annot_tokens,
        annot_sentences,
        annot_lemma,
        annot_morph,
        annot_postag,
        annot_syntax_dep_tree)

    for chunk in chunks:

        edus = self.segmenter(annot_text, chunk['tokens'], chunk['sentences'], chunk['lemma'],
                              chunk['postag'], chunk['syntax_dep_tree'], start_id=start_id)

        if len(edus) == 1:
            dus += edus
            start_id = edus[-1].id + 1

        elif len(edus) > 1:
            trees = self.paragraph_parser(edus,
                                          annot_text, chunk['tokens'], chunk['sentences'], chunk['lemma'],
                                          chunk['morph'], chunk['postag'], chunk['syntax_dep_tree'])

            dus += trees
            start_id = dus[-1].id + 1

    # 2. Process paragraphs into the document-level annotation
    trees = self.document_parser(dus,
                                 annot_text,
                                 annot_tokens,
                                 annot_sentences,
                                 annot_lemma,
                                 annot_morph,
                                 annot_postag,
                                 annot_syntax_dep_tree)


In [None]:
annot['text']

In [None]:
def split_by_paragraphs_edus(edus, text):
    res = []
    parag = []
    
    for edu in edus:
        parag.append(edu)
        boundary = text.find(edu)+len(edu)
        if boundary < len(text):
            if text[boundary] == '\n':
                res.append(parag)
                parag = []
         
    if parag:
        res.append(parag)
    return res

In [None]:
split_by_paragraphs_edus(edus, annot['text'])

In [None]:
cache = []

In [None]:
def prepare_gold_pairs(gold_pairs):
    TARGET = 'category_id'

    gold_pairs[TARGET] = gold_pairs[TARGET].replace([0.0], 'same-unit_m')
    gold_pairs['order'] = gold_pairs['order'].replace([0.0], 'NN')
    gold_pairs[TARGET] = gold_pairs[TARGET].replace(['antithesis_r',], 'contrast_m')
    gold_pairs[TARGET] = gold_pairs[TARGET].replace(['cause_r', 'effect_r'], 'cause-effect_r')
    gold_pairs[TARGET] = gold_pairs[TARGET].replace(['conclusion_r',], 'restatement_m')
    gold_pairs[TARGET] = gold_pairs[TARGET].replace(['evaluation_r'], 'interpretation-evaluation_r')
    gold_pairs[TARGET] = gold_pairs[TARGET].replace(['motivation_r',], 'condition_r')
    gold_pairs['relation'] = gold_pairs[TARGET].map(lambda row: row[:-1]) + gold_pairs['order']
    gold_pairs['relation'].value_counts()
    gold_pairs['relation'] = gold_pairs['relation'].replace(['restatement_SN', 'restatement_NS'], 'restatement_NN')
    gold_pairs['relation'] = gold_pairs['relation'].replace(['contrast_SN', 'contrast_NS'], 'contrast_NN')
    gold_pairs['relation'] = gold_pairs['relation'].replace(['solutionhood_NS', 'preparation_NS'], 'elaboration_NS')
    gold_pairs['relation'] = gold_pairs['relation'].replace(['concession_SN', 'evaluation_SN', 
                                                             'elaboration_SN', 'evidence_SN'], 'preparation_SN')

    _class_mapper = {
            'background_NS': 'elaboration_NS',
            'background_SN': 'preparation_SN',
            'comparison_NN': 'contrast_NN',
            'interpretation-evaluation_SN': 'elaboration_NS',
            'evidence_NS': 'elaboration_NS',
            'restatement_NN': 'joint_NN',
            'sequence_NN': 'joint_NN'
        }

    for key, value in _class_mapper.items():
        gold_pairs['relation'] = gold_pairs['relation'].replace(key, value)
        
    gold_pairs['order'] = gold_pairs['relation'].map(lambda row: row.split('_')[1])
    gold_pairs[TARGET] = gold_pairs['relation'].map(lambda row: row.split('_')[0])
        
    return gold_pairs

In [None]:
from tqdm import tqdm_notebook as tqdm
from utils.file_reading import *
from utils.evaluation import extr_pairs, extr_pairs_forest


broken_files = []
smallest_file = 'data/news2_4.edus'
coolest_file = 'data/blogs_17.edus'
#test[:1]
for file in tqdm([coolest_file]):
    filename = '.'.join(file.split('.')[:-1])
    edus = read_edus(filename)
    #gold = read_gold(filename)
    gold = prepare_gold_pairs(read_gold(filename, features=True))
    
    annot = read_annotation(filename)
    
    if '\n' in annot['text']:
        chunks = split_by_paragraphs(
            annot['text'],
            annot['tokens'], 
            annot['sentences'], 
            annot['lemma'], 
            annot['morph'], 
            annot['postag'], 
            annot['syntax_dep_tree'])
        
        chunked_edus = split_by_paragraphs_edus(edus, annot['text'])
    
    dus = []
    for i, chunk in enumerate(chunks):
        _edus = []
        last_end = 0
        
        for max_id in range(len(chunked_edus[i])):
            start = len(annot['text'][:last_end]) + annot['text'][last_end:].find(chunked_edus[i][max_id])
            end = start + len(chunked_edus[i][max_id])
            temp = DiscourseUnit(
                    id=max_id,
                    left=None,
                    right=None,
                    relation='edu',
                    start=start,
                    end=end,
                    orig_text=annot['text'],
                    proba=1.,
                )

            _edus.append(temp)
            last_end = end + 1
            
        if len(_edus) == 1:
            dus += _edus
            start_id = _edus[-1].id + 1

        elif len(_edus) > 1:
            trees = paragraph_parser(_edus,
                annot['text'], chunk['tokens'], chunk['sentences'], chunk['lemma'],
                chunk['morph'], chunk['postag'], chunk['syntax_dep_tree'])
            
            dus += trees
            # print('chunk processed')
            start_id = dus[-1].id + 1
        
    parsed = document_parser(
                dus, 
                annot['text'], 
                annot['tokens'], 
                annot['sentences'], 
                annot['lemma'], 
                annot['morph'], 
                annot['postag'], 
                annot['syntax_dep_tree'],
                genre=filename.split('_')[0])
    
    if len(parsed) > len(annot['text']) // 400:
        parsed = additional_document_parser(
            parsed, 
            annot['text'], 
            annot['tokens'], 
            annot['sentences'], 
            annot['lemma'], 
            annot['morph'], 
            annot['postag'], 
            annot['syntax_dep_tree'],
            genre=filename.split('_')[0]
        )
    
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed, annot['text']), columns=['snippet_x', 'snippet_y', 'category_id', 'order'])
    evaluation = eval_pipeline(parsed, edus, gold, annot['text'])
    evaluation['filename'] = file
    cache.append(evaluation)

In [None]:
for tree in parsed:
    print('>>>', tree.relation + '_' + tree.nuclearity)
    if tree.left:
        print(tree.text)

In [None]:
from utils.evaluation import eval_pipeline

evaluation = eval_pipeline(parsed, edus, gold, annot['text'])
evaluation['filename'] = file
cache = [evaluation]

In [None]:
tmp = pd.DataFrame(cache)
tmp['pr_seg'] = tmp.seg_true_pred / tmp.seg_all_pred
tmp['re_seg'] = tmp.seg_true_pred / tmp.seg_all_true
tmp['f1_seg'] = 2 * tmp.pr_seg * tmp.re_seg / (tmp.pr_seg + tmp.re_seg)
tmp['pr_unlab'] = tmp.unlab_true_pred / tmp.unlab_all_pred
tmp['re_unlab'] = tmp.unlab_true_pred / tmp.unlab_all_true
tmp['f1_unlab'] = 2 * tmp.pr_unlab * tmp.re_unlab / (tmp.pr_unlab + tmp.re_unlab)
tmp['pr_lab'] = tmp.lab_true_pred / tmp.lab_all_pred
tmp['re_lab'] = tmp.lab_true_pred / tmp.lab_all_true
tmp['f1_lab'] = 2 * tmp.pr_lab * tmp.re_lab / (tmp.pr_lab + tmp.re_lab)
tmp['pr_nuc'] = tmp.nuc_true_pred / tmp.nuc_all_pred
tmp['re_nuc'] = tmp.nuc_true_pred / tmp.nuc_all_true
tmp['f1_nuc'] = 2 * tmp.pr_nuc * tmp.re_nuc / (tmp.pr_nuc + tmp.re_nuc)
tmp.sort_values('f1_seg', ascending=False)

In [None]:
tmp[[key for key in tmp.keys() if 'lab' in key]]

In [None]:
from utils.evaluation import _not_parsed_as_in_gold

err = _not_parsed_as_in_gold(parsed_pairs, gold, labeled=True)

In [None]:
err

In [None]:
from utils.evaluation import metric_parseval

filenames = []
true_pos = []
all_parsed = []
all_gold = []

for key, value in cache.items():
    c_true_pos, c_all_parsed, c_all_gold = metric_parseval(value[0], value[1])
    filenames.append(key)
    true_pos.append(c_true_pos)
    all_parsed.append(c_all_parsed)
    all_gold.append(c_all_gold)
    
results = pd.DataFrame({'filename': filenames, 
                    'true_pos': true_pos,
                    'all_parsed': all_parsed,
                    'all_gold': all_gold})

In [None]:
results['recall'] = results['true_pos'] / results['all_gold']
results['precision'] = results['true_pos'] / results['all_parsed']
results['F1'] = 2 * results['precision'] * results['recall'] / (results['precision'] + results['recall'])

In [None]:
re = results['true_pos'].sum() / results['all_gold'].sum()
pr = results['true_pos'].sum() / results['all_parsed'].sum()
f1 = 2 * pr * re / (pr+re)
print(re, pr, f1)

In [None]:
results['F1'].mean(), results['recall'].mean(), results['precision'].mean()

In [None]:
tmp = results[results.filename.str.contains('blog')]

tmp['F1'].mean(), tmp['recall'].mean(), tmp['precision'].mean()

In [None]:
tmp = results[results.filename.str.contains('news')]

tmp['F1'].mean(), tmp['recall'].mean(), tmp['precision'].mean()

# Evaluation (Gold)

In [None]:
cache = {}

In [None]:
from utils.evaluation import metric_parseval_df as metric_parseval

In [None]:
filenames = []
true_pos = []
all_parsed = []
all_gold = []

for key, value in cache.items():
    c_true_pos, c_all_parsed, c_all_gold = metric_parseval(value[0], value[1])
    filenames.append(key)
    true_pos.append(c_true_pos)
    all_parsed.append(c_all_parsed)
    all_gold.append(c_all_gold)
    
results = pd.DataFrame({'filename': filenames, 
                    'true_pos': true_pos,
                    'all_parsed': all_parsed,
                    'all_gold': all_gold})

In [None]:
from isanlp_rst.src.isanlp_rst.rst_tree_predictor import GoldTreePredictor

In [None]:
def parse_golds(filename):
    filename = '.'.join(filename.split('.')[:-1])
    edus = read_edus(filename)
    gold = read_gold(filename)
    annot = read_annotation(filename)
    
    _edus = []
    last_end = 0
    for max_id in range(len(edus)):
        start = len(annot['text'][:last_end]) + annot['text'][last_end:].find(edus[max_id])
        end = start + len(edus[max_id])
        temp = DiscourseUnit(
                id=max_id,
                left=None,
                right=None,
                relation='edu',
                start=start,
                end=end,
                orig_text=annot['text'],
                proba=1.,
            )
        _edus.append(temp)
        last_end = end

    parser = GreedyRSTParser(GoldTreePredictor(gold), confidence_threshold=0.)
    parsed = parser(_edus, annot['text'], annot['tokens'], annot['sentences'],
                    annot['postag'], annot['morph'], annot['lemma'], annot['syntax_dep_tree'])
    
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed, annot['text']), columns=['snippet_x', 'snippet_y', 'category_id', 'order'])
    parsed_pairs[0] = parsed_pairs.snippet_x
    parsed_pairs[1] = parsed_pairs.snippet_y
    
    return (filename,) + metric_parseval(parsed_pairs, gold)#, parsed_pairs, parsed

In [None]:
%%time

import multiprocessing as mp

pool = mp.Pool(5)
result = pool.map(parse_golds, test)
pool.close()

In [None]:
results = pd.DataFrame(columns=['filename', 'true_pos', 'all_parsed', 'all_gold'], data=result)

results['recall'] = results['true_pos'] / results['all_gold']
results['precision'] = results['true_pos'] / results['all_parsed']
results['F1'] = 2 * results['precision'] * results['recall'] / (results['precision'] + results['recall'])

In [None]:
results.sort_values('F1')

In [None]:
results.F1.mean()