## Unconditional Generation

In [2]:
%load_ext autoreload
%autoreload 2
%pylab inline
import tree_text_gen.binary.unconditional.evaluate as evaluate
from pprint import pprint as pp
import os
from glob import glob

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Populating the interactive namespace from numpy and matplotlib


### SETUP

First extract the models downloaded from [here](https://drive.google.com/file/d/1HmtxtzGG3tvQBk6tPtKn_OsMnA0LcjqX/view?usp=sharing) into the `tree_text_gen` base directory. The cell below will then load the correct experiment directories. The cell prints the directories and the contents of an example experiment directory.

In [4]:
import tree_text_gen
project_dir = os.path.abspath(os.path.join(os.path.dirname(tree_text_gen.__file__), os.pardir))
dirs = glob(os.path.join(project_dir, 'models/unconditional/*'))
pp(dirs)
d = dirs[0]
!ls $d

['/home/sw1986/projects/phd/tree_text_gen/models/unconditional/leftright',
 '/home/sw1986/projects/phd/tree_text_gen/models/unconditional/uniform',
 '/home/sw1986/projects/phd/tree_text_gen/models/unconditional/annealed']
leftright.checkpoint  model_config.json  tok2i.json


### Load each model specified in `exprs`

In [5]:
CHECKPOINT = True

exprs = {}
for d in dirs:
    expr_name = d.split('/')[-1]
    exprs[expr_name] = d

models = {}
for k, v in exprs.items():
    print(k)
    models[k] = evaluate.load_model(v, k, checkpoint=CHECKPOINT)


leftright
uniform
annealed


### Sample Outputs

#### Visualization

Below shows samples from each model using the specified sampler.

Select `show_trees` to print out corresponding trees.

In [8]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import torch.nn.functional as F
import torch as th
import seaborn as sns
from tree_text_gen.binary.common.data import load_personachat, build_tok2i, SentenceDataset, inds2toks
from torch.utils.data.dataloader import DataLoader
import tree_text_gen.binary.common.samplers as samplers

def run(topk, show_trees, n=fixed(10)):
    for name, model in models.items():
        k = topk
        if k == -1:
            model.sampler.eval_sampler = samplers.StochasticSampler()
        else:
            model.sampler.eval_sampler = samplers.TopkSampler(k, model.device)
        out = evaluate.sample(model, n)
        
        print('=== %s ===' % name)
        for o in out:
            if show_trees:
                evaluate.print_output(o)
            else:
                print(' '.join(o['inorder_tokens']))
        print('---------------------------------------------------')

print("topk is the k for a topk-sampler (-1 for k = all)")
interact(run, topk=[-1, 3, 10, 100, 1000, 10000], show_trees=False);    

topk is the k for a topk-sampler (-1 for k = all)


interactive(children=(Dropdown(description='topk', options=(-1, 3, 10, 100, 1000, 10000), value=-1), Checkbox(…

### Tree Completion

Select root, left child, and right child words for a seed tree. Sampled tree completions are then shown for each model. Select `show_trees` to show each sample's corresponding tree.

In [19]:
words = list(list(models.values())[0].tok2i.keys())

def run(root='favorite', lchild='my', rchild='!', show_trees=False, topk=100, n=fixed(5)):
    for name in ['uniform', 'annealed', 'leftright']:
        model = models[name]
        k = topk
        if k == -1:
            model.sampler.eval_sampler = samplers.StochasticSampler()
        else:
            model.sampler.eval_sampler = samplers.TopkSampler(k, model.device)
        tree_prefix_tokens = [root, lchild, rchild]
        out, scores, samples = evaluate.sample_with_prefix(model, tree_prefix_tokens, n=5)
        print('=== %s ===' % name)
        for o in out:
            if show_trees:
                evaluate.print_output(o)
            else:
                print(' '.join(o['inorder_tokens']))
        print('---------------------------------------------------')

interact(run, root=words, lchild=words, rchild=words, topk=[-1,10,50,100,1000], show_trees=False);

interactive(children=(Dropdown(description='root', index=28, options=('<s>', '<p>', '</s>', '<unk>', '<end>', …