In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
from utils.print_tree import printBTree
#from utils.rst_annotation import DiscourseUnit

import sys
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

In [None]:
class DiscourseUnit:
    def __init__(self, id, left=None, right=None, text='', start=None, end=None, 
                 orig_text=None, relation=None, nuclearity=None, proba=1.):
        """
        :param int id:
        :param DiscourseUnit left:
        :param DiscourseUnit right:
        :param str text: (optional)
        :param int start: start position in original text
        :param int end: end position in original text
        :param string relation: {the relation between left and right components | 'elementary' | 'root'}
        :param string nuclearity: {'NS' | 'SN' | 'NN'}
        :param float proba: predicted probability of the relation occurrence
        """
        self.id = id
        self.left = left
        self.right = right
        self.relation = relation
        self.nuclearity = nuclearity
        self.proba = str(proba)
        self.start = start
        self.end = end

        if self.left:
            self.start = left.start
            self.end = right.end+1
        
        if orig_text:            
            self.text = orig_text[self.start:self.end].strip()
        else:
            self.text = text.strip()

    def __str__(self):
        return f"id: {self.id}\ntext: {self.text}\nrelation: {self.relation}\nleft: {self.left.text if self.left else None}\nright: {self.right.text if self.right else None}\nstart: {self.start}\nend: {self.end}"


In [None]:
def printTree(tree):
    def _(n):
        if n.relation:
            value = (n.relation, "%.2f"%(n.proba))
        else:
            value = n.text
        return str(value), n.left, n.right

    return printBTree(_) 

In [None]:
class DiscourseUnitCreator:
    def __init__(self, id):
        self.id = id
        
    def __call__(self, left_node, right_node, proba):
        self.id += 1
        return DiscourseUnit(
            id=id,
            left=left_node,
            right=right_node,
            relation=1,
            proba=proba
        )

In [None]:
#from isanlp.annotation_rst import DiscourseUnit
import pandas as pd


class RSTTreePredictor:
    def __init__(self, features_processor, relation_predictor, label_predictor):
        self.features_processor = features_processor
        self.relation_predictor = relation_predictor
        self.label_predictor = label_predictor
        if self.label_predictor:
            self.labels = self.label_predictor.classes_
        self.genre = None

    def predict_label(self, features):
        if not self.label_predictor:
            return 'relation'

        return self.label_predictor.predict(features)


class GoldTreePredictor(RSTTreePredictor):
    def __init__(self, corpus):
        RSTTreePredictor.__init__(self, None, None, None)
        self.corpus = corpus

    def extract_features(self, *args):
        return [args[0].text, args[1].text]

    def predict_pair_proba(self, features):
        def _check_snippet_pair_in_dataset(left_snippet, right_snippet):
            return ((((self.corpus.snippet_x == left_snippet) & (self.corpus.snippet_y == right_snippet)).sum(
                axis=0) != 0)
                    or ((self.corpus.snippet_y == left_snippet) & (self.corpus.snippet_x == right_snippet)).sum(
                        axis=0) != 0)

        left_snippet, right_snippet = features
        return float(_check_snippet_pair_in_dataset(left_snippet, right_snippet))

    def predict_label(self, features):
        left_snippet, right_snippet = features
        label = self.corpus[((self.corpus.snippet_x == left_snippet) & (self.corpus.snippet_y == right_snippet))].category_id.values
        if label.size == 0:
            return 'relation'
        
        return label[0]
    
    def predict_nuclearity(self, features):
        left_snippet, right_snippet = features
        nuclearity = self.corpus[((self.corpus.snippet_x == left_snippet) & (self.corpus.snippet_y == right_snippet))].order.values
        if nuclearity.size == 0:
            return '_'
        
        return nuclearity[0]


class CustomTreePredictor(RSTTreePredictor):
    """
    Contains trained classifiers and feature processors needed for tree prediction.
    """
    def __init__(self, features_processor, relation_predictor, label_predictor=None):
        RSTTreePredictor.__init__(self, features_processor, relation_predictor, label_predictor)

    def extract_features(self, left_node: DiscourseUnit, right_node: DiscourseUnit,
                         annot_text, annot_tokens, annot_sentences, annot_postag, annot_morph, annot_lemma,
                         annot_syntax_dep_tree):
        pair = pd.DataFrame({
            'snippet_x': [left_node.text.strip()],
            'snippet_y': [right_node.text.strip()],
        })

        try:
            features = self.features_processor(pair, annot_text=annot_text,
                                               annot_tokens=annot_tokens, annot_sentences=annot_sentences,
                                               annot_postag=annot_postag, annot_morph=annot_morph,
                                               annot_lemma=annot_lemma, annot_syntax_dep_tree=annot_syntax_dep_tree)
            return features
        except IndexError:
            with open('errors.log', 'w+') as f:
                f.write(str(pair.values))
                f.write(annot_text)
            return -1

    def predict_pair_proba(self, features):
        return self.relation_predictor.predict_proba(features)[0][1]
    
    def predict_nuclearity(self, features):
        # ToDO:
        return 'unavail'


In [None]:
import numpy as np
import sys

#from isanlp.annotation_rst import DiscourseUnit


class GreedyRSTParser:
    def __init__(self, tree_predictor, forest_threshold=0.05):
        """
        :param RSTTreePredictor tree_predictor:
        :param float forest_threshold: minimum relation probability to append the pair into the tree
        """
        self.tree_predictor = tree_predictor
        self.forest_threshold = forest_threshold

    def __call__(self, edus, annot_text, annot_tokens, annot_sentences, annot_postag, annot_morph, annot_lemma,
                 annot_syntax_dep_tree, genre=None):
        """
        :param list edus: DiscourseUnit
        :param str annot_text: original text
        :param list annot_tokens: isanlp.annotation.Token
        :param list annot_sentences: isanlp.annotation.Sentence
        :param list annot_postag: lists of str for each sentence
        :param annot_lemma: lists of str for each sentence
        :param annot_syntax_dep_tree: list of isanlp.annotation.WordSynt for each sentence
        :return: list of DiscourseUnit containing each extracted tree
        """

        def to_merge(scores):
            return np.argmax(np.array(scores))

        self.tree_predictor.genre = genre

        nodes = edus
        
#         for edu in nodes:
#             print(edu, file=sys.stderr)
        
        max_id = edus[-1].id

        # initialize scores
        features = [
            self.tree_predictor.extract_features(nodes[i], nodes[i + 1], annot_text, annot_tokens,
                                                 annot_sentences,
                                                 annot_postag, annot_morph, annot_lemma,
                                                 annot_syntax_dep_tree)
            for i in range(len(nodes) - 1)]

        scores = [self.tree_predictor.predict_pair_proba(features[i]) for i in range(len(nodes) - 1)]
        relations = [self.tree_predictor.predict_label(features[i]) for i in range(len(nodes) - 1)]
        nuclearities = [self.tree_predictor.predict_nuclearity(features[i]) for i in range(len(nodes) - 1)]

        while len(nodes) > 2 and any([score > self.forest_threshold for score in scores]):
            # select two nodes to merge
            j = to_merge(scores)  # position of the pair in list
            
            # make the new node by merging node[j] + node[j+1]
            temp = DiscourseUnit(
                id=max_id + 1,
                left=nodes[j],
                right=nodes[j + 1],
                relation=self.tree_predictor.predict_label(features[j]),
                nuclearity=self.tree_predictor.predict_nuclearity(features[j]),
                proba=scores[j],
                text=annot_text[nodes[j].start:nodes[j + 1].end].strip()
                #orig_text=annot_text
                #text=nodes[j].text + nodes[j + 1].text  #annot_text[nodes[j].start:nodes[j+1].end]
            )
            
#             print(temp, file=sys.stderr)
            
            max_id += 1

            # modify the node list
            nodes = nodes[:j] + [temp] + nodes[j + 2:]

            # modify the scores list
            if j == 0:
                features_right = self.tree_predictor.extract_features(nodes[j], nodes[j + 1],
                                                                annot_text, annot_tokens, 
                                       
                                                                      annot_sentences, annot_postag,
                                                                annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted = self.tree_predictor.predict_pair_proba(features_right)

                scores = [predicted] + scores[j + 2:]
                features = [features_right] + features[j + 2:]

            elif j + 1 < len(nodes):
                features_left = self.tree_predictor.extract_features(nodes[j - 1], nodes[j], 
                                                                     annot_text, annot_tokens,
                                                                     annot_sentences, annot_postag, 
                                                                     annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted_left = self.tree_predictor.predict_pair_proba(features_left)

                features_right = self.tree_predictor.extract_features(nodes[j], nodes[j + 1], 
                                                                      annot_text, annot_tokens,
                                                                      annot_sentences, annot_postag, 
                                                                      annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted_right = self.tree_predictor.predict_pair_proba(features_right)

                scores = scores[:j - 1] + [predicted_left] + [predicted_right] + scores[j + 2:]
                features = features[:j - 1] + [features_left] + [features_right] + features[j + 2:]

            else:
                features_left = self.tree_predictor.extract_features(nodes[j - 1], nodes[j],
                                                                annot_text, annot_tokens, 
                                                                annot_sentences, annot_postag,
                                                                annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted = self.tree_predictor.predict_pair_proba(features_left)
                scores = scores[:j - 1] + [predicted]
                features = features[:j - 1] + [features_left]

        if len(scores) == 1 and scores[0] > self.forest_threshold:
            root = DiscourseUnit(
                id=max_id + 1,
                left=nodes[0],
                right=nodes[1],
                relation='root',
                proba=scores[0]
            )
            nodes = [root]

        return nodes


In [None]:
import numpy as np
import pandas as pd


def get_embeddings(embedder, X, maxlen=100):
    X_ = [text[:text.rfind('_')] for text in X.split()]
    result = np.zeros((embedder.vector_size, maxlen))

    for i in range(min(len(X_), maxlen)):
        try:
            result[i] = embedder[X_[i]]
        except KeyError:
            continue

    return result


class FeaturesExtractor:
    DROP_COLUMNS = ['snippet_x', 'snippet_y', 'snippet_x_tmp', 'snippet_y_tmp', 'postags_x', 'postags_y']

    def __init__(self, processor, scaler=None, categorical_cols=None, one_hot_encoder=None, label_encoder=None):
        self.processor = processor
        self.scaler = scaler
        self._categorical_cols = categorical_cols
        self.one_hot_encoder = one_hot_encoder
        self.label_encoder = label_encoder

    def __call__(self, df, annot_text, annot_tokens, annot_sentences, annot_postag, annot_morph, annot_lemma, annot_syntax_dep_tree):
        X = self.processor(df, annot_text, annot_tokens, annot_sentences, annot_postag, annot_morph, annot_lemma, annot_syntax_dep_tree)
        X = X.drop(columns=self.DROP_COLUMNS)

        if self._categorical_cols:
            if self.label_encoder:
                X[self._categorical_cols] = X[self._categorical_cols].apply(lambda col: self.label_encoder.fit_transform(col))

            if self.one_hot_encoder:
                X_ohe = self.one_hot_encoder.transform(X[self._categorical_cols].values)
                X_ohe = pd.DataFrame(X_ohe, X.index, columns=self.one_hot_encoder.get_feature_names(self._categorical_cols))

                X = X.join(
                    pd.DataFrame(X_ohe, X.index).add_prefix('cat_'), how='right'
                ).drop(columns=self._categorical_cols).drop(columns=self.DROP_COLUMNS)

        if self.scaler:
            return pd.DataFrame(self.scaler.transform(X.values), index=X.index, columns=X.columns)

        return X


## Gold tree parsing example 

In [None]:
def extr_pairs(tree):
    pp = []
    if tree.left:
        pp.append([tree.left.text, tree.right.text, tree.relation])
        pp += extr_pairs(tree.left)
        pp += extr_pairs(tree.right)
    return pp

def extr_pairs_forest(forest):
    pp = []
    for tree in forest:
        pp += extr_pairs(tree)
    return pp

In [None]:
from utils.train_test_split import split_data

train, test = split_data('data/', 0.2)

In [None]:
from utils.file_reading import read_edus, read_gold, read_annotation

In [None]:
%%time

from tqdm import tqdm_notebook as tqdm
from utils.file_reading import read_edus, read_gold, read_annotation

cache = {}
for file in tqdm(test[:3]):
    filename = '.'.join(file.split('.')[:-1])
    edus = read_edus(filename)
    gold = read_gold(filename)
    annot = read_annotation(filename)
    
    _edus = []
    last_end = 0
    for max_id in range(len(edus)):
        start = len(annot['text'][:last_end]) + annot['text'][last_end:].find(edus[max_id])
        end = start + len(edus[max_id])
        temp = DiscourseUnit(
                id=max_id,
                left=None,
                right=None,
                relation='edu',
                start=start,
                end=end,
                orig_text=annot['text'],
                proba=1.,
                #text=edus[max_id]  #annot_text[nodes[j].start:nodes[j+1].end]
            )
        _edus.append(temp)
        last_end = end

    parser = GreedyRSTParser(GoldTreePredictor(gold), forest_threshold=0.)
    parsed = parser(_edus, annot['text'], annot['tokens'], annot['sentences'],
                    annot['postag'], annot['morph'], annot['lemma'], annot['syntax_dep_tree'])
    
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed), columns=['snippet_x', 'snippet_y', 'category_id'])
    cache[filename] = (parsed_pairs, gold)

In [None]:
import json

#filename = 'rst_pairs/news_55'
#filename = 'rst_pairs/news_13'
filename = 'data/sci.comp_26'
edus = read_edus(filename)
gold = read_gold(filename)
annot = read_annotation(filename)

_edus = []
last_end = 0
for max_id in range(len(edus)):
    start = len(annot['text'][:last_end]) + annot['text'][last_end:].find(edus[max_id])
    end = start + len(edus[max_id])
    temp = DiscourseUnit(
            id=max_id,
            left=None,
            right=None,
            relation='edu',
            start=start,
            end=end,
            orig_text=annot['text'],
            proba=1.,
            #text=edus[max_id]  #annot_text[nodes[j].start:nodes[j+1].end]
        )
    _edus.append(temp)
    last_end = end

In [None]:
parser = GreedyRSTParser(GoldTreePredictor(gold), forest_threshold=0.)
parsed = parser(_edus, annot['text'], annot['tokens'], annot['sentences'],
                annot['postag'], annot['morph'], annot['lemma'], annot['syntax_dep_tree'])

## Evaluation (Gold tree construction)

In [None]:
true_pos = []
all_parsed = []
all_gold = []
fnames = []

c_true_pos, c_all_parsed, c_all_gold = metric_parseval(parsed_pairs, gold)
true_pos.append(c_true_pos)
all_parsed.append(c_all_parsed)
all_gold.append(c_all_gold)

recall = sum(true_pos) / sum(all_gold)
print('Recall: ', recall)

precision = sum(true_pos) / sum(all_parsed)
print('Precision:', precision)

f1 = 2 * precision * recall / (precision + recall)
print('F1:', f1)
    
aa = pd.DataFrame({'true_pos': true_pos, 'all_parsed': all_parsed, 'all_gold': all_gold})
aa['recall'] = aa.true_pos / aa.all_gold
aa['precision'] = aa.true_pos / aa.all_parsed
aa['f1'] = aa.recall * aa.precision * 2 / (aa.precision + aa.recall)

aa.sort_values('f1')

# Evaluation (Parser)

In [None]:
from utils.features_processor_default import FeaturesProcessor

binary_classifier_model_path = 'models/structure_predictor/'

In [None]:
%%time

features_processor = FeaturesProcessor(model_dir_path='models', verbose=False)

In [None]:
import pickle
import os

scaler = pickle.load(open(os.path.join(binary_classifier_model_path, 'scaler.pkl'), 'rb'))
#categorical_cols = pickle.load(open(binary_classifier_model_path + 'categorical_cols.pkl', 'rb'))
#ohe = pickle.load(open(binary_classifier_model_path + 'one_hot_encoder.pkl', 'rb'))
#le = pickle.load(open(binary_classifier_model_path + 'label_encoder.pkl', 'rb'))
binary_classifier = pickle.load(open(os.path.join(binary_classifier_model_path, 'model.pkl'), 'rb'))
features_extractor = FeaturesExtractor(features_processor, scaler)

In [None]:
predictor = CustomTreePredictor(features_extractor, binary_classifier, label_predictor=None)

In [None]:
parser = GreedyRSTParser(predictor, forest_threshold=0.3)

In [None]:
from tqdm import tqdm_notebook as tqdm
from utils.file_reading import *
from utils.evaluation import extr_pairs_forest

cache = {}
broken_files = []

for file in tqdm(test):
    filename = '.'.join(file.split('.')[:-1])
    edus = read_edus(filename)
    gold = read_gold(filename)
    annot = read_annotation(filename)
    
    _edus = []
    last_end = 0
    for max_id in range(len(edus) - 1):
        start = annot['text'].find(edus[max_id], last_end)
        end = start + len(edus[max_id])
        temp = DiscourseUnit(
                id=max_id,
                left=None,
                right=None,
                relation='edu',
                start=start,
                end=end,
                orig_text=annot['text'],
                proba=1.,
                #text=edus[max_id]  #annot_text[nodes[j].start:nodes[j+1].end]
            )
        _edus.append(temp)
        last_end = end

    try:
        parsed = parser(_edus, 
                        annot['text'], 
                        annot['tokens'], 
                        annot['sentences'], 
                        annot['postag'], 
                        annot['morph'], 
                        annot['lemma'], 
                        annot['syntax_dep_tree'], 
                        genre=filename.split('_')[0])
    except:
        broken_files.append(filename)
        continue
    
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed), columns=['snippet_x', 'snippet_y', 'category_id'])
    cache[filename] = (parsed_pairs, gold)

In [None]:
parsed_pairs

In [None]:
from utils.evaluation import metric_parseval

filenames = []
true_pos = []
all_parsed = []
all_gold = []

for key, value in cache.items():
    c_true_pos, c_all_parsed, c_all_gold = metric_parseval(value[0], value[1])
    filenames.append(key)
    true_pos.append(c_true_pos)
    all_parsed.append(c_all_parsed)
    all_gold.append(c_all_gold)
    
results = pd.DataFrame({'filename': filenames, 
                    'true_pos': true_pos,
                    'all_parsed': all_parsed,
                    'all_gold': all_gold})

In [None]:
results['recall'] = results['true_pos'] / results['all_gold']
results['precision'] = results['true_pos'] / results['all_parsed']
results['F1'] = 2 * results['precision'] * results['recall'] / (results['precision'] + results['recall'])

In [None]:
results.sort_values('F1', ascending=False)

In [None]:
broken_files

# Evaluation (Gold)

In [None]:
def parse_golds(file):
    filename = file.replace('.edus', '')
    edus = read_edus(filename)
    gold = read_gold(filename)
    gold = gold.sort_values('snippet_y').drop_duplicates(subset=['snippet_y'])
    annot = read_annotation(filename)
    _edus = []
    last_end = 0
    for max_id in range(len(edus)):
        start = annot['text'].find(edus[max_id], last_end)
        end = start + len(edus[max_id])
        temp = DiscourseUnit(
                id=max_id,
                left=None,
                right=None,
                relation='edu',
                start=start,
                end=end,
                orig_text=annot['text'],
                proba=1.,
                #text=edus[max_id]  #annot_text[nodes[j].start:nodes[j+1].end]
            )
        _edus.append(temp)
        last_end = end
    
    parser = GreedyRSTParser(GoldTreePredictor(gold), forest_threshold=0.)
    #parsed = parser(_edus)
    
    try:
        parsed = parser(_edus, 
                        annot['text'], 
                        annot['tokens'], 
                        annot['sentences'], 
                        annot['postag'], 
                        annot['morph'], 
                        annot['lemma'], 
                        annot['syntax_dep_tree'], 
                        genre=filename.split('_')[0])
    except:
        broken_files.append(filename)
        continue
    
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed), columns=['snippet_x', 'snippet_y', 'category_id'])
    return (filename, parsed_pairs, gold)

In [None]:
import glob
import os
from tqdm import tqdm_notebook as tqdm

cache = {}
for file in tqdm(glob.glob('data/*.edus')):
    filename = file.replace('.edus', '')
    edus = read_edus(filename)
    gold = read_gold(filename)
    gold = gold.sort_values('snippet_y').drop_duplicates(subset=['snippet_y'])
    annot = read_annotation(filename)
    _edus = []
    last_end = 0
    for max_id in range(len(edus)):
        start = annot['text'].find(edus[max_id], last_end)
        end = start + len(edus[max_id])
        temp = DiscourseUnit(
                id=max_id,
                left=None,
                right=None,
                relation='edu',
                start=start,
                end=end,
                orig_text=annot['text'],
                proba=1.,
                #text=edus[max_id]  #annot_text[nodes[j].start:nodes[j+1].end]
            )
        _edus.append(temp)
        last_end = end
    
    parser = GreedyRSTParser(GoldTreePredictor(gold), forest_threshold=0.)
    #parsed = parser(_edus)
    
    try:
        parsed = parser(_edus, 
                        annot['text'], 
                        annot['tokens'], 
                        annot['sentences'], 
                        annot['postag'], 
                        annot['morph'], 
                        annot['lemma'], 
                        annot['syntax_dep_tree'], 
                        genre=filename.split('_')[0])
    except:
        broken_files.append(filename)
        continue
    
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed), columns=['snippet_x', 'snippet_y', 'category_id'])
    cache[filename] = (parsed_pairs, gold)
    parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed), columns=['snippet_x', 'snippet_y', 'category_id'])
    cache[filename] = (parsed_pairs, gold)

In [None]:
filenames = []
true_pos = []
all_parsed = []
all_gold = []

for key, value in cache.items():
    c_true_pos, c_all_parsed, c_all_gold = metric_parseval(value[0], value[1])
    filenames.append(key)
    true_pos.append(c_true_pos)
    all_parsed.append(c_all_parsed)
    all_gold.append(c_all_gold)
    
results = pd.DataFrame({'filename': filenames, 
                    'true_pos': true_pos,
                    'all_parsed': all_parsed,
                    'all_gold': all_gold})

In [None]:
results['recall'] = results['true_pos'] / results['all_gold']
results['precision'] = results['true_pos'] / results['all_parsed']
results['F1'] = 2 * results['precision'] * results['recall'] / (results['precision'] + results['recall'])

In [None]:
results[results['filename'].str.contains('comp')].F1.mean()

In [None]:
results[results['filename'].str.contains('ling')].F1.mean()

In [None]:
results[results['filename'].str.contains('news')].F1.mean()

In [None]:
results[results['filename'].str.contains('blogs')].F1.mean()

### Bad file analysis 

In [None]:
filename = 'data/news2_17'

edus = read_edus(filename)
gold = read_gold(filename)
gold = gold.sort_values('snippet_y').drop_duplicates(subset=['snippet_y'])
annot = read_annotation(filename)
_edus = []
last_end = 0
for max_id in range(len(edus)):
    start = annot['text'].find(edus[max_id], last_end)
    end = start + len(edus[max_id])
    temp = DiscourseUnit(
            id=max_id,
            left=None,
            right=None,
            relation='edu',
            start=start,
            end=end,
            orig_text=annot['text'],
            #text=edus[max_id],
            proba=1.,
            #text=edus[max_id]  #annot_text[nodes[j].start:nodes[j+1].end]
        )
    _edus.append(temp)
    last_end = end

parser = GreedyRSTParser(GoldTreePredictor(gold), forest_threshold=0.)
#parsed = parser(_edus)

parsed = parser(_edus, 
                annot['text'], 
                annot['tokens'], 
                annot['sentences'], 
                annot['postag'], 
                annot['morph'], 
                annot['lemma'], 
                annot['syntax_dep_tree'], 
                genre=filename.split('_')[0])

In [None]:
for _ in _edus:
    print(vars(_))

In [None]:
cache[filename] = (parsed_pairs, gold)

In [None]:
counter = 0

for tree in parsed:
    if tree.relation != 'edu':
        print(vars(tree))
        counter += 1
        break

In [None]:
tree = parsed[1]

In [None]:
vars(tree)

In [None]:
vars(tree)

In [None]:
vars(tree.right)

In [None]:
from utils.evaluation import metric_parseval, extr_pairs, extr_pairs_forest, _check_snippet_pair_in_dataset, _not_parsed_as_in_gold

parsed_pairs = pd.DataFrame(extr_pairs_forest(parsed), columns=['snippet_x', 'snippet_y', 'category_id'])
print(parsed_pairs.shape, gold.shape)
errors = _not_parsed_as_in_gold(parsed_pairs, gold)

def find_edu_number(edus, error):
    for i, edu in enumerate(edus):
        if error[2].find(edu) > -1:
            yield i

In [None]:
errors.iloc[3].values

In [None]:
list(find_edu_number(edus, errors.iloc[3]))