In [7]:
import os
import random
import re
from pathlib import Path
from itertools import chain, product
from collections import namedtuple, defaultdict
import sentencepiece as spm
import pickle
import glob

raw = Path("data-raw")
rand = Path("data-rand")
bart = Path("data-bart")
os.makedirs(rand, exist_ok=True)
os.makedirs(bart, exist_ok=True)

In [8]:
# Train BPE tokenizer for RAND
with open(raw/"new_tokens.txt", "r") as fin:
    special_tokens = [line.strip() for line in fin]
non_prefix = re.compile("({})".format("|".join(re.escape(i) for i in special_tokens)))

def process_line(line):
    line = " ".join(line.strip().split())
    if non_prefix.match(line):
        # do not append space if the first token is special token
        return line
    else:
        return " " + line

# use the same bpe tokens for simplicity
# all.train.src should contain all related tokens
with open(raw/"all.train.src", "r") as fin:
    spm.SentencePieceTrainer.train(sentence_iterator=map(process_line, fin),
                                   model_prefix=str(rand/"spm"),
                                   vocab_size=8000,
                                   character_coverage=1.0,
                                   model_type="bpe",
                                   split_by_whitespace=True,
                                   user_defined_symbols=",".join(special_tokens),
                                   add_dummy_prefix=False,
                                   remove_extra_whitespaces=False)
sp = spm.SentencePieceProcessor(model_file=str(rand/"spm.model"))


In [9]:
# Apply the tokenizer for RAND
for file in chain(raw.glob("*.src*"), raw.glob("*.tgt*")):
    if "pkl" in file.name:
        continue
    print(file)
    with open(file, "r") as fin, open(rand/file.name, "w") as fout:
        for line in fin:
            fout.write(" ".join(sp.encode(process_line(line), out_type=str)) + "\n")

for file in raw.glob("*.tgt.pkl"):
    with open(file, "rb") as fin, open(rand/file.name.replace("pkl", "const"), "w") as fout:
        word_lists = pickle.load(fin)
        for words in word_lists:
            line = "\t".join(" ".join(sp.encode(process_line(word), out_type=str)) for word in words)
            fout.write(line + "\n")

data-raw/bnp.0-0.dev.src
data-raw/bnp.0-1.train.src
data-raw/bnp.0-0.5.dev.src
data-raw/base.dev.src.5
data-raw/base.dev.src.1
data-raw/part.0.5-0.train.src
data-raw/base.dev.src.4
data-raw/full.test.src
data-raw/part.1-0.5.dev.src
data-raw/bnp.1-1.dev.src
data-raw/base.dev.src
data-raw/base.dev.src.8
data-raw/bnp.1-0.5.dev.src
data-raw/base.train.src.1
data-raw/base.train.src
data-raw/base.dev.src.7
data-raw/base.train.src.7
data-raw/base.train.src.2
data-raw/ldep.train.src
data-raw/part.0.5-0.5.dev.src
data-raw/part.1-1.train.src
data-raw/part.1-1.dev.src
data-raw/udep.train.src
data-raw/base.train.src.5
data-raw/base.train.src.8
data-raw/bnp.0-0.5.train.src
data-raw/ldep.dev.src
data-raw/ubase.test.src
data-raw/bnp.0.5-0.dev.src
data-raw/brac.test.src
data-raw/base.train.src.4
data-raw/part.0.5-0.5.train.src
data-raw/base.train.src.9
data-raw/base.test.src
data-raw/bnp.1-1.train.src
data-raw/udep.dev.src
data-raw/base.train.src.6
data-raw/part.1-0.dev.src
data-raw/bnp.0.5-0.5.train.

In [14]:
# Prepare BPE tokenizers for BART
# download bart base and extract it to ./bart
# https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz
# download these files to ./bart
# https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
# https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe

# update the bart vocabulary with our special tokens
with open(raw/"new_tokens.txt", "r") as fin:
    special_tokens = [line.strip() for line in fin]
    with open("./bart/dict.txt", "r") as fin, \
         open(bart/"dict.txt", "w") as fout:
        for i, line in enumerate(fin):
            idx, count = line.strip().split(" ")
            if int(count) == 0 and "madeupword" in idx and special_tokens:
                entry = (special_tokens.pop(), 1)
            else:
                entry = (idx, count)
            fout.write("{} {}\n".format(entry[0], entry[1]))

In [15]:
# Apply BART BPE tokenization
# please chek the path to encoder.json and vocab.bpe!
cmd = "python -m examples.roberta.multiprocessing_bpe_encoder " \
      "--encoder-json bart/encoder.json " \
      "--vocab-bpe bart/vocab.bpe " \
      "--inputs {}  --outputs {} --workers 10 --keep-empty"
for file in chain(raw.glob("*.src*"), raw.glob("*.tgt*")):
    if "pkl" in file.name:
        continue
    print(file)
    os.system(cmd.format(file, str(bart / file.name)))

for file in raw.glob("*.tgt.pkl"):
    with open(file, "rb") as fin, open(bart / file.name.replace("pkl", "const"), "w") as fout:
        word_lists = pickle.load(fin)
        # as BART decodes with fairseq interactive mode, which takes raw text inputs,
        # we use raw text for constraints.
        for words in word_lists:
            fout.write("\t".join(words) + "\n")


data-raw/bnp.0-0.dev.src
data-raw/bnp.0-1.train.src
data-raw/bnp.0-0.5.dev.src
data-raw/base.dev.src.5
data-raw/base.dev.src.1
data-raw/part.0.5-0.train.src
data-raw/base.dev.src.4
data-raw/full.test.src
data-raw/part.1-0.5.dev.src
data-raw/bnp.1-1.dev.src
data-raw/base.dev.src
data-raw/base.dev.src.8
data-raw/bnp.1-0.5.dev.src
data-raw/base.train.src.1
data-raw/base.train.src
data-raw/base.dev.src.7
data-raw/base.train.src.7
data-raw/base.train.src.2
data-raw/ldep.train.src
data-raw/part.0.5-0.5.dev.src
data-raw/part.1-1.train.src
data-raw/part.1-1.dev.src
data-raw/udep.train.src
data-raw/base.train.src.5
data-raw/base.train.src.8
data-raw/bnp.0-0.5.train.src
data-raw/ldep.dev.src
data-raw/ubase.test.src
data-raw/bnp.0.5-0.dev.src
data-raw/brac.test.src
data-raw/base.train.src.4
data-raw/part.0.5-0.5.train.src
data-raw/base.train.src.9
data-raw/base.test.src
data-raw/bnp.1-1.train.src
data-raw/udep.dev.src
data-raw/base.train.src.6
data-raw/part.1-0.dev.src
data-raw/bnp.0.5-0.5.train.

In [17]:
# Settings to assemble datasets for each experiments
settings = {}
pair = namedtuple("Pair", ["src", "tgt", "const"])

# "ubase": unconstrained decoding with only BoW inputs
for case in ["base", "brac", "ubase", "pos", "udep", "ldep", "all"]:
    settings[case] = {
        "train": [pair(src=["{}.train.src".format(case)], tgt=["train.tgt"], const=None)],
        "valid": [pair(src=["{}.dev.src".format(case)], tgt=["dev.tgt"], const=["dev.tgt.const"])],
        "test": [pair(src=["{}.test.src".format(case)], tgt=["test.tgt"], const=["test.tgt.const"])]
    }

for case in ["base0", "base2", "base4", "base6", "base8"]:
    count = int(re.search("\d+$", case)[0])
    config = {}
    config["train"] = [pair(
        src=["base.train.src"] + ["base.train.src.{}".format(i) for i in range(1, count + 1)],
        tgt=["train.tgt"] * (count + 1), const=None
    )]
    config["valid"] = [pair(src=["base.dev.src"], tgt=["dev.tgt"], const=None)]
    config["test"] = [pair(src=["base.test.src"], tgt=["test.tgt"], const=["test.tgt.const"])]
    config["test"] += [pair(src=["base.dev.src.{}".format(i)], tgt=["dev.tgt"], const=["dev.tgt.const"]) for i in range(1, 11)]
    settings[case] = config

# base setting with all subword tokens shuffled
settings["sbase"] = settings["base0"]

indices = list(product([0, 0.5, 1], [0, 0.5, 1]))
for case in ["bnp", "part"]:
    config = {}
    config["train"] = [pair(src=["{}.{}-{}.train.src".format(case, p, d) for p, d in indices],
                           tgt=["train.tgt" for _ in indices], const=None)]
    config["valid"] = [pair(src=["{}.{}-{}.dev.src".format(case, p, d) for p, d in indices],
                           tgt=["dev.tgt" for _ in indices], const=None)]
    config["test"] = [pair(src=["{}.{}-{}.dev.src".format(case, p, d)],
                           tgt=["dev.tgt"], const=["bnp.dev.tgt.const" if case == "bnp" else "dev.tgt.const"])
                      for p, d in indices]
    settings[case] = config

In [11]:
# Assemble dataset for RAND (from scratch) training
for case, setting in settings.items():
    case_out = rand/case
    case_raw = case_out/"raw"
    os.makedirs(case_raw, exist_ok=True)
    prefixes = defaultdict(list)
    print(case)
    for split, pairs in setting.items():
        for i, pair in enumerate(pairs):
            prefix = "{}{}".format(split, str(i) if i else "")
            prefixes[split].append(str(case_raw/prefix))
            src_in = " ".join(str(rand/f) for f in pair.src)
            src_out = case_raw/"{}.src".format(prefix)
            os.system("cat {} > {}".format(src_in, src_out))
            if case == "sbase" and split != "test":
                src_out_bak = case_raw/"{}.src.bak".format(prefix)
                os.system("cp {} {}".format(src_out, src_out_bak))
                with open(src_out_bak, "r") as fin, open(src_out, "w") as fout:
                    for line in fin:
                        tokens = line.strip().split()
                        random.shuffle(tokens)
                        fout.write(" ".join(tokens) + "\n")
                os.system("rm {}".format(src_out_bak))

            tgt_in = " ".join(str(rand/f) for f in pair.tgt)
            tgt_out = case_raw/"{}.tgt".format(prefix)
            os.system("cat {} > {}".format(tgt_in, tgt_out))

            if pair.const is not None:
                const_in = " ".join(str(rand/f) for f in pair.const)
                const_out = case_raw/"{}.const".format(prefix)
                os.system("cat {} > {}".format(const_in, const_out))
                with open(src_out, "r") as sin, open(const_out, "r") as cin, \
                     open(case_out/"{}.mix".format(prefix), "w") as fout:
                    for sline, cline in zip(sin, cin):
                        fout.write("{}\t{}\n".format(sline.strip(), cline.strip()))
                os.system("rm {}".format(const_out))
    cmd = "fairseq-preprocess --source-lang src --target-lang tgt --trainpref {} --validpref {} --testpref {} --destdir {} --workers 20 --joined-dictionary".format(
            ",".join(prefixes["train"]), ",".join(prefixes["valid"]), ",".join(prefixes["test"]), case_out
    )
    os.system(cmd)

base
brac
ubase
pos
udep
ldep
full
base0
base2
base4
base6
base8
sbase
bnp
part


In [18]:
# Assemble dataset for BART finetuning
for case, setting in settings.items():
    case_out = bart/case
    case_raw = case_out/"raw"
    os.makedirs(case_raw, exist_ok=True)
    prefixes = defaultdict(list)
    print(case)
    for split, pairs in setting.items():
        for i, pair in enumerate(pairs):
            prefix = "{}{}".format(split, str(i) if i else "")
            prefixes[split].append(str(case_raw/prefix))
            src_in = " ".join(str(bart/f) for f in pair.src)
            src_out = case_raw/"{}.src".format(prefix)
            os.system("cat {} > {}".format(src_in, src_out))
            if case == "sbase" and split != "test":
                src_out_bak = case_raw/"{}.src.bak".format(prefix)
                os.system("cp {} {}".format(src_out, src_out_bak))
                with open(src_out_bak, "r") as fin, open(src_out, "w") as fout:
                    for line in fin:
                        tokens = line.strip().split()
                        random.shuffle(tokens)
                        fout.write(" ".join(tokens) + "\n")
                os.system("rm {}".format(src_out_bak))

            tgt_in = " ".join(str(bart/f) for f in pair.tgt)
            tgt_out = case_raw/"{}.tgt".format(prefix)
            os.system("cat {} > {}".format(tgt_in, tgt_out))

            if pair.const is not None:
                # since bpe encoder is specified in interactive mode, we use raw inputs.
                const_in = " ".join(str(bart/f) for f in pair.const)
                const_out = case_raw/"{}.const".format(prefix)
                os.system("cat {} > {}".format(const_in, const_out))

                raw_src_in = " ".join(str(raw/f) for f in pair.src)
                raw_src_out = case_raw/"{}.rawsrc".format(prefix)
                os.system("cat {} > {}".format(raw_src_in, raw_src_out))

                with open(raw_src_out, "r") as sin, open(const_out, "r") as cin, \
                     open(case_out/"{}.mix".format(prefix), "w") as fout:
                    for sline, cline in zip(sin, cin):
                        fout.write("{}\t{}\n".format(sline.strip(), cline.strip()))
                os.system("rm {}".format(const_out))
                os.system("rm {}".format(raw_src_out))
    cmd = "fairseq-preprocess --source-lang src --target-lang tgt --trainpref {} --validpref {} --testpref {}  --destdir {} --workers 20 --srcdict {} --tgtdict {}".format(
        ",".join(prefixes["train"]), ",".join(prefixes["valid"]), ",".join(prefixes["test"]), case_out, str(bart/"dict.txt"), str(bart/"dict.txt")
)
    os.system(cmd)

base
brac
ubase
pos
udep
ldep
all
base0
base2
base4
base6
base8
sbase
bnp
part
