## Translation (Transformer)

In [1]:
%load_ext autoreload
%autoreload 2
%pylab inline
import tree_text_gen.binary.translation.evaluate as evaluate

import os
from glob import glob
from pprint import pprint as pp

Populating the interactive namespace from numpy and matplotlib


### SETUP

First extract the models downloaded from [here](https://drive.google.com/file/d/172Ir1oNvHBgnLO1hWqDeiAcBH5i6pfwi/view?usp=sharing) into the `tree_text_gen` base directory. The cell below will then load the correct experiment directories. The cell prints the directories and the contents of an example experiment directory.

In [3]:
import tree_text_gen
project_dir = os.path.abspath(os.path.join(os.path.dirname(tree_text_gen.__file__), os.pardir))
dirs = glob(os.path.join(project_dir, 'models/translation/*'))
pp(dirs)
d = dirs[0]
!ls $d

['/home/sw1986/projects/phd/tree_text_gen/models/translation/leftright',
 '/home/sw1986/projects/phd/tree_text_gen/models/translation/annealed_tree',
 '/home/sw1986/projects/phd/tree_text_gen/models/translation/uniform',
 '/home/sw1986/projects/phd/tree_text_gen/models/translation/annealed']
leftright  leftright.checkpoint  model_config.json  tok2i.json


### Load each model from `dirs`

`exprs` will map experiment/model name to pytorch model

In [4]:
CHECKPOINT = False

exprs = {}
models = {}
for d in dirs:
    expr_name = d.split('/')[-1]
    print(expr_name)
    models[expr_name] = evaluate.load_transformer_eval(d, expr_name, checkpoint=CHECKPOINT)
    

leftright
annealed_tree
uniform
annealed


### Load Data

In [5]:
import argparse
from tree_text_gen.binary.translation.args import common_args
from tree_text_gen.binary.translation.data import load_iwslt, load_iwslt_vocab
from tree_text_gen.binary.common.util import setup
import copy
import torchtext.data

parser = argparse.ArgumentParser()
common_args(parser)
args = parser.parse_args(args=[])
args = setup(args, log=False)
args.translate = "{}_{}".format(args.src, args.trg)
args.logger = None

# -- DATA
train_data, dev_data, test_data, SRC, TRG = load_iwslt(args)
tok2i, i2tok, SRC, TRG = load_iwslt_vocab(args, SRC, TRG, args.data_prefix)

SRC = copy.deepcopy(SRC)
for data_ in [train_data, dev_data, test_data]:
    if not data_ is None:
        data_.fields['src'] = SRC
        
sort_key = lambda x: len(x.src)
trainloader = torchtext.data.BucketIterator(dataset=train_data, batch_size=args.batch_size, device=args.device, train=True, repeat=False, shuffle=True, sort_key=sort_key, sort_within_batch=True)
validloader = torchtext.data.BucketIterator(dataset=dev_data, batch_size=args.batch_size, device=args.device, train=False, repeat=False, shuffle=True, sort_key=sort_key, sort_within_batch=True)
testloader = torchtext.data.BucketIterator(dataset=test_data, batch_size=args.batch_size, device=args.device, train=False, repeat=False, shuffle=False, sort_key=sort_key, sort_within_batch=True)

### Evaluation

For the `+ tree-encoding` result (i.e. tree encoding without end tuning), set `end_tuned = False` below.


#### Meteor and Yisi Note
Meteor and Yisi require additional setup; **comment out** `common_eval.eval_meteor()` **and** `common_eval.eval_yisi()` **below unless you have them set up**. 

You will probably need to look at the `common_eval.eval_meteor()` and `common_eval.eval_yisi()` code and arguments to get it running on your environment.
- [meteor setup](https://www.cs.cmu.edu/~alavie/METEOR/README.html)
- [yisi setup](https://github.com/chikiulo/yisi)

In [6]:
import tree_text_gen.binary.translation.evaluate as evaluate
import tree_text_gen.binary.common.evaluate as common_eval
import tree_text_gen.binary.common.samplers as samplers

In [7]:
dataloader = validloader
# dataloader = testloader

end_tuned = True

for name, model in models.items():
    if end_tuned and 'tree' in name:
        model.module.stop_prob = 0.67
        name += '_end_tuned'
        
    print('=== %s ===' % (name))
    ms, predictions = evaluate.eval_dataset(model, dataloader)
    out = {}
    out['meteor'] = common_eval.eval_meteor()
    out['ribes'] = common_eval.eval_ribes()
    out['yisi'] = common_eval.eval_yisi()
    out['bleu'] = common_eval.eval_sacrebleu()
    pp(out)

  0%|          | 0/32 [00:00<?, ?it/s]

=== leftright ===


100%|██████████| 32/32 [00:40<00:00,  1.23s/it]
# RIBES evaluation start at 2019-04-30 16:37:33.497879
# RIBES evaluation done at 2019-04-30 16:37:33.708625
  0%|          | 0/32 [00:00<?, ?it/s]

BLEU+case.mixed+numrefs.1+smooth.none+tok.13a+version.1.2.12 = 32.3 65.7/41.1/28.1/19.7 (BP = 0.923 ratio = 0.926 hyp_len = 20841 ref_len = 22509)

{'bleu': 32.3,
 'meteor': 0.3195953154322145,
 'ribes': 0.848027,
 'yisi': 0.694072}
=== annealed_tree_end_tuned ===


100%|██████████| 32/32 [00:53<00:00,  1.53s/it]
# RIBES evaluation start at 2019-04-30 16:39:06.556216
# RIBES evaluation done at 2019-04-30 16:39:06.748772
  0%|          | 0/32 [00:00<?, ?it/s]

BLEU+case.mixed+numrefs.1+smooth.none+tok.13a+version.1.2.12 = 29.1 61.4/35.9/22.9/15.1 (BP = 0.986 ratio = 0.986 hyp_len = 22193 ref_len = 22509)

{'bleu': 29.1,
 'meteor': 0.30999861976223303,
 'ribes': 0.835089,
 'yisi': 0.688148}
=== uniform ===


100%|██████████| 32/32 [00:41<00:00,  1.29s/it]
# RIBES evaluation start at 2019-04-30 16:40:26.363024
# RIBES evaluation done at 2019-04-30 16:40:26.589651
  0%|          | 0/32 [00:00<?, ?it/s]

BLEU+case.mixed+numrefs.1+smooth.none+tok.13a+version.1.2.12 = 24.5 64.3/36.3/22.2/14.1 (BP = 0.838 ratio = 0.849 hyp_len = 19119 ref_len = 22509)

{'bleu': 24.5,
 'meteor': 0.27981935032263144,
 'ribes': 0.826607,
 'yisi': 0.663999}
=== annealed ===


100%|██████████| 32/32 [00:43<00:00,  1.34s/it]
# RIBES evaluation start at 2019-04-30 16:41:48.825809
# RIBES evaluation done at 2019-04-30 16:41:49.003298


BLEU+case.mixed+numrefs.1+smooth.none+tok.13a+version.1.2.12 = 26.8 64.6/37.1/23.2/15.3 (BP = 0.882 ratio = 0.889 hyp_len = 20007 ref_len = 22509)

{'bleu': 26.8,
 'meteor': 0.2967353647960145,
 'ribes': 0.836156,
 'yisi': 0.678767}
