In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

### Compare two rs3 files
 - Parseval metrics as for the parser
 - Inter-annotator aggreement from https://aclanthology.org/W19-2712.pdf

In [None]:
from utils.dataset.rs3_forest_splitter import RS3ForestSplitter
from utils.dataset.rst2dis_converter import split_seq, RST2DISConverter
from utils.dataset.dis_file_reading import *
from utils.discourseunit2str import *
from utils import metrics
import subprocess
import os
import glob

In [None]:
def split_forest_rs3(filename1, filename2, output_dir):
    """ Makes in output_dir two subdirectories with separate trees:
        file1/part_0.rs3
        ...
        file1/part_n.rs3
        file2/part_0.rs3
        ...
        file2/part_m.rs3
    """
    
    os.mkdir(output_dir + '/file1')
    os.mkdir(output_dir + '/file2')
    
    splitter = RS3ForestSplitter()
    splitter(filename1, output_dir + '/file1')
    splitter(filename2, output_dir + '/file2')
    
    for d in ('file1', 'file2'):
        for file in glob.glob(f'temp/{d}/*.rs3'):
            new_filename = 'part_' + file.split('_part_')[-1]
            with open(file, 'r') as f:
                contents = f.read()
            os.remove(file)
            with open(f'{output_dir}/{d}/{new_filename}', 'w') as f:
                f.write(contents)
                

def run_rsttace(input_dir):
    """ Runs rsttace to compute inter-annotation aggreement metrics between each pair:
        input_dir/file1/part_n.rs3
        input_dir/file2/part_n.rs3
    """
    
    subprocess.run(['rsttace', 'compare', os.path.join(input_dir, 'file1'), os.path.join(input_dir, 'file2'), '-o', input_dir])

def run_parseval(input_dir, matching_trees=False):
    """ a) Converts input_dir/file_*/*.rs3 to *.dis files
        b) Loads them in DiscourseUnit objects
        c) Computes Parseval score for each tree pair
    """
    
    # Convert to *.dis
    converter_url = 'localhost:5000'  # <- Put address of the rst converter service here
    converter_threads = 1
    for part in ('file1', 'file2'):
        files = glob.glob(os.path.join(input_dir, part, '*.rs3'))
        for batch in split_seq(files, converter_threads):
            t = RST2DISConverter(converter_url, batch, output_dir=os.path.join(input_dir, part))
            t.start()
            t.join()
    
    # Collect trees to structures
    structures = dict()
    for part in ('file1', 'file2'):
        file = glob.glob(os.path.join(input_dir, part, '*part_0.dis'))[0]
        trees = []
        for i in range(100):
            _file = file.replace('part_0', f'part_{i}')
            if not os.path.isfile(_file):
                break
            try:
                trees.append(read_dis(_file))
            except Exception as e:
                print(e)
        structures[part] = trees

    # Computes Parseval for separate trees and whole document
    if matching_trees:
        counter = iter(range(len(structures['file1'])))
        for tree1, tree2 in zip(structures['file1'], structures['file2']):
            struct1 = get_docs_structure_charsonly([tree1], needs_preprocessing=False)
            struct2 = get_docs_structure_charsonly([tree2], needs_preprocessing=False)
            local_metric = metrics.DiscourseMetricDoc(eps=1e-10)
            local_metric(golds=struct1, preds=struct2)
            print(f'paragraph {next(counter)}:\t', local_metric)

    print()    
    global_metric = metrics.DiscourseMetricDoc(eps=1e-20)
    golds = get_docs_structure_charsonly(structures['file1'], needs_preprocessing=False)
    preds = get_docs_structure_charsonly(structures['file2'], needs_preprocessing=False)
    global_metric(golds=golds, preds=preds)
    print('document:\t', global_metric)

In [None]:
directory = 'temp'
if os.path.isdir(directory):
    ! rm -r $directory
    
os.mkdir(directory)

FILE1, FILE2 = 'corpus/41.txt_local.rs3', 'corpus/20.txt.rs3'
split_forest_rs3(FILE1, FILE2, directory)
run_rsttace(directory)
run_parseval(directory, matching_trees=True)