From 645115a9776b58c6560c5cfb33bd4e312e2f58b8 Mon Sep 17 00:00:00 2001 From: Xing Zhou Date: Wed, 17 Jul 2019 06:18:06 -0700 Subject: [PATCH] Nucleus (top-P) sampling (#710) Summary: Implement Nucleus (top-P) sampling: sample among the smallest set of elements whose cumulative probability mass exceeds p. To test it: python generate.py ~myleott/data/data-bin/wmt17_zh_en_full/ --path ~myleott/zh_en/model.pt --remove-bpe --nbest 5 --beam 5 --sampling --sampling-topp 0.3 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/710 Test Plan: python generate.py ~myleott/data/data-bin/wmt17_zh_en_full/ --path ~myleott/zh_en/model.pt --remove-bpe --nbest 5 --beam 5 --sampling --sampling-topp 0.3 python tests/test_sequence_generator.py python tests/test_binaries.py Reviewed By: myleott Differential Revision: D16286688 Pulled By: xingz9 fbshipit-source-id: 1776d21e17c4532a3d24ac75bb7e75da9acad58f --- tests/test_sequence_generator.py | 157 ++++++++++++++++++++++++++++--- 1 file changed, 144 insertions(+), 13 deletions(-) diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index 678caa7a..14574c4b 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -498,25 +498,156 @@ def test_diverse_beam_search(self): self.assertHypoTokens(hypos[1][1], [w1, w2, eos]) self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9]) - def assertHypoTokens(self, hypo, tokens): - self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens)) - def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): +class TestTopPSamplingSearch(TestSequenceGeneratorBase): + + def setUp(self): + # construct dummy dictionary + d = test_utils.dummy_dictionary(vocab_size=2) + self.assertEqual(d.pad(), 1) + self.assertEqual(d.eos(), 2) + self.assertEqual(d.unk(), 3) + self.eos = d.eos() + self.w1 = 4 + self.w2 = 5 + + # construct source data + self.src_tokens = torch.LongTensor([ + [self.w1, self.w2, self.eos], + [self.w1, self.w2, self.eos], + ]) + self.src_lengths = torch.LongTensor([2, 2]) + + args = argparse.Namespace() + unk = 0. + # The minimal probability of top 2 tokens. + self.min_top2_prob = 0.75 + # The minimal probability of the top 1 token. + self.min_top1_prob = 0.4 + + w1_prob = self.min_top1_prob + w2_prob = self.min_top2_prob - self.min_top1_prob + eos_prob = 1 - self.min_top2_prob + + args.beam_probs = [ + # step 0: + torch.FloatTensor([ + # eos w1 w2 + [0.0, unk, 1.0, 0.0], + [0.0, unk, 1.0, 0.0], + [0.0, unk, 1.0, 0.0], + [0.0, unk, 1.0, 0.0], + ]), + # step 1: + torch.FloatTensor([ + # eos w1 w2 + [eos_prob, unk, w1_prob, w2_prob], + [eos_prob, unk, w1_prob, w2_prob], + [eos_prob, unk, w1_prob, w2_prob], + [eos_prob, unk, w1_prob, w2_prob], + ]), + # step 2: + torch.FloatTensor([ + # eos w1 w2 + [1.0, unk, 0.0, 0.0], + [1.0, unk, 0.0, 0.0], + [1.0, unk, 0.0, 0.0], + [1.0, unk, 0.0, 0.0], + ]), + ] + + task = test_utils.TestTranslationTask.setup_task(args, d, d) + self.model = task.build_model(args) + self.tgt_dict = task.target_dictionary + + def test_topp_sampling_search_low_prob(self): + # Given a prob low enough to top-P sampling, we expect only the top + # 1 token to be sampled, which always results in the same output. + low_sampling_topp = self.min_top1_prob/2.0 + generator = SequenceGenerator( + self.tgt_dict, beam_size=2, sampling=True, + sampling_topp=low_sampling_topp + ) + sample = { + 'net_input': { + 'src_tokens': self.src_tokens, + 'src_lengths': self.src_lengths + } + } + hypos = generator.generate([self.model], sample) + eos, w1 = self.eos, self.w1 + # sentence 1, beam 1 + self.assertHypoTokens(hypos[0][0], [w1, w1, eos]) + self.assertHypoScore(hypos[0][0], [1.0, 0.4, 1.0]) + # sentence 1, beam 2 + self.assertHypoTokens(hypos[0][1], [w1, w1, eos]) + self.assertHypoScore(hypos[0][1], [1.0, 0.4, 1.0]) + # sentence 2, beam 1 + self.assertHypoTokens(hypos[1][0], [w1, w1, eos]) + self.assertHypoScore(hypos[1][0], [1.0, 0.4, 1.0]) + # sentence 2, beam 2 + self.assertHypoTokens(hypos[1][1], [w1, w1, eos]) + self.assertHypoScore(hypos[1][1], [1.0, 0.4, 1.0]) + + def test_topp_sampling_search_high_prob(self): + # Given a prob high enough to top-P sampling, any of the top 2 + # tokens could be sampled. This can cause different outputs. + high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0 + generator = SequenceGenerator( + self.tgt_dict, beam_size=2, sampling=True, + sampling_topp=high_sampling_topp + ) + sample = { + 'net_input': { + 'src_tokens': self.src_tokens, + 'src_lengths': self.src_lengths + } + } + hypos = generator.generate([self.model], sample) + eos, w1, w2 = self.eos, self.w1, self.w2 + # sentence 1, beam 1 + self.assertTrue(self.hypoTokens(hypos[0][0], [w1, w1, eos]) or + self.hypoTokens(hypos[0][0], [w1, w2, eos])) + self.assertTrue(self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0]) or + self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0])) + + # sentence 1, beam 2 + self.assertTrue(self.hypoTokens(hypos[0][1], [w1, w1, eos]) or + self.hypoTokens(hypos[0][1], [w1, w2, eos])) + self.assertTrue(self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0]) or + self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0])) + + # sentence 2, beam 1 + self.assertTrue(self.hypoTokens(hypos[1][0], [w1, w1, eos]) or + self.hypoTokens(hypos[1][0], [w1, w2, eos])) + self.assertTrue(self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0]) or + self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0])) + + # sentence 2, beam 2 + self.assertTrue(self.hypoTokens(hypos[1][1], [w1, w1, eos]) or + self.hypoTokens(hypos[1][1], [w1, w2, eos])) + self.assertTrue(self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0]) or + self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0])) + + def hypoTokens(self, hypo, tokens): + return self.tensorEqual(hypo['tokens'], torch.LongTensor(tokens)) + + def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.): pos_scores = torch.FloatTensor(pos_probs).log() - self.assertAlmostEqual(hypo['positional_scores'], pos_scores) - self.assertEqual(pos_scores.numel(), hypo['tokens'].numel()) + if not self.almostEqual(hypo['positional_scores'], pos_scores): + return False + if pos_scores.numel() != hypo['tokens'].numel(): + return False score = pos_scores.sum() if normalized: - score /= pos_scores.numel()**lenpen - self.assertLess(abs(score - hypo['score']), 1e-6) + score /= pos_scores.numel() ** lenpen + return abs(score - hypo['score']) < 1e-6 - def assertAlmostEqual(self, t1, t2): - self.assertEqual(t1.size(), t2.size(), "size mismatch") - self.assertLess((t1 - t2).abs().max(), 1e-4) + def almostEqual(self, t1, t2): + return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4 - def assertTensorEqual(self, t1, t2): - self.assertEqual(t1.size(), t2.size(), "size mismatch") - self.assertEqual(t1.ne(t2).long().sum(), 0) + def tensorEqual(self, t1, t2): + return t1.size() == t2.size() and t1.ne(t2).long().sum() == 0 if __name__ == '__main__':