### Scripts for training and evaluation of the system described in Nguyen et al. on our data.

<i>Nguyen, Thanh-Tung, et al. "RST Parsing from Scratch." Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies. 2021.</i>

In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
! cp ../isanlp_rst/td_rst_parser .

In [None]:
from tqdm import tqdm
from utils.discourseunit2str import *

In [None]:
from utils.train_test_split import split_rstreebank

print('Loading RSTreebank:')
train, dev, test = split_rstreebank('./data_ru')
print('Train length:', len(train), 'Dev length:', len(dev), 'Test length:', len(test), '(files)')

In [None]:
from isanlp import PipelineCommon
from isanlp.processor_razdel import ProcessorRazdel

ppl = PipelineCommon([
    (ProcessorRazdel(), ['text'],
    {'tokens': 'tokens',
     'sentences': 'sentences'}),
])

In [None]:
packed_train, packed_dev, packed_test = dict(), dict(), dict()
for key in ('InputDocs', 'EduBreak_TokenLevel', 'SentBreak', 'Docs_structure', 'filename'):
    packed_train[key], packed_dev[key], packed_test[key] = [], [], []


for file in tqdm(train + dev + test):
    path_data = 'data_ru' if 'news' in file or 'blog' in file else 'dep_data'
    path_du = 'corpus_du' if 'news' in file or 'blog' in file else 'dep_corpus_du'
    pure_filename = file.split('/')[-1][:-5]
    
    for i in range(100):
        filename = os.path.join(path_du, f'{pure_filename}_part_{i}.du')
        if not os.path.isfile(filename): break
        
        trees = [pickle.load(open(filename, 'rb'))]
        annot = ppl(trees[0].text)
        edus = get_edu_breaks(trees, annot)

        if len(edus) > 1:
            if file in train:
                packed_train['InputDocs'].append(get_input_docs(annot))
                packed_train['EduBreak_TokenLevel'].append(get_edu_breaks(trees, annot))
                packed_train['SentBreak'].append(get_sentence_breaks(annot))
                packed_train['Docs_structure'].append(get_docs_structure(trees))
                packed_train['filename'].append(file)

            elif file in dev:
                packed_dev['InputDocs'].append(get_input_docs(annot))
                packed_dev['EduBreak_TokenLevel'].append(get_edu_breaks(trees, annot))
                packed_dev['SentBreak'].append(get_sentence_breaks(annot))
                packed_dev['Docs_structure'].append(get_docs_structure(trees))
                packed_dev['filename'].append(file)

            elif file in test:
                packed_test['InputDocs'].append(get_input_docs(annot))
                packed_test['EduBreak_TokenLevel'].append(get_edu_breaks(trees, annot))
                packed_test['SentBreak'].append(get_sentence_breaks(annot))
                packed_test['Docs_structure'].append(get_docs_structure(trees))
                packed_test['filename'].append(file)

In [None]:
processed_data_path = 'td_rst_parser/data'
if not os.path.isdir(processed_data_path):
    os.mkdir(processed_data_path)

with open(os.path.join(processed_data_path, 'train_data'), 'wb') as f:
	pickle.dump(packed_train, f)
    
with open(os.path.join(processed_data_path, 'dev_data'), 'wb') as f:
	pickle.dump(packed_dev, f)

with open(os.path.join(processed_data_path, 'test_data'), 'wb') as f:
	pickle.dump(packed_test, f)

In [None]:
# Check for bugs, output must be empty
for i, struct in enumerate(packed_train['Docs_structure']):
    assert struct, f"Check packed_train['EduBreak_TokenLevel'][{i}]"

### Word2vec: download, lowercase, remove postags 

In [None]:
# %%bash

# cd ./td_rst_parser/src/processed_data/
# wget http://vectors.nlpl.eu/repository/20/220.zip
# unzip 220.zip model.txt 
# mv model.txt w2v.txt

In [None]:
with open('./td_rst_parser/data/w2v.txt', 'r') as f:
    lines = f.readlines()

vocab = []  # We won't use POS tags, so take only the first mention of the form in the w2v (the most frequent one)
with open('./td_rst_parser/data/w2v.txt', 'w') as f:
    for line in tqdm(lines):
        new_line = line.strip()
        key = new_line.split(' ')[0]
        value = new_line.split(' ')[1:]
        
        if '_' in key:
            key = key.split('_')[0]
        
        if not key in vocab:
            vocab.append(key)
            new_line = (key + '\t' + ' '.join(value)).lower()    
            f.write(new_line + '\n')

In [None]:
'unknown' in vocab

### Train the model 

In [None]:
%%writefile td_rst_parser/discourse_config.ini

[Network]
;batch_size=10000
;n_embed=300
min_freq=1

In [None]:
! mkdir td_rst_parser/exp

In [None]:
%%writefile td_rst_parser/run_discourse_doc_goldsegmentation_edu_rep_train.sh

export DATA_PATH='./data'
export MODE='train'
export FEAT='char'
export LEARNING_RATE_SCHEDULE='Exponential'
export PRETRAINED_EMBEDDING='./data/w2v.txt'
export N_EMBED=300

export BERT_MODEL=''
export BATCH_SIZE=4000  
export BEAM_SIZE=20

python -m src.cmds.pointing_discourse_gold_segmentation_edu_rep train -b -d 1 -p exp/ptb.pointing.discourse.gold_segmentation_edu_rep.$FEAT \
--data_path $DATA_PATH -f $FEAT --learning_rate_schedule $LEARNING_RATE_SCHEDULE \
--batch-size $BATCH_SIZE --conf 'discourse_config.ini' --n-embed $N_EMBED --unk 'unknown' \
--beam-size $BEAM_SIZE

cd td_rst_parser && sh run_discourse_doc_goldsegmentation_edu_rep_train.sh

### Sandbox
 - Evaluate the selected model across documents
 - Quantize it and measure time and performance of two models; time optimization is crucial for RST parsing. 

In [None]:
! rm -r src
! ln -s td_rst_parser/src src

In [None]:
import pickle
test = pickle.load(open('td_rst_parser/data/test_data', 'rb'))

In [None]:
from td_rst_parser.predict_interactive import TrainedPredictor
pr = TrainedPredictor('td_rst_parser/exp/ptb.pointing.discourse.gold_segmentation_edu_rep.char/2022_05_26_08_58_41/model_dev_UF_65.41_NF_43.00_RF_32.09.pt',
                      device='cpu')  # quantized models work only on cpu

In [None]:
%%time

predictions = pr.predict(test)

In [None]:
import torch

model_int8 = torch.quantization.quantize_dynamic(pr.parser.model.to('cpu'))
pr.parser.model = model_int8
pr.parser.save('quantized_model.pt')  # 80M vs 198M

In [None]:
%%time

predictions = pr.predict(test)

In [None]:
def edu2tokens(tree, edu_breaks):
    result = []
    for node in tree:
        left_begin, left_rel, border, right_rel, right_end = node[1:-1].split(':')
        left_end, right_begin = border.split(',')
        
        if left_begin == '1':
            left_begin_toks = 0
        else:
            left_begin_toks = edu_breaks[int(left_end)-2] + 1
        
        left_end_toks = edu_breaks[int(left_end)-1]
        right_begin_toks = left_end_toks + 1
        right_end_toks = edu_breaks[int(right_end)-1]
        
        result.append(f'({left_begin_toks}:{left_rel}:{left_end_toks},{right_begin_toks}:{right_rel}:{right_end_toks})')
    return result

golds = [edu2tokens(test['Docs_structure'][i], test['EduBreak_TokenLevel'][i]) for i in range(len(test['Docs_structure']))]

In [None]:
from utils.metrics import DiscourseMetricDoc

metric = DiscourseMetricDoc()
for i, tree in enumerate(golds):
    pred = predictions['trees'][i].split(' ')
    metric(golds=tree, preds=pred)
    # print(metric)
print(metric)

In [None]:
metric = DiscourseMetricDoc()
metric(golds=[' '.join(tree) for tree in golds], preds=predictions['trees'])
print(metric)

FP32 (198MB, 28s):<br>
UF: 56.91% NF: 39.26% RF: 30.08% Full RNF: 29.54% <br>
UF: 56.04% NF: 38.78% RF: 29.70% Full RNF: 29.16% 

Int8 (82MB, 22s):<br>
UF: 56.98% NF: 39.23% RF: 30.19% Full RNF: 29.65% <br>
UF: 56.12% NF: 38.74% RF: 29.80% Full RNF: 29.27% 