In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from isanlp import PipelineCommon
from isanlp.processor_razdel import ProcessorRazdel
from isanlp.processor_remote import ProcessorRemote
from isanlp.ru.processor_mystem import ProcessorMystem
from isanlp.ru.converter_mystem_to_ud import ConverterMystemToUd
import razdel
from isanlp.annotation import Token


address_syntax = (SERVER0, 3134)  # <- Put address of the udpipe parser
address_rst = (SERVER2, 3335)  # <- Put address of the rst parser


def tokenize(text):
    """ Tokenize text, but keep paragraph boundaries. """
    
    while '\n\n' in text:
        text = text.replace('\n\n', '\n')
    result = []
    for paragraph in text.split('\n'):
        result.append(' '.join([tok.text for tok in razdel.tokenize(paragraph)]))
    return '\n'.join(result).strip()


class WhitespaceTokenizer:
    """Performs dummy tokenization
    when you want to process pretokenized text
    """

    def __call__(self, text):
        result = []
        start = 0
        for token in text.strip().split():
            result.append(Token(token, start, start+len(token)))
            start += len(token) + 1
        return {'tokens': result}


ppl = PipelineCommon([
    (ProcessorRazdel(), ['text'],
     {'tokens': 'tokens',
      'sentences': 'sentences'}),
    (ProcessorRemote(address_syntax[0], address_syntax[1], '0'),
     ['tokens', 'sentences'],
     {'lemma': 'lemma',
      'syntax_dep_tree': 'syntax_dep_tree',
      'postag': 'ud_postag'}),
    (ProcessorMystem(delay_init=False),
     ['tokens', 'sentences'],
     {'postag': 'postag'}),
    (ConverterMystemToUd(),
     ['postag'],
     {'morph': 'morph',
      'postag': 'postag'}),
    (ProcessorRemote(address_rst[0], address_rst[1], 'default'),
     ['text', 'tokens', 'sentences', 'postag', 'morph', 'lemma', 'syntax_dep_tree'],
     {'rst': 'rst'})
])

In [None]:
from utils.file_reading import read_edus, read_gold
from utils.evaluation import *

example = 'data_ru/news2_4'
text = open(example + '.txt', 'r').read().strip()
gold_edus = read_edus(example)
gold_pairs = prepare_gold_pairs(read_gold(example, features=True))

In [None]:
%%time

result = ppl(tokenize(text))

In [None]:
from isanlp.annotation_rst import ForestExporter
ex = ForestExporter(encoding='utf8')
ex(result['rst'], 'news2_4_pred.rs3')

#### Data loading

In [None]:
from utils import metrics
from utils.discourseunit2str import du_to_docs_structure_charsonly
from utils.evaluation import *

In [None]:
import pandas as pd

pairs = pd.read_feather('data_ru/all_pairs.fth')
pairs = prepare_gold_pairs(pairs)
pairs.snippet_x = pairs.snippet_x.map(lambda row: charsonly(prepare_string(row).replace(' -', ' - ').replace('- ', ' - ').replace('  ', ' ')))
pairs.snippet_y = pairs.snippet_y.map(lambda row: charsonly(prepare_string(row).replace(' -', ' - ').replace('- ', ' - ').replace('  ', ' ')))

In [None]:
def gold_tree_as_string(pairs, filename, text):
    if filename:
        df = pairs[pairs.filename == filename]
    else:
        df = pairs[:]
    _text = charsonly(prepare_string(text)).strip()
    
    # In NS/SN nucleus is marked as a span; in NN both nuclei are nuclei
    df['loc_x'] = df.apply(lambda row: _text.find(row.snippet_x + row.snippet_y), axis=1)
    df['loc_y'] = df.apply(lambda row: row.loc_x + len(row.snippet_x), axis=1)
    
    result = []
    for idx, row in df.iterrows():
        left_begin, left_end = row.loc_x, row.loc_x + len(row.snippet_x)
        right_begin, right_end = row.loc_y, row.loc_y + len(row.snippet_y)
        left_rel = 'span' if row.order == 'NS' else row.category_id
        right_rel = 'span' if row.order == 'SN' else row.category_id
        left_nuc = 'Satellite' if row.order == 'SN' else 'Nucleus'
        right_nuc = 'Satellite' if row.order == 'NS' else 'Nucleus'
        result.append(f'({left_begin}:{left_nuc}={left_rel}:{left_end},{right_begin}:{right_nuc}={right_rel}:{right_end})')
    
    return ' '.join(result)

#### Document level

In [None]:
import os

PARSING_RES_PATH = 'end2end-rstreebank'
if not os.path.isdir(PARSING_RES_PATH):
    os.mkdir(PARSING_RES_PATH)

In [None]:
from isanlp.annotation_rst import ForestExporter

exp = ForestExporter('utf8')

In [None]:
from utils.train_test_split import split_rstreebank

train, dev, test = split_rstreebank('./data_ru')

In [None]:
from tqdm.autonotebook import tqdm
from utils.train_test_split import split_rstreebank, split_essays
from utils.file_reading import read_edus
from utils.evaluation import *
from isanlp.annotation_rst import ForestExporter
import pickle

train, dev, test = split_rstreebank('./data_ru')

# news only
#test = [filename for filename in test if 'news' in filename]
cache = []
thrown_error = []

test.sort()
ex = ForestExporter()

global_metric = metrics.DiscourseMetricDoc()

for file in tqdm(test):
    file = file.replace('.edus', '')
    pure_filename = file.split('/')[-1]
    text = open(file + '.txt', 'r').read().strip()

    for key in text_html_map.keys():
        text = text.replace(key, text_html_map[key])

    text = tokenize(text)
    result = ppl(text)
    # result = pickle.load(open(os.path.join(PARSING_RES_PATH, pure_filename + '.pkl'), 'rb'))

    pickle.dump(result, open(f'{PARSING_RES_PATH}/{file.split("/")[-1]}.pkl', 'wb'))
    exp(result['rst'], f'{PARSING_RES_PATH}/{file.split("/")[-1]}.rs3')

    out_file = file.split('/')[-1]

    pred = []
    for tree in result['rst']:
        dstr = du_to_docs_structure_charsonly(tree, charsonly(text))
        if dstr:
            pred += dstr
    pred = ' '.join(pred)

    gold = gold_tree_as_string(pairs, file.split('/')[-1], text)

    cur_metric = metrics.DiscourseMetricDoc()
    cur_metric([pred], [gold])
    print('Current metric:', pure_filename, cur_metric)
    global_metric([pred], [gold])

In [None]:
global_metric

#### Sentence and paragraph level 

In [None]:
import pickle

def charsonly(text):
    return ''.join(text.split())

# news only
#test = [filename for filename in test if 'news' in filename]
cache = []
thrown_error = []

test.sort()
ex = ForestExporter()
global_metric_s = metrics.DiscourseMetricDoc()
global_metric_p = metrics.DiscourseMetricDoc()

for file in tqdm(test):
    pure_filename = file.replace('.edus', '')
    
    result = pickle.load(open(os.path.join(PARSING_RES_PATH, pure_filename.split('/')[-1] + '.pkl'), 'rb'))
    parsed_pairs = pd.DataFrame(extr_pairs_forest(result['rst']), 
                                        columns=['snippet_x', 'snippet_y', 'category_id', 'order'])
    parsed_pairs['filename'] = pure_filename.split('/')[-1]
    parsed_pairs.snippet_x = parsed_pairs.snippet_x.map(lambda row: charsonly(prepare_string(row).replace(' -', ' - ').replace('- ', ' - ').replace('  ', ' ')))
    parsed_pairs.snippet_y = parsed_pairs.snippet_y.map(lambda row: charsonly(prepare_string(row).replace(' -', ' - ').replace('- ', ' - ').replace('  ', ' ')))
    parsed_pairs['loc_x'] = parsed_pairs.snippet_x.map(lambda row: charsonly(prepare_string(text)).find(row))
    parsed_pairs['loc_y'] = parsed_pairs.snippet_y.map(lambda row: charsonly(prepare_string(text)).find(row)) 
    
    gold_edus = read_edus(pure_filename)
    gold_edus = [tokenize(edu.replace(' -', ' - ').replace('  ', ' ')) for edu in gold_edus]

    sentences = [tok.text for paragraph in result['text'].split('\n') for tok in razdel.sentenize(paragraph)]
    for idx, sentence in tqdm(enumerate(sentences), total=len(sentences)):
        _text = prepare_string(charsonly(sentence))
        _gold_edus = [edu for edu in gold_edus if prepare_string(charsonly(edu)) in _text]

        _this_sentence = pairs.apply(lambda row: charsonly(row.snippet_x) in _text and charsonly(row.snippet_y) in _text, axis=1)
        _gold_pairs = pairs[_this_sentence]

        _this_sentence = parsed_pairs.apply(lambda row: charsonly(row.snippet_x) in _text and charsonly(row.snippet_y) in _text, axis=1)
        _pred_pairs = parsed_pairs[_this_sentence]

        if len(_gold_edus) > 1 and not parsed_pairs.empty and not _gold_pairs.empty:

            gold = gold_tree_as_string(_gold_pairs, '', _text)
            if _pred_pairs.empty:
                pred = ''
            else:
                pred = gold_tree_as_string(_pred_pairs, '', _text)
            
            global_metric_s([pred], [gold])

    sentences = [paragraph for paragraph in result['text'].split('\n')]
    for idx, sentence in tqdm(enumerate(sentences), total=len(sentences)):
        _text = prepare_string(charsonly(sentence))
        _gold_edus = [edu for edu in gold_edus if prepare_string(charsonly(edu)) in _text]

        _this_sentence = pairs.apply(lambda row: charsonly(row.snippet_x) in _text and charsonly(row.snippet_y) in _text, axis=1)
        _gold_pairs = pairs[_this_sentence]

        _this_sentence = parsed_pairs.apply(lambda row: charsonly(row.snippet_x) in _text and charsonly(row.snippet_y) in _text, axis=1)
        _pred_pairs = parsed_pairs[_this_sentence]

        if len(_gold_edus) > 1 and not parsed_pairs.empty and not _gold_pairs.empty:

            gold = gold_tree_as_string(_gold_pairs, pure_filename.split('/')[-1], _text)
            if _pred_pairs.empty:
                pred = ''
            else:
                pred = gold_tree_as_string(_pred_pairs, pure_filename.split('/')[-1], _text)

            global_metric_p([pred], [gold])

In [None]:
global_metric_p  # paragraphs

In [None]:
global_metric_s  # sentences