## SARI script

A script to run the SARI evaluation metric. Model outputs must be reconstructed before the SARI score is calculated and the script provides different reconstruction functions depending on the model which is being evaluated.

SARI is implemented using an adapted version of Wei Coco Xu's script: <https://github.com/cocoxu/simplification>

In [None]:
from SARI import SARIsent
from nltk import word_tokenize
from LSTM_reconstruct import LSTM_reconstruct
from BERT_reconstruct import BERT_reconstruct
from BERT_reconstruct import BERT_rules_reconstruct
import pickle

In [None]:
# Read the test data

def read_data(test_file):
    with open(test_file, 'rb') as f:
        data = pickle.load(f)
        return data

In [None]:
test_file = "" # Path to test output
test_data = read_data(test_file)

In [None]:
# If the output is from a LSTM model: reconstruct those outputs

target_reconstructions = []
pred_reconstructions = []

for i in range(len(test_data['predictions'])):
    pred_recon, target_recon = LSTM_reconstruct(test_data['predictions'][i],
                                           test_data['targets'][i])
    pred_reconstructions.append(pred_recon)
    target_reconstructions.append(target_recon)

In [None]:
# If the output is from a BERT_rules ensemble model: reconstruct those outputs

rules_data = read_data("") # Path to rules output

pred_reconstructions, target_reconstructions, originals = BERT_rules_reconstruct(test_data, rules_data)

In [None]:
# If the output is from a Rules based model: reconstruct those outputs

target_reconstructions = []
pred_reconstructions = []
originals = []

for i in range(len(test_data['predictions'])):
    target_reconstructions.append(test_data['targets'][i])
    pred_reconstructions.append(test_data['predictions'][i])
    originals.append(test_data['originals'][i])

In [None]:
# If the output is from a BERT model: reconstruct those outputs

pred_reconstructions, target_reconstructions, originals = BERT_reconstruct(test_data)

In [None]:
# Process original inputs (LSTM only)

originals = []

for i in range(len(test_data['predictions'])):
    original = test_data['wholeInput'][i]
    if len(original) != 0:
        if str(original)[0:2] == "b'" and str(original)[-1] == "'":
            original = str(original)[2:-1]
            originals.append(original)
        elif str(original)[0:2] == 'b"' and str(original)[-1] == '"':
            original = str(original)[2:-1]
            originals.append(original)

In [None]:
# Get SARI scores

count = len(pred_reconstructions)
rolling_SARI = 0

for i in range(len(pred_reconstructions)):
    token_target = word_tokenize(target_reconstructions[i])
    token_prediction = word_tokenize(pred_reconstructions[i])
    token_original = word_tokenize(originals[i])
    sentence_SARI, _, _, _ = SARIsent(token_original, token_prediction, [token_target])
    rolling_SARI += sentence_SARI
    
average_SARI = rolling_SARI/count
print(average_SARI)