In [5]:
import os
import math
import pickle
from pathlib import Path
from collections import Counter, defaultdict
from itertools import chain, product
from ptb_utils import parse_ptb_tree_data, partial_tree, linearize

splits = ["train", "dev", "test"]
mode2param = {
    "base": {"keep_pos": 0, "keep_dep": 0, "bracket": False, "rel": False},
    "brac": {"keep_pos": 0, "keep_dep": 0, "bracket": True, "rel": False},
    "pos":  {"keep_pos": 1, "keep_dep": 0, "bracket": True, "rel": False},
    "udep": {"keep_pos": 0, "keep_dep": 1, "bracket": True, "rel": False},
    "ldep": {"keep_pos": 0, "keep_dep": 1, "bracket": True, "rel": True},
    "full": {"keep_pos": 1, "keep_dep": 1, "bracket": True, "rel": True},
}
out = Path("data-raw")
data_path = Path("ptb_trees")
os.makedirs(out, exist_ok=True)


In [6]:
# load ptb trees
data = dict()
for split in splits:
    data[split] = parse_ptb_tree_data(data_path / "{}.txt".format(split))

In [7]:
# update special tokens
# they should be kept intact during tokenization
pcount = Counter()
rcount = Counter()
specials = []
for nodes in chain.from_iterable(data.values()):
    pcount.update(node.pos for node in nodes)
    rcount.update(node.rel for node in nodes)

if not specials:
    specials = [nodes[0].start, nodes[0].end]
specials.extend(next(zip(*pcount.most_common())))
specials.extend(next(zip(*rcount.most_common())))
with open(out / "new_tokens.txt", "w") as fout:
    for token in specials:
        fout.write("{}\n".format(token))

In [8]:
# Differet types of inputs share the same targets
for split in ["train", "dev", "test"]:
    with open(out / "{}.tgt".format(split), "w") as fout:
        entries = data[split]
        word_lists = []
        for nodes in entries:
            words = [node.word.strip() for node in nodes]
            word_lists.append(words)
            line = " ".join(words)
            fout.write(line + "\n")
    if split != "train":
        # the pkl keeps the word and phrase boundary, useful in constrained decoding
        with open(out / "{}.tgt.pkl".format(split), "wb") as fout:
            pickle.dump(word_lists, fout)


In [9]:
def write_src_without_repeat(trees, path_out, get_line, uniqs):
    # avoid duplicated source permutations with best effort try
    with open(path_out, "w") as fout:
        for tree in trees:
            # resample if duplicate with limited retries
            for _ in range(30):
                line = get_line(tree)
                if line not in uniqs:
                    break
            uniqs.add(line)
            fout.write(line + "\n")

In [10]:
# base sources with different permutations for augmentations
# corresponds to results in section 3.3 (main word ordering results)
# and 3.5 (effects of input permutation)
num_copy = 10
mode = "base"
param = mode2param[mode]
uniqs = defaultdict(set)  # avoid duplicated input permutation for data augmentation setting

for split in ["train", "dev", "test"]:
    print(split)
    entries = data[split]
    trees = [partial_tree(nodes, param["keep_pos"], param["keep_dep"]) for nodes in entries]
    for copy_id in range(0, num_copy + 1):
        if copy_id == 0:
            path_out = out / "{}.{}.src".format(mode, split)
        else:
            # additional
            path_out = out / "{}.{}.src.{}".format(mode, split, copy_id)
            if split == "test":
                break


        def get_line(tree):
            return linearize(tree, bracket=param["bracket"], shuffle=True, rel=param["rel"])


        write_src_without_repeat(trees, path_out, get_line, uniqs[split])


train
dev
test


In [11]:
# sources with only PUNC to simulate unconditional generation
# corresponds to results in section 3.5 (Effects of conditional modeling)
source = "<PUNCT:>"
mode = "base"
param = mode2param[mode]
for split in ["train", "dev", "test"]:
    entries = data[split]
    with open(out / "u{}.{}.src".format(mode, split), "w") as fout:
        for _ in entries:
            fout.write(source + "\n")

In [12]:
# sources with different input features
# corresponds to results in section 4 (understanding why BART helps)
for mode, param in mode2param.items():
    if mode == "base":
        continue
    print(mode)
    for split in ["train", "dev", "test"]:
        entries = data[split]
        trees = [partial_tree(nodes, param["keep_pos"], param["keep_dep"]) for nodes in entries]
        uniqs = set()  # keep the number of duplicate inputs minimal
        path_out = out / "{}.{}.src".format(mode, split)


        def get_line(tree):
            return linearize(tree, bracket=param["bracket"], shuffle=True, rel=param["rel"])


        write_src_without_repeat(trees, path_out, get_line, uniqs)


brac
pos
udep
ldep
full


In [13]:
# sources for partial tree linearization
# corresponds to results in section 5 (extension to partial tree linearization)
for keep_pos, keep_dep in product([0, 0.5, 1], [0, 0.5, 1]):
    print(keep_pos, keep_dep)
    for split in ["train", "dev"]:
        entries = data[split]
        trees = [partial_tree(nodes, keep_pos, keep_dep) for nodes in entries]
        uniqs = set()
        path_out = out / "{}.{}-{}.{}.src".format("part", keep_pos, keep_dep, split)


        def get_line(tree):
            return linearize(tree, bracket=True, shuffle=True, rel=True)


        write_src_without_repeat(trees, path_out, get_line, uniqs)

0 0
0 0.5
0 1
0.5 0
0.5 0.5
0.5 1
1 0
1 0.5
1 1


In [14]:
# sources for bnp partial tree linearization
# corresponds to results in section 5 (extension to partial tree linearization)
bnpdata = dict()
for split in splits:
    bnpdata[split] = parse_ptb_tree_data(data_path / "{}.bnp.txt".format(split))

# targets constraints
for split in ["dev"]:
    entries = bnpdata[split]
    word_lists = []
    for nodes in entries:
        words = [node.word.strip() for node in nodes]
        word_lists.append(words)
    with open(out / "bnp.{}.tgt.pkl".format(split), "wb") as fout:
        # this pkl is used in constrained decoding.
        # we need to keep the word and phrase boundary
        pickle.dump(word_lists, fout)

for keep_pos, keep_dep in product([0, 0.5, 1], [0, 0.5, 1]):
    print(keep_pos, keep_dep)
    for split in ["train", "dev"]:
        entries = bnpdata[split]
        trees = [partial_tree(nodes, keep_pos, keep_dep) for nodes in entries]
        uniqs = set()
        path_out = out / "{}.{}-{}.{}.src".format("bnp", keep_pos, keep_dep, split)


        def get_line(tree):
            return linearize(tree, bracket=True, shuffle=True, rel=True)


        write_src_without_repeat(trees, path_out, get_line, uniqs)

0 0
0 0.5
0 1
0.5 0
0.5 0.5
0.5 1
1 0
1 0.5
1 1
