In [None]:
SERVER0 = ''
SERVER1 = ''

In [None]:
from isanlp.processor_remote import ProcessorRemote
from isanlp.processor_syntaxnet_remote import ProcessorSyntaxNetRemote
from isanlp import PipelineCommon
from isanlp.ru.converter_mystem_to_ud import ConverterMystemToUd

address_morph = (SERVER0, 4333)
address_syntax = (SERVER0, 4343)
address_rst = (SERVER1, 3490)

ppl = PipelineCommon([(ProcessorRemote(address_morph[0], address_morph[1], 'default'),
                 ['text'],
                 {'tokens': 'tokens',
                  'sentences': 'sentences',
                  'postag': 'mystem_postag',
                  'lemma': 'lemma'}),
                (ProcessorSyntaxNetRemote(address_syntax[0], address_syntax[1]),
                 ['tokens', 'sentences'],
                 {'syntax_dep_tree': 'syntax_dep_tree'}),
                (ConverterMystemToUd(),
                 ['mystem_postag'],
                 {'morph': 'morph',
                  'postag': 'postag'}),
                (ProcessorRemote(address_rst[0], address_rst[1], 'default'),
                 ['text', 'tokens', 'sentences', 'postag', 'morph', 'lemma', 'syntax_dep_tree'],
                 {'rst': 'rst'})])

In [None]:
res = ppl('Внутри 22-й Московской международной книжной выставки-ярмарки, проходившей в начале сентября 2009 года на ВВЦ, работала внутренняя выставка - «Книгабайт», посвящённая электронному книгоизданию.')

In [None]:
res = ppl('Как сообщили «Ведомости», правительство внесло в Госдуму пакет поправок в законодательство по помощи регионам в связи с потерями, которые они понесут из-за кризиса. Общая стоимость помощи - около 100 млрд рублей. Однако налоговые доходы регионов, по данным Минфина, могут снизиться в 2009 году на 700-800 млрд рублей.')

In [None]:
print(res['rst'])

In [None]:
def extr_pairs(tree):
    pp = []
    
    if tree.left:
        pp.append([tree.left.text, tree.right.text, tree.relation])
        pp += extr_pairs(tree.left)
        pp += extr_pairs(tree.right)
    return pp

In [None]:
print(res['rst'][0])

In [None]:
print(res['rst'][0].left.left)

In [None]:
print(res['rst'][0].right)

In [None]:
extr_pairs(res['rst'][0])

In [None]:
import pandas as pd

text_html_map = {
    r'\n': r' ',
    r'&gt;': r'>',
    r'&lt;': r'<',
    r'&amp;': r'&',
    r'&quot;': r'"',
    r'&ndash;': r'–',
    r'##### ': r'',
    r'\\\\\\\\': r'\\',
    r'  ': r' ',
    r'——': r'-',
    r'—': r'-',
    r'/': r'',
    r'\^': r'',
    r'^': r'',
    r'±': r'+',
    r'y': r'у',
    r'x': r'х'
}

def read_edus(filename):
    edus = []
    with open(filename + '.edus', 'r') as f:
        for line in f.readlines():
            edu = str(line.strip())
            for key, value in text_html_map.items():
                edu = edu.replace(key, value)
            edus.append(edu)
    return edus

def read_gold(filename):
    df = pd.read_pickle(filename + '.gold.pkl')
    for key in text_html_map.keys():
        df['snippet_x'].replace(key, text_html_map[key], regex=True, inplace=True)
        df['snippet_y'].replace(key, text_html_map[key], regex=True, inplace=True)

    return df

def read_json(filename):
    df = pd.read_json(filename + '.json')
    for key in text_html_map.keys():
        df['snippet_x'].replace(key, text_html_map[key], regex=True, inplace=True)
        df['snippet_y'].replace(key, text_html_map[key], regex=True, inplace=True)

    return df

def read_annotation(filename):
    annot = pd.read_pickle(filename + '.annot.pkl')
    for key in text_html_map.keys():
        annot['text'] = annot['text'].replace(key, text_html_map[key])
        for token in annot['tokens']:
            token.text = token.text.replace(key, text_html_map[key])
    
    return annot

In [None]:
filename = 'data/news1_62'
edus = read_edus(filename)
gold = read_json(filename)
annot = read_annotation(filename)

In [None]:
annot['text']

In [None]:
edus

In [None]:
gold

In [None]:
class DiscourseUnit:
    def __init__(self, id, left=None, right=None, text='', start=None, end=None, 
                 orig_text=None, relation=None, nuclearity=None, proba=1.):
        """
        :param int id:
        :param DiscourseUnit left:
        :param DiscourseUnit right:
        :param str text: (optional)
        :param int start: start position in original text
        :param int end: end position in original text
        :param string relation: {the relation between left and right components | 'elementary' | 'root'}
        :param string nuclearity: {'NS' | 'SN' | 'NN'}
        :param float proba: predicted probability of the relation occurrence
        """
        self.id = id
        self.left = left
        self.right = right
        self.relation = relation
        self.nuclearity = nuclearity
        self.proba = str(proba)
        self.start = start
        self.end = end

        if self.left:
            gap_counter = 0
            #while len(left.text + right.text) < len(self.text):
            #    self.text = left.text + ' ' * gap_counter + right.text
            #    gap_counter += 1
            self.start = left.start
            self.end = right.end
        
        # (1) for gold tree parsing
        """
        if orig_text:            
            self.text = orig_text[self.start:self.end].strip()
        else:
            self.text = text.strip()
        """
        # (2) ??
        
        if self.left:
            self.text = ' '.join([self.left.text, self.right.text])
        else:
            self.text = orig_text[self.start:self.end].strip()
    
    def __str__(self):
        return f"id: {self.id}\ntext: {self.text}\nrelation: {self.relation}\nleft: {self.left.text if self.left else None}\nright: {self.right.text if self.right else None}\nstart: {self.start}\nend: {self.end}"


In [None]:
#from isanlp.annotation_rst import DiscourseUnit
import pandas as pd


class RSTTreePredictor:
    def __init__(self, features_processor, relation_predictor, label_predictor):
        self.features_processor = features_processor
        self.relation_predictor = relation_predictor
        self.label_predictor = label_predictor
        if self.label_predictor:
            self.labels = self.label_predictor.classes_
        self.genre = None

    def predict_label(self, features):
        if not self.label_predictor:
            return 'relation'

        return self.label_predictor.predict(features)


class GoldTreePredictor(RSTTreePredictor):
    def __init__(self, corpus):
        RSTTreePredictor.__init__(self, None, None, None)
        self.corpus = corpus

    def extract_features(self, *args):
        return [args[0].text, args[1].text]

    def predict_pair_proba(self, features):
        def _check_snippet_pair_in_dataset(left_snippet, right_snippet):
            return ((((self.corpus.snippet_x == left_snippet) & (self.corpus.snippet_y == right_snippet)).sum(
                axis=0) != 0)
                    or ((self.corpus.snippet_y == left_snippet) & (self.corpus.snippet_x == right_snippet)).sum(
                        axis=0) != 0)

        left_snippet, right_snippet = features
        return float(_check_snippet_pair_in_dataset(left_snippet, right_snippet))

    def predict_label(self, features):
        left_snippet, right_snippet = features
        label = self.corpus[((self.corpus.snippet_x == left_snippet) & (self.corpus.snippet_y == right_snippet))].category_id.values
        if label.size == 0:
            return 'relation'
        
        return label[0]
    
    def predict_nuclearity(self, features):
        left_snippet, right_snippet = features
        nuclearity = self.corpus[((self.corpus.snippet_x == left_snippet) & (self.corpus.snippet_y == right_snippet))].order.values
        if nuclearity.size == 0:
            return '_'
        
        return nuclearity[0]


class CustomTreePredictor(RSTTreePredictor):
    def __init__(self, features_processor, relation_predictor, label_predictor=None):
        RSTTreePredictor.__init__(self, features_processor, relation_predictor, label_predictor)

    def extract_features(self, left_node: DiscourseUnit, right_node: DiscourseUnit,
                         annot_text, annot_tokens, annot_sentences, annot_postag, annot_morph, annot_lemma,
                         annot_syntax_dep_tree):
        pair = pd.DataFrame({
            'snippet_x': [left_node.text.strip()],
            'snippet_y': [right_node.text.strip()],
            #'genre': self.genre
        })

        try:
            features = self.features_processor(pair, annot_text=annot_text,
                                               annot_tokens=annot_tokens, annot_sentences=annot_sentences,
                                               annot_postag=annot_postag, annot_morph=annot_morph,
                                               annot_lemma=annot_lemma, annot_syntax_dep_tree=annot_syntax_dep_tree)
            return features
        except IndexError:
            with open('errors.log', 'w+') as f:
                f.write(str(pair.values))
                f.write(annot_text)
            return -1

    def predict_pair_proba(self, features):
        return self.relation_predictor.predict_proba(features)[0][1]

In [None]:
import numpy as np
import sys

#from isanlp.annotation_rst import DiscourseUnit


class GreedyRSTParser:
    def __init__(self, tree_predictor, forest_threshold=0.05):
        """
        :param RSTTreePredictor tree_predictor:
        :param float forest_threshold: minimum relation probability to append the pair into the tree
        """
        self.tree_predictor = tree_predictor
        self.forest_threshold = forest_threshold

    def __call__(self, edus, annot_text, annot_tokens, annot_sentences, annot_postag, annot_morph, annot_lemma,
                 annot_syntax_dep_tree, genre=None):
        """
        :param list edus: DiscourseUnit
        :param str annot_text: original text
        :param list annot_tokens: isanlp.annotation.Token
        :param list annot_sentences: isanlp.annotation.Sentence
        :param list annot_postag: lists of str for each sentence
        :param annot_lemma: lists of str for each sentence
        :param annot_syntax_dep_tree: list of isanlp.annotation.WordSynt for each sentence
        :return: list of DiscourseUnit containing each extracted tree
        """

        def to_merge(scores):
            return np.argmax(np.array(scores))

        self.tree_predictor.genre = genre

        nodes = edus
        
        for edu in nodes:
            print(edu, file=sys.stderr)
        
        max_id = edus[-1].id

        # initialize scores
        features = [
            self.tree_predictor.extract_features(nodes[i], nodes[i + 1], annot_text, annot_tokens,
                                                 annot_sentences,
                                                 annot_postag, annot_morph, annot_lemma,
                                                 annot_syntax_dep_tree)
            for i in range(len(nodes) - 1)]

        scores = [self.tree_predictor.predict_pair_proba(features[i]) for i in range(len(nodes) - 1)]
        relations = [self.tree_predictor.predict_label(features[i]) for i in range(len(nodes) - 1)]
        nuclearities = [self.tree_predictor.predict_nuclearity(features[i]) for i in range(len(nodes) - 1)]

        while len(nodes) > 2 and any([score > self.forest_threshold for score in scores]):
            # select two nodes to merge
            j = to_merge(scores)  # position of the pair in list
            
            # make the new node by merging node[j] + node[j+1]
            temp = DiscourseUnit(
                id=max_id + 1,
                left=nodes[j],
                right=nodes[j + 1],
                relation=self.tree_predictor.predict_label(features[j]),
                nuclearity=self.tree_predictor.predict_nuclearity(features[j]),
                proba=scores[j],
                text=nodes[j].text + nodes[j + 1].text  #annot_text[nodes[j].start:nodes[j+1].end]
            )
            
            print(temp, file=sys.stderr)
            
            max_id += 1

            # modify the node list
            nodes = nodes[:j] + [temp] + nodes[j + 2:]

            # modify the scores list
            if j == 0:
                features_right = self.tree_predictor.extract_features(nodes[j], nodes[j + 1],
                                                                annot_text, annot_tokens, 
                                                                annot_sentences, annot_postag,
                                                                annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted = self.tree_predictor.predict_pair_proba(features_right)

                scores = [predicted] + scores[j + 2:]
                features = [features_right] + features[j + 2:]

            elif j + 1 < len(nodes):
                features_left = self.tree_predictor.extract_features(nodes[j - 1], nodes[j], 
                                                                     annot_text, annot_tokens,
                                                                     annot_sentences, annot_postag, 
                                                                     annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted_left = self.tree_predictor.predict_pair_proba(features_left)

                features_right = self.tree_predictor.extract_features(nodes[j], nodes[j + 1], 
                                                                      annot_text, annot_tokens,
                                                                      annot_sentences, annot_postag, 
                                                                      annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted_right = self.tree_predictor.predict_pair_proba(features_right)

                scores = scores[:j - 1] + [predicted_left] + [predicted_right] + scores[j + 2:]
                features = features[:j - 1] + [features_left] + [features_right] + features[j + 2:]

            else:
                features_left = self.tree_predictor.extract_features(nodes[j - 1], nodes[j],
                                                                annot_text, annot_tokens, 
                                                                annot_sentences, annot_postag,
                                                                annot_morph, annot_lemma, annot_syntax_dep_tree)
                predicted = self.tree_predictor.predict_pair_proba(features_left)
                scores = scores[:j - 1] + [predicted]
                features = features[:j - 1] + [features_left]

        if len(scores) == 1 and scores[0] > self.forest_threshold:
            root = DiscourseUnit(
                id=max_id + 1,
                left=nodes[0],
                right=nodes[1],
                relation='root',
                proba=scores[0]
            )
            nodes = [root]

        return nodes


In [None]:
_edus = []
last_end = 0
for max_id in range(len(edus)):
    start = len(annot['text'][:last_end]) + annot['text'][last_end:].find(edus[max_id])
    end = start + len(edus[max_id])
    temp = DiscourseUnit(
            id=max_id,
            left=None,
            right=None,
            relation='edu',
            start=start,
            end=end,
            orig_text=annot['text'],
            proba=1.,
            #text=edus[max_id]  #annot_text[nodes[j].start:nodes[j+1].end]
        )
    _edus.append(temp)
    last_end = end

parser = GreedyRSTParser(GoldTreePredictor(gold), forest_threshold=0.)
parsed = parser(_edus, annot['text'], annot['tokens'], annot['sentences'],
                annot['postag'], annot['morph'], annot['lemma'], annot['syntax_dep_tree'])

In [None]:
parsed

In [None]:
import functools as fn


def printBTree(node, nodeInfo=None, inverted=False, isTop=True):
    # node value string and sub nodes
    info = nodeInfo(node)
    
    if info:
        stringValue, leftNode, rightNode = info

        stringValueWidth  = len(stringValue)

        # recurse to sub nodes to obtain line blocks on left and right
        leftTextBlock     = [] if not leftNode else printBTree(leftNode, nodeInfo, inverted, False)
        rightTextBlock    = [] if not rightNode else printBTree(rightNode, nodeInfo, inverted, False)

        # count common and maximum number of sub node lines
        commonLines       = min(len(leftTextBlock),len(rightTextBlock))
        subLevelLines     = max(len(rightTextBlock),len(leftTextBlock))

        # extend lines on shallower side to get same number of lines on both sides
        leftSubLines      = leftTextBlock  + [""] *  (subLevelLines - len(leftTextBlock))
        rightSubLines     = rightTextBlock + [""] *  (subLevelLines - len(rightTextBlock))

        # compute location of value or link bar for all left and right sub nodes
        #   * left node's value ends at line's width
        #   * right node's value starts after initial spaces
        leftLineWidths    = [ len(line) for line in leftSubLines  ]                            
        rightLineIndents  = [ len(line)-len(line.lstrip(" ")) for line in rightSubLines ]

        # top line value locations, will be used to determine position of current node & link bars
        firstLeftWidth    = (leftLineWidths   + [0])[0]  
        firstRightIndent  = (rightLineIndents + [0])[0] 

        # width of sub node link under node value (i.e. with slashes if any)
        # aims to center link bars under the value if value is wide enough
        # 
        # ValueLine:    v     vv    vvvvvv   vvvvv
        # LinkLine:    / \   /  \    /  \     / \ 
        #
        linkSpacing       = min(stringValueWidth, 2 - stringValueWidth % 2)
        leftLinkBar       = 1 if leftNode  else 0
        rightLinkBar      = 1 if rightNode else 0
        minLinkWidth      = leftLinkBar + linkSpacing + rightLinkBar
        valueOffset       = (stringValueWidth - linkSpacing) // 2

        # find optimal position for right side top node
        #   * must allow room for link bars above and between left and right top nodes
        #   * must not overlap lower level nodes on any given line (allow gap of minSpacing)
        #   * can be offset to the left if lower subNodes of right node 
        #     have no overlap with subNodes of left node                                                                                                                                 
        minSpacing        = 2
        rightNodePosition = fn.reduce(lambda r,i: max(r,i[0] + minSpacing + firstRightIndent - i[1]), \
                                     zip(leftLineWidths,rightLineIndents[0:commonLines]), \
                                     firstLeftWidth + minLinkWidth)

        # extend basic link bars (slashes) with underlines to reach left and right
        # top nodes.  
        #
        #        vvvvv
        #       __/ \__
        #      L       R
        #
        linkExtraWidth    = max(0, rightNodePosition - firstLeftWidth - minLinkWidth )
        rightLinkExtra    = linkExtraWidth // 2
        leftLinkExtra     = linkExtraWidth - rightLinkExtra

        # build value line taking into account left indent and link bar extension (on left side)
        valueIndent       = max(0, firstLeftWidth + leftLinkExtra + leftLinkBar - valueOffset)
        valueLine         = " " * max(0,valueIndent) + stringValue
        slash             = "\\" if inverted else  "/"
        backslash         = "/" if inverted else  "\\"
        uLine             = "¯" if inverted else  "_"

        # build left side of link line
        leftLink          = "" if not leftNode else ( " " * firstLeftWidth + uLine * leftLinkExtra + slash)

        # build right side of link line (includes blank spaces under top node value) 
        rightLinkOffset   = linkSpacing + valueOffset * (1 - leftLinkBar)                      
        rightLink         = "" if not rightNode else ( " " * rightLinkOffset + backslash + uLine * rightLinkExtra )

        # full link line (will be empty if there are no sub nodes)                                                                                                    
        linkLine          = leftLink + rightLink

        # will need to offset left side lines if right side sub nodes extend beyond left margin
        # can happen if left subtree is shorter (in height) than right side subtree                                                
        leftIndentWidth   = max(0,firstRightIndent - rightNodePosition) 
        leftIndent        = " " * leftIndentWidth
        indentedLeftLines = [ (leftIndent if line else "") + line for line in leftSubLines ]

        # compute distance between left and right sublines based on their value position
        # can be negative if leading spaces need to be removed from right side
        mergeOffsets      = [ len(line) for line in indentedLeftLines ]
        mergeOffsets      = [ leftIndentWidth + rightNodePosition - firstRightIndent - w for w in mergeOffsets ]
        mergeOffsets      = [ p if rightSubLines[i] else 0 for i,p in enumerate(mergeOffsets) ]

        # combine left and right lines using computed offsets
        #   * indented left sub lines
        #   * spaces between left and right lines
        #   * right sub line with extra leading blanks removed.
        mergedSubLines    = zip(range(len(mergeOffsets)), mergeOffsets, indentedLeftLines)
        mergedSubLines    = [ (i,p,line + (" " * max(0,p)) )       for i,p,line in mergedSubLines ]
        mergedSubLines    = [ line + rightSubLines[i][max(0,-p):]  for i,p,line in mergedSubLines ]                        

        # Assemble final result combining
        #  * node value string
        #  * link line (if any)
        #  * merged lines from left and right sub trees (if any)
        treeLines = [leftIndent + valueLine] + ( [] if not linkLine else [leftIndent + linkLine] ) + mergedSubLines

        # invert final result if requested
        treeLines = reversed(treeLines) if inverted and isTop else treeLines

        # return intermediate tree lines or print final result
        if isTop : return "\n".join(treeLines)
        else     : return treeLines        

def print_rst_tree(tree, file):
    def _(n):
        if n.relation != 'elementary':
            value = (n.relation, n.proba)
        else:
            value = n.text
        
        return str(value), n.left, n.right

    lines = printBTree(tree, _)
    file.write(lines)


In [None]:
print_rst_tree(parsed[2], open('tmp.txt', 'w'))

In [None]:
def extract_relations_list(tree):        
    rr = []
    
    if tree.relation not in rr:
        rr = [tree.relation]

    if tree.left:
        rr += extract_relations_list(tree.left)
        rr += extract_relations_list(tree.right)

    rr = list(set(rr))
    rr.sort()
    
    return rr

def extract_segments(tree):
    span_counter = tree.id + 1
    edus = []
    spans = []

    if tree and tree.left and tree.left.relation == 'edu':
        temp_edu = ((tree.left.id, tree.right.id, tree.relation, tree.left.text))  # id, parent, relation, text            
        edus.append(temp_edu)
        
    if tree and tree.right and tree.right.relation == 'edu':
        edus.append((tree.right.id, tree.left.id, tree.relation, tree.right.text))

    if tree.left:
        edus += extract_segments(tree.left)
        
    if tree.right:
        edus += extract_segments(tree.right)


    return sorted(edus, key=lambda x: x[0])

def extract_groups(tree):
    groups = []
        
    if tree and tree.left and tree.left.relation != 'edu':
        groups.append((tree.left.id, tree.right.id, tree.relation, tree.text))  # id, parent, relation, text
        
    if tree and tree.right and tree.right.relation != 'edu':
        groups.append((tree.right.id, tree.left.id, tree.relation, tree.text))
    
    else:
        if tree.left:
            groups += extract_groups(tree.left)
        if tree.right:
            groups += extract_groups(tree.right)

    return groups
    
def export_rs3(tree):    
    def make_header():
        def wrap_relations(relations_list):
            res = '\t\t<relations>\n'

            for relation in relations_list:
                res += f'\t\t\t<rel name="{relation}" type="rst" />\n'

            res += '\t\t</relations>\n'
            return res
    
        res = '\t<header>\n' +\
                wrap_relations(extract_relations_list(tree)) +\
              '\t</header>\n'
    
        return res
    
    def make_body():
        def wrap_segments(segments):
            res = ''
            
            for segment in segments:
                res += f'\t\t<segment id="{segment[0]}" parent="{segment[1]}" relname="{segment[2]}">'
                res += segment[3]
                res += '</segment>\n'
            
            return res
        
        def wrap_groups(groups):
            res = ''
            
            for group in groups:
                res += f'\t\t<group id="{group[0]}" type="span" parent="{group[1]}" relname="{group[2]}">\n'
                
            return res
                
        
        res = '\t<body>\n' +\
                wrap_segments(extract_segments(tree)) +\
                wrap_groups(extract_groups(tree)) +\
              '\t</body>\n'
        
        return res
    
    return '<rst>\n' + make_header() + make_body() + '</rst>'

In [None]:
class Span:
    def __init__(self, id, left_id, right_id):
        self.id = id
        self.left_id = left_id
        self.right_id = right_id

In [None]:
def extract_spans(tree):
    span_counter = tree.id + 1
    edus = []
    spans = []

    if tree and tree.left and tree.left.relation == 'edu':
        spans.append(Span(span_counter, tree.left.id, tree.right.id))
        #temp_edu = ((tree.left.id, tree.right.id, tree.relation, tree.left.text))  # id, parent, relation, text

        if tree.nuclearity == 'NS':
            edus.append((tree.left.id, span_counter, 'span', tree.left.text))  # id, parent, relation, text
            edus.append((tree.right.id, tree.left.id, tree.relation, tree.right.text))
        elif tree.nuclearity == 'SN':
            edus.append((tree.left.id, tree.right.id, tree.relation, tree.left.text))
            edus.append((tree.right.id, span_counter, 'span', tree.right.text))
        elif tree.nuclearity == 'NN':
            edus.append((tree.left.id, span_counter, 'multinuc', tree.left.text))
            edus.append((tree.right.id, span_counter, 'multinuc', tree.right.text))
            
    if tree.left:
        n_spans, n_edus = extract_spans(tree.left)
        spans += n_spans
        edus += n_edus
        
    if tree. right:
        n_spans, n_edus = extract_spans(tree.right)
        spans += n_spans
        edus += n_edus        

    return sorted(spans, key=lambda x: x.id), edus

In [None]:
temp = read_json('data/news1_62')

In [None]:
temp.order.value_counts()

In [None]:
parsed[0].nuclearity

In [None]:
tree = parsed[0]

In [None]:
print(export_rs3(parsed[0]))