In [None]:
import sys
import os
from collections import defaultdict
import random
import numpy as np

np.random.seed(1234)
random.seed(1234)

%env CUDA_VISIBLE_DEVICES=0


src_lang, dst_lang = 'de', 'en'
checkpoint_path = 'baseline_checkpoint.pt'
data_path = 'data-bin/iwslt14.tokenized.{}-{}/'.format(src_lang, dst_lang)
output_path = 'edit_iwslt14.tokenized.{}-{}'.format(src_lang, dst_lang)
beam = nbest = 32
max_tokens = 1024
keys = ['valid', 'test', 'train']


sys.path.append('fairseq')

In [None]:
if os.path.isdir(output_path):
    if len(os.listdir(output_path)) > 0:
        raise ValueError('Output directory {} is not empty'.format(output_path))
else:
    os.makedirs(output_path, exist_ok=True)

In [None]:
if not os.path.isfile(checkpoint_path):
    ! wget https://www.dropbox.com/s/iksezig9qi4g92e/baseline_checkpoint.pt?dl=1 -O {checkpoint_path}

In [None]:
from fairseq.data.dictionary import Dictionary

src_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(src_lang)))
dst_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(dst_lang)))

In [None]:
def generate_alternatives(key, file):
    !fairseq-generate {data_path} --path {checkpoint_path} \
     --beam {beam} --nbest {nbest} \
     --max-tokens {max_tokens} \
     --gen-subset {key} > {file}

In [None]:
def generate_edits(key, file):
    !fairseq-generate {data_path} --path {checkpoint_path} \
     --beam {beam} --nbest {nbest} \
     --max-tokens {max_tokens} \
     --sampling --temperature 1.2 \
     --gen-subset {key} > {file}

In [None]:
for key in keys:
    tmp_alternatives_output_file_path = os.path.join(output_path, 'tmp_alternaatives_{}.txt'.format(key))
    tmp_edits_output_file_path = os.path.join(output_path, 'tmp_edits_{}.txt'.format(key))
    bpe_output_file_path = os.path.join(output_path, 'bpe_{}.txt'.format(key))
    
    print('Start generating {} alternatives'.format(key))
    generate_alternatives(key, tmp_alternatives_output_file_path)
    print('Start generating {} edits'.format(key))
    generate_edits(key, tmp_edits_output_file_path)

    logs = defaultdict(list)

    print('Start parsing the alternatives beam search output')

    with open(tmp_alternatives_output_file_path) as f_in:
        for line in f_in:
            try:
                tag, *payload = line.split('\t')
                logs[tag].append(payload)
            except: pass
            
    print('Start parsing the edits beam search output')
    
    edit_logs = defaultdict(list)

    with open(tmp_edits_output_file_path) as f_in:
        for line in f_in:
            try:
                tag, *payload = line.split('\t')
                edit_logs[tag].append(payload)
            except: pass
    
    print('Start edit samples generating')

    with open(bpe_output_file_path, 'w') as f_out:
        i = 0
        while True:
            i += 1
            try:
                source = logs['S-{}'.format(i)][0][0].strip()
                edit_source = edit_logs['S-{}'.format(i)][0][0].strip()
            except:
                break
            hypos = [hypo.strip() for prob, hypo in logs['H-{}'.format(i)]]
            
            edits = [edit.strip() for prob, edit in edit_logs['H-{}'.format(i)]]
            edit = random.choice(edits)

            np.random.shuffle(hypos)
            f_out.write('{}\t{}\t{}\n'.format(source, edit, '\t'.join(hypos)))
            
    print('_'*100)
    print('\n'*3)