Multi-problem training is possible by defining MultiProblem sub-classes that specify a list of Problem objects to include in training. In some cases, multi-problem training can be used to improve performance compared to training on individual problems.
In the following sections we'll discuss MultiProblem from a usage perspective followed by that of someone wishing to build upon it.
Please note the T2T Walkthrough documentation is a good place to start to understand the variety of component concepts we'll build on here.
In this discussion we'll consider the following (large) multi-problem that includes ten different sub-problems. These include:
- A language modeling problem operating on a corpus of German, English, French, and Romanian language wikipedia articles.
- Multiple compatible pairwise language translation problems (En -> De, En -> Fr, En -> Ro, De -> En, Fr -> En, Ro -> En)
- A compatible version of the combined CNN/DailyMail news article summarization problem.
- A compatible version of the MultiNLI textual entailment classification problem.
- A compatible version of the SQuAD question/answer problem.
@registry.register_problem
class LanguagemodelMultiWikiTranslate(multi_problem.MultiProblem):
"""Wiki multi-lingual LM and multiple translations."""
def __init__(self, was_reversed=False, was_copy=False):
super(LanguagemodelMultiWikiTranslate, self).__init__(
was_reversed, was_copy)
self.task_list.append(wiki_lm.LanguagemodelDeEnFrRoWiki64k())
self.task_list.append(translate_ende.TranslateEndeWmtMulti64k())
self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k())
self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k())
self.task_list.append(translate_ende.TranslateEndeWmtMulti64k(
was_reversed=True))
self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k(
was_reversed=True))
self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k(
was_reversed=True))
self.task_list.append(
cnn_dailymail.SummarizeCnnDailymailWikiLMMultiVocab64k())
self.task_list.append(multinli.MultiNLIWikiLMMultiVocab64k())
self.task_list.append(squad.SquadConcatMulti64k())
@property
def vocab_type(self):
return text_problems.VocabType.SUBWORD
The word "compatible" was used a lot above! That's because each of these problems have been modified to use the vocabulary produced by the Wikipedia-based language modeling problem, e.g. the following
@registry.register_problem
class SummarizeCnnDailymailWikiLMMultiVocab64k(SummarizeCnnDailymail32k):
"""Summarize CNN and Daily Mail articles using multi-lingual 64k vocab."""
@property
def vocab_filename(self):
return wiki_lm.LanguagemodelDeEnFrRoWiki64k().vocab_filename
Important note: It's easy to miss the key point that, as implemented currently, the first task in the task list must be a language modelling problem and each included task must be modified to use the resulting vocabulary.
With a properly defined and registered multi-problem we can now run datagen as follows:
t2t-datagen --problem=languagemodel_multi_wiki_translate
This will take approximately the following amount of space (and several hours):
(t2t) username@instance-2:~$ du -sh /tmp
99G /tmp
(t2t) username@instance-2:~$ du -sh /tmp/t2t_datagen
81G /tmp/t2t_datagen
Next we're ready to try training a model on this MultiProblem. Note that by not specifying --data_dir
above TFExample's were by default generated into /tmp so that's what we'll explicitly provide here.
t2t-trainer --problem=languagemodel_multi_wiki_translate \
--model=transformer \
--hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \
--output_dir ~/t2t_train/transformer_multi_2jan19 \
--data_dir=/tmp \
--train_steps=1 \
--eval_steps=1
The hparams_set
parameter we provided above was transformer_tall_pretrain_lm_tpu_adafactor_large, also provided below:
@registry.register_hparams
def transformer_tall_pretrain_lm_tpu_adafactor_large():
"""Hparams for transformer on LM pretraining on TPU, large model."""
hparams = transformer_tall_pretrain_lm_tpu_adafactor()
hparams.hidden_size = 1024
hparams.num_heads = 16
hparams.filter_size = 32768 # max fitting in 16G memory is 49152, batch 2
hparams.batch_size = 4
hparams.multiproblem_mixing_schedule = "constant"
# Task order: lm/en-de/en-fr/en-ro/de-en/fr-en/ro-en/cnndm/mnli/squad.
hparams.multiproblem_per_task_threshold = "320,80,160,2,80,160,2,20,5,5"
return hparams
Here it's worth noting a couple things, one that we have specified a multi_problem_mixing_schedule
(which is required), consumed by MultiProblem.mix_data. When set to "constant" the strategy for sampling examples is not a function of step and is proportional only to the per-task "thresholds" which are by default equal (sample examples from each problem with equal probability).
But notice we have also specified the (non-required) multiproblem_per_task_threshold
parameter, also consumed by mix_data, and specifically used by sample_task which defines non-uniform thresholds to inform a weighted random sampling. E.g. for two problems with weights 1 and 9 the first would be sampled 1/10 of the time and the other 9/10.
You can try translating from English to German using a model previously trained on LanguagemodelMultiWikiTranslate
(the one shown above) (gs://tensor2tensor-checkpoints/transformer_multi_2jan19/). Just copy the checkpoint down to a local directory such as the one given via --output_dir
below:
t2t-decoder --problem=languagemodel_multi_wiki_translate \
--model=transformer \
--hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \
--decode_hparams='batch_size=1,multiproblem_task_id=64510' \
--hparams="" \
--output_dir=~/t2t_train/transformer_multi_2jan19 \
--decode_from_file ~/newstest2014.en \
--data_dir=~/t2t_train/transformer_multi_2jan19
Here we'll point --data_dir
to the checkpoint directory which includes the vocab file vocab.languagemodel_de_en_fr_ro_wiki64k.64000.subwords
; typically data_dir would point to the directory containing your TFRecord example dataset(s).
The file passed to --decode_from_file
is simply a file with one sentence to translate on each line (in its original form, not post-vocabulary-encoded).
A key requirement for multi-problem inference is that we specify the ID of the problem for which we want to perform inference. But wait, why is the task ID 64510? We can see from the code for MultiProblem.update_task_ids
that TID's have a place at the end of the vocabulary.
class MultiProblem(problem.Problem):
"""MultiProblem base class."""
...
def update_task_ids(self, encoder_vocab_size):
"""Generate task_ids for each problem.
These ids correspond to the index of the task in the task_list.
Args:
encoder_vocab_size: the size of the vocab which is used to compute
the index offset.
"""
for idx, task in enumerate(self.task_list):
task.set_task_id(idx + encoder_vocab_size)
tf.logging.info("Task %d (%s) has id %d." %
(idx, task.name, task.task_id))
We can look up the task_id that is assigned to each task we may want to use for inference by instantiating the MultiProblem subclass and obtaining the value, in this case via the following:
task_index = 1 # The second task in the list is En -> De
LanguagemodelMultiWikiTranslate().task_list[task_index].task_id
For me running the t2t-decode
command provided above gave the following output:
...
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference results INPUT: hello world was the news of the day
INFO:tensorflow:Inference results OUTPUT: Hallo Welt war die Nachricht des Tages
INFO:tensorflow:Elapsed Time: 37.15079
INFO:tensorflow:Averaged Single Token Generation Time: 3.3009222 (time 36.3101439 count 11)
...