Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Latest commit

 

History

History
188 lines (133 loc) · 9.56 KB

multi_problem.md

File metadata and controls

188 lines (133 loc) · 9.56 KB

Multi-problem training

Multi-problem training is possible by defining MultiProblem sub-classes that specify a list of Problem objects to include in training. In some cases, multi-problem training can be used to improve performance compared to training on individual problems.

In the following sections we'll discuss MultiProblem from a usage perspective followed by that of someone wishing to build upon it.

Please note the T2T Walkthrough documentation is a good place to start to understand the variety of component concepts we'll build on here.

Usage

Problem definition and datagen

In this discussion we'll consider the following (large) multi-problem that includes ten different sub-problems. These include:

  1. A language modeling problem operating on a corpus of German, English, French, and Romanian language wikipedia articles.
  2. Multiple compatible pairwise language translation problems (En -> De, En -> Fr, En -> Ro, De -> En, Fr -> En, Ro -> En)
  3. A compatible version of the combined CNN/DailyMail news article summarization problem.
  4. A compatible version of the MultiNLI textual entailment classification problem.
  5. A compatible version of the SQuAD question/answer problem.
@registry.register_problem
class LanguagemodelMultiWikiTranslate(multi_problem.MultiProblem):
  """Wiki multi-lingual LM and multiple translations."""

  def __init__(self, was_reversed=False, was_copy=False):
    super(LanguagemodelMultiWikiTranslate, self).__init__(
        was_reversed, was_copy)
    self.task_list.append(wiki_lm.LanguagemodelDeEnFrRoWiki64k())
    self.task_list.append(translate_ende.TranslateEndeWmtMulti64k())
    self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k())
    self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k())
    self.task_list.append(translate_ende.TranslateEndeWmtMulti64k(
        was_reversed=True))
    self.task_list.append(translate_enfr.TranslateEnfrWmtMulti64k(
        was_reversed=True))
    self.task_list.append(translate_enro.TranslateEnroWmtMultiTiny64k(
        was_reversed=True))
    self.task_list.append(
        cnn_dailymail.SummarizeCnnDailymailWikiLMMultiVocab64k())
    self.task_list.append(multinli.MultiNLIWikiLMMultiVocab64k())
    self.task_list.append(squad.SquadConcatMulti64k())

  @property
  def vocab_type(self):
    return text_problems.VocabType.SUBWORD

The word "compatible" was used a lot above! That's because each of these problems have been modified to use the vocabulary produced by the Wikipedia-based language modeling problem, e.g. the following

@registry.register_problem
class SummarizeCnnDailymailWikiLMMultiVocab64k(SummarizeCnnDailymail32k):
  """Summarize CNN and Daily Mail articles using multi-lingual 64k vocab."""

  @property
  def vocab_filename(self):
    return wiki_lm.LanguagemodelDeEnFrRoWiki64k().vocab_filename

Important note: It's easy to miss the key point that, as implemented currently, the first task in the task list must be a language modelling problem and each included task must be modified to use the resulting vocabulary.

With a properly defined and registered multi-problem we can now run datagen as follows:

t2t-datagen --problem=languagemodel_multi_wiki_translate

This will take approximately the following amount of space (and several hours):

(t2t) username@instance-2:~$ du -sh /tmp
99G     /tmp
(t2t) username@instance-2:~$ du -sh /tmp/t2t_datagen
81G     /tmp/t2t_datagen

Training

Next we're ready to try training a model on this MultiProblem. Note that by not specifying --data_dir above TFExample's were by default generated into /tmp so that's what we'll explicitly provide here.

t2t-trainer --problem=languagemodel_multi_wiki_translate \
    --model=transformer \
    --hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \
    --output_dir ~/t2t_train/transformer_multi_2jan19 \
    --data_dir=/tmp \
    --train_steps=1 \
    --eval_steps=1

The hparams_set parameter we provided above was transformer_tall_pretrain_lm_tpu_adafactor_large, also provided below:

@registry.register_hparams
def transformer_tall_pretrain_lm_tpu_adafactor_large():
  """Hparams for transformer on LM pretraining on TPU, large model."""
  hparams = transformer_tall_pretrain_lm_tpu_adafactor()
  hparams.hidden_size = 1024
  hparams.num_heads = 16
  hparams.filter_size = 32768  # max fitting in 16G memory is 49152, batch 2
  hparams.batch_size = 4
  hparams.multiproblem_mixing_schedule = "constant"
  # Task order: lm/en-de/en-fr/en-ro/de-en/fr-en/ro-en/cnndm/mnli/squad.
  hparams.multiproblem_per_task_threshold = "320,80,160,2,80,160,2,20,5,5"
  return hparams

Here it's worth noting a couple things, one that we have specified a multi_problem_mixing_schedule (which is required), consumed by MultiProblem.mix_data. When set to "constant" the strategy for sampling examples is not a function of step and is proportional only to the per-task "thresholds" which are by default equal (sample examples from each problem with equal probability).

But notice we have also specified the (non-required) multiproblem_per_task_threshold parameter, also consumed by mix_data, and specifically used by sample_task which defines non-uniform thresholds to inform a weighted random sampling. E.g. for two problems with weights 1 and 9 the first would be sampled 1/10 of the time and the other 9/10.

Inference

You can try translating from English to German using a model previously trained on LanguagemodelMultiWikiTranslate (the one shown above) (gs://tensor2tensor-checkpoints/transformer_multi_2jan19/). Just copy the checkpoint down to a local directory such as the one given via --output_dir below:

t2t-decoder --problem=languagemodel_multi_wiki_translate \
    --model=transformer \
    --hparams_set=transformer_tall_pretrain_lm_tpu_adafactor_large \
    --decode_hparams='batch_size=1,multiproblem_task_id=64510' \
    --hparams="" \
    --output_dir=~/t2t_train/transformer_multi_2jan19 \
    --decode_from_file ~/newstest2014.en \
    --data_dir=~/t2t_train/transformer_multi_2jan19

Here we'll point --data_dir to the checkpoint directory which includes the vocab file vocab.languagemodel_de_en_fr_ro_wiki64k.64000.subwords; typically data_dir would point to the directory containing your TFRecord example dataset(s).

The file passed to --decode_from_file is simply a file with one sentence to translate on each line (in its original form, not post-vocabulary-encoded).

A key requirement for multi-problem inference is that we specify the ID of the problem for which we want to perform inference. But wait, why is the task ID 64510? We can see from the code for MultiProblem.update_task_ids that TID's have a place at the end of the vocabulary.

class MultiProblem(problem.Problem):
  """MultiProblem base class."""

  ...

  def update_task_ids(self, encoder_vocab_size):
    """Generate task_ids for each problem.
    These ids correspond to the index of the task in the task_list.
    Args:
      encoder_vocab_size: the size of the vocab which is used to compute
        the index offset.
    """
    for idx, task in enumerate(self.task_list):
      task.set_task_id(idx + encoder_vocab_size)
      tf.logging.info("Task %d (%s) has id %d." %
                      (idx, task.name, task.task_id))

We can look up the task_id that is assigned to each task we may want to use for inference by instantiating the MultiProblem subclass and obtaining the value, in this case via the following:

task_index = 1 # The second task in the list is En -> De
LanguagemodelMultiWikiTranslate().task_list[task_index].task_id

For me running the t2t-decode command provided above gave the following output:

...

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference results INPUT: hello world was the news of the day
INFO:tensorflow:Inference results OUTPUT: Hallo Welt war die Nachricht des Tages
INFO:tensorflow:Elapsed Time: 37.15079
INFO:tensorflow:Averaged Single Token Generation Time: 3.3009222 (time 36.3101439 count 11)

...