Skip to content

Commit

Permalink
Nucleus (top-P) sampling (#710)
Browse files Browse the repository at this point in the history
Summary:
Implement Nucleus (top-P) sampling: sample among the smallest set of elements whose cumulative probability mass exceeds p.

To test it:
python generate.py   ~myleott/data/data-bin/wmt17_zh_en_full/   --path ~myleott/zh_en/model.pt   --remove-bpe   --nbest 5   --beam 5 --sampling --sampling-topp 0.3
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/710

Test Plan:
python generate.py   ~myleott/data/data-bin/wmt17_zh_en_full/   --path ~myleott/zh_en/model.pt   --remove-bpe   --nbest 5   --beam 5 --sampling --sampling-topp 0.3

python tests/test_sequence_generator.py

python tests/test_binaries.py

Reviewed By: myleott

Differential Revision: D16286688

Pulled By: xingz9

fbshipit-source-id: 1776d21e17c4532a3d24ac75bb7e75da9acad58f
  • Loading branch information
Xing Zhou authored and yzpang committed Feb 19, 2021
1 parent 7253686 commit 645115a
Showing 1 changed file with 144 additions and 13 deletions.
157 changes: 144 additions & 13 deletions tests/test_sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,25 +498,156 @@ def test_diverse_beam_search(self):
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])

def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))

def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
class TestTopPSamplingSearch(TestSequenceGeneratorBase):

def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5

# construct source data
self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos],
])
self.src_lengths = torch.LongTensor([2, 2])

args = argparse.Namespace()
unk = 0.
# The minimal probability of top 2 tokens.
self.min_top2_prob = 0.75
# The minimal probability of the top 1 token.
self.min_top1_prob = 0.4

w1_prob = self.min_top1_prob
w2_prob = self.min_top2_prob - self.min_top1_prob
eos_prob = 1 - self.min_top2_prob

args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
[0.0, unk, 1.0, 0.0],
[0.0, unk, 1.0, 0.0],
[0.0, unk, 1.0, 0.0],
[0.0, unk, 1.0, 0.0],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2
[eos_prob, unk, w1_prob, w2_prob],
[eos_prob, unk, w1_prob, w2_prob],
[eos_prob, unk, w1_prob, w2_prob],
[eos_prob, unk, w1_prob, w2_prob],
]),
# step 2:
torch.FloatTensor([
# eos w1 w2
[1.0, unk, 0.0, 0.0],
[1.0, unk, 0.0, 0.0],
[1.0, unk, 0.0, 0.0],
[1.0, unk, 0.0, 0.0],
]),
]

task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary

def test_topp_sampling_search_low_prob(self):
# Given a prob low enough to top-P sampling, we expect only the top
# 1 token to be sampled, which always results in the same output.
low_sampling_topp = self.min_top1_prob/2.0
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, sampling=True,
sampling_topp=low_sampling_topp
)
sample = {
'net_input': {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths
}
}
hypos = generator.generate([self.model], sample)
eos, w1 = self.eos, self.w1
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
self.assertHypoScore(hypos[0][0], [1.0, 0.4, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
self.assertHypoScore(hypos[0][1], [1.0, 0.4, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w1, eos])
self.assertHypoScore(hypos[1][0], [1.0, 0.4, 1.0])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
self.assertHypoScore(hypos[1][1], [1.0, 0.4, 1.0])

def test_topp_sampling_search_high_prob(self):
# Given a prob high enough to top-P sampling, any of the top 2
# tokens could be sampled. This can cause different outputs.
high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, sampling=True,
sampling_topp=high_sampling_topp
)
sample = {
'net_input': {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths
}
}
hypos = generator.generate([self.model], sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertTrue(self.hypoTokens(hypos[0][0], [w1, w1, eos]) or
self.hypoTokens(hypos[0][0], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0]))

# sentence 1, beam 2
self.assertTrue(self.hypoTokens(hypos[0][1], [w1, w1, eos]) or
self.hypoTokens(hypos[0][1], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0]))

# sentence 2, beam 1
self.assertTrue(self.hypoTokens(hypos[1][0], [w1, w1, eos]) or
self.hypoTokens(hypos[1][0], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0]))

# sentence 2, beam 2
self.assertTrue(self.hypoTokens(hypos[1][1], [w1, w1, eos]) or
self.hypoTokens(hypos[1][1], [w1, w2, eos]))
self.assertTrue(self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0]) or
self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0]))

def hypoTokens(self, hypo, tokens):
return self.tensorEqual(hypo['tokens'], torch.LongTensor(tokens))

def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
if not self.almostEqual(hypo['positional_scores'], pos_scores):
return False
if pos_scores.numel() != hypo['tokens'].numel():
return False
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel()**lenpen
self.assertLess(abs(score - hypo['score']), 1e-6)
score /= pos_scores.numel() ** lenpen
return abs(score - hypo['score']) < 1e-6

def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
def almostEqual(self, t1, t2):
return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4

def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
def tensorEqual(self, t1, t2):
return t1.size() == t2.size() and t1.ne(t2).long().sum() == 0


if __name__ == '__main__':
Expand Down

0 comments on commit 645115a

Please sign in to comment.