In [None]:
import re
import os
import numpy as np
import tqdm

fout_path = 'bleus_editable.txt'
if os.path.isfile(fout_path):
    raise ValueError("File already exists")

def get_edit_bleu(checkpoint_path, edit_sample_index=None):
    edit_sample_str = ''
    if edit_sample_index is not None:
        edit_sample_str = '--edit-sample-index {}'.format(edit_sample_index)

    res = !python edited_generate.py data-bin/iwslt14.tokenized.de-en  \
           --path {checkpoint_path} \
           --edit-samples-path edit_iwslt14.tokenized.de-en/bpe_test.txt \
           {edit_sample_str} \
           --no-progress-bar \
           --beam 5 --remove-bpe --sacrebleu --moses-detokenizer de 
        
    if edit_sample_index is not None:
        bleu_id = -2
    else:
        bleu_id = -1

    try:
        bleu = float(re.findall('BLEU\(score=(\d+(\.\d+|)),', res[bleu_id])[0][0])
    except:
        bleu = None
    
    if edit_sample_index is not None:
        try:
            success, complexity = re.findall('EditResult\(success=(True|False), complexity=(\d+)\)', res[-1])[0]
            success = eval(success)
            complexity = int(complexity)
        except:
            success = complexity = None
        
        with open(fout_path, 'a') as f:
            print(bleu, success, complexity, file=f)
    else:
        success = complexity = None
    
    return bleu, success, complexity

In [None]:
get_edit_bleu('checkpoints_editable/checkpoint_best.pt')

In [None]:
np.random.seed(42)
ids = np.random.permutation(6749)[:1000]
[get_edit_bleu('checkpoints_editable/checkpoint_best.pt', id) for id in tqdm.tqdm_notebook(ids)]

In [None]:
bleus = []
successes = []
complexities = []
with open('bleus_samples_stability1_kl_almost_last.txt') as f:
    for line in f:
        bleu, success, complexity = line[:-1].split()
        
        try:
            bleu = float(bleu)
            success = bool(success)
            complexity = float(complexity)
        except:
            bleu = 0
            success = False
            complexity = -1
        
        bleus.append(bleu)
        successes.append(success)
        complexities.append(complexity)
        
bleus = np.array(bleus)
successes = np.array(successes)
complexities = np.array(complexities)

print('Mean BLEU:', np.mean(bleus))
print('Success rate:', np.mean(successes))
print('Mean complexity:', np.mean(complexities))