## Word Reordering

## SETUP

1. First extract the models downloaded from [here](https://drive.google.com/file/d/1UX5_6E7vOiBzFoO0tCh5Pu5CzAWm2flE/view?usp=sharing) into the `tree_text_gen` base directory. 

2. Download the `personachat` sentences dataset here: [train](https://drive.google.com/file/d/1JnBTNKiVJGLB3tcyYo36NTNvvWXIsX8P/view?usp=sharing)
 [valid](https://drive.google.com/file/d/179s3ONLEqMEaGjdueuEC4XQyXhh7Sck0/view?usp=sharing) [test](https://drive.google.com/file/d/12VhPsvp-RgQzg9TowYKO6Ym2Wscd4o1A/view?usp=sharing). Put these files in the `data_dir` specified in the next step.
 
2. Set `data_dir` to the directory holding the dataset files from (2).

In [1]:
data_dir = '/home/sw1986/datasets/personachat/'

In [8]:
%load_ext autoreload
%autoreload 2
%pylab inline
import tree_text_gen.binary.bagorder.evaluate as evaluate
from pprint import pprint as pp
import os
from glob import glob

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Populating the interactive namespace from numpy and matplotlib


### Experiment Directories

The cell below will load the correct experiment directories if you did the setup steps correctly. The cell prints the directory names and the contents of an example experiment directory.

In [9]:
import tree_text_gen
project_dir = os.path.abspath(os.path.join(os.path.dirname(tree_text_gen.__file__), os.pardir))
dirs = glob(os.path.join(project_dir, 'models/bagorder/*'))
pp(dirs)
d = dirs[0]
!ls $d

['/home/sw1986/projects/phd/tree_text_gen/models/bagorder/leftright',
 '/home/sw1986/projects/phd/tree_text_gen/models/bagorder/uniform',
 '/home/sw1986/projects/phd/tree_text_gen/models/bagorder/annealed']
leftright  leftright.checkpoint  model_config.json  tok2i.json


### Load each model from `dirs`

`exprs` will map experiment/model name to pytorch model

In [11]:
CHECKPOINT = False

exprs = {}
for d in dirs:
    expr_name = d.split('/')[-1]
    exprs[expr_name] = d

models = {}
for k, v in exprs.items():
    print(k)
    models[k] = evaluate.load_model(v, k, checkpoint=CHECKPOINT)


leftright
uniform
annealed


### BLEU Evaluation

In [12]:
kind = 'valid'
# kind = 'test'

from tree_text_gen.binary.common.data import load_personachat, build_tok2i, SentenceDataset, inds2toks
from torch.utils.data.dataloader import DataLoader

tok2i = list(models.values())[0].tok2i
dataset = load_personachat(os.path.join(data_dir, 'personachat_all_sentences_%s.jsonl' % kind))
dataset = SentenceDataset(dataset, tok2i)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=dataset.collate, drop_last=True)

16181 sentences


In [13]:
from pprint import pprint as pp
print(kind)
for name, model in models.items():
    print("=== %s === " % (name))
    ms, predictions = evaluate.eval_dataset(model, dataloader)
    print(('em', ms['eval/em']), ('f1', ms['eval/f1']))
    !cat __hyp.txt | sacrebleu --force --smooth none __ref.txt

!rm __hyp.txt
!rm __ref.txt

  0%|          | 0/505 [00:00<?, ?it/s]

valid
=== leftright === 


100%|██████████| 505/505 [01:35<00:00,  5.31it/s]


('em', 0.2301) ('f1', 0.91)
BLEU+case.mixed+numrefs.1+smooth.none+tok.13a+version.1.2.12 = 46.6 89.5/53.7/36.7/27.1 (BP = 0.997 ratio = 0.997 hyp_len = 195217 ref_len = 195797)


  0%|          | 0/505 [00:00<?, ?it/s]

=== uniform === 


100%|██████████| 505/505 [01:36<00:00,  5.20it/s]


('em', 0.2095) ('f1', 0.9676)
BLEU+case.mixed+numrefs.1+smooth.none+tok.13a+version.1.2.12 = 44.7 96.6/53.7/33.7/23.4 (BP = 0.995 ratio = 0.995 hyp_len = 194822 ref_len = 195789)


  0%|          | 0/505 [00:00<?, ?it/s]

=== annealed === 


100%|██████████| 505/505 [01:36<00:00,  5.32it/s]


('em', 0.23) ('f1', 0.9613)
BLEU+case.mixed+numrefs.1+smooth.none+tok.13a+version.1.2.12 = 46.8 96.1/55.6/36.3/25.9 (BP = 0.989 ratio = 0.989 hyp_len = 193595 ref_len = 195768)


## Extra Visualizations

#### Utilities

In [11]:
import torch.nn.functional as F
import torch as th
import seaborn as sns

def viz_step_distributions(model, scores, preds, x, out):
    ps = F.softmax(scores[0], dim=1)
    content_timesteps = ((preds[0] != model.tok2i['<end>']) & (preds[0] != model.tok2i['<p>'])).nonzero().view(-1)[:len(out['genorder_tokens'])]
    input_word_idxs = x[0][x[0] != model.tok2i['</s>']]

    label_ps = []
    for t in content_timesteps:
        label_ps.append(ps[t][input_word_idxs])

    label_ps = th.stack(label_ps).cpu().numpy()
    fig, ax = plt.subplots(1)
    sns.heatmap(label_ps, ax=ax, cmap='gist_gray', cbar=True, vmin=0., vmax=1.)
    ax.set_xticklabels(out['gt_tokens'], rotation='vertical')
    ax.set_yticklabels(list(enumerate(out['genorder_tokens'])), rotation='horizontal')
    ax.set_xlabel('ground-truth tokens')
    ax.set_ylabel('(time, generated word)')
    ax.set_title('Per-step probabilities over ground-truth tokens')
    plt.show()

#### Select and Display an example from the validation set

Each row of the heatmap shows a policy's token probabilities at a given step, with only probabilities for correct actions shown. The y-axis shows which token the policy sampled at that step.

The annealed policy tends to display a lower entropy distribution than the uniform policy. Intuitively, this shows the annealed policy's learned preferences over the set of valid actions at each step, while the uniform policy tries to put uniform mass on all valid actions.

In [28]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from tree_text_gen.binary.common.data import load_personachat, build_tok2i, SentenceDataset, inds2toks
from torch.utils.data.dataloader import DataLoader

valid = load_personachat(os.path.join(data_dir, 'personachat_all_sentences_valid.jsonl'))
sentences = list(sorted([' '.join(x['tokens']) for x in valid], key=lambda x: len(x)))

def run(idx):
    sentence = sentences[idx]
    for name, model in models.items():
        out, x, scores, preds = evaluate.eval_single(model, sentence)

        print('=== %s ===' % name)
        evaluate.print_output(out)
        viz_step_distributions(model, scores, preds, x, out)
        print('---------------------------------------------------')

interact(run, idx=(0, len(sentences)-1));

16181 sentences


interactive(children=(IntSlider(value=8090, description='idx', max=16180), Output()), _dom_classes=('widget-in…