Skip to content

Commit

Permalink
Change to use parametrize
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed May 13, 2024
1 parent 433499d commit 25b5e63
Showing 1 changed file with 28 additions and 38 deletions.
66 changes: 28 additions & 38 deletions tests/torchtune/utils/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,66 +88,56 @@ def test_sample_consistency(self):
token = sample(logits, temperature=1, top_k=1)
assert token.item() == 100

def test_reproducibility(self, generation_model, prompt_tokens):
@pytest.mark.parametrize(
"model1,model2,prompt",
[
("generation_model", "generation_model", "prompt_tokens"),
("generation_model", "generation_model_no_kv_cache", "prompt_tokens"),
(
"generation_model_batched",
"generation_model_batched",
"prompt_tokens_batched",
),
(
"generation_model_batched",
"generation_model_no_kv_cache",
"prompt_tokens_batched",
),
],
)
def test_reproducibility(self, request, model1, model2, prompt):
"""
Test to check if the `generate` function produces the same output when run with the same
inputs and a fixed seed.
inputs and a fixed seed. This should work regardless of batched input or kv cache.
"""

model1 = request.getfixturevalue(model1)
model2 = request.getfixturevalue(model2)
prompt = request.getfixturevalue(prompt)

temperature = 0.6
top_k = 100

torch.manual_seed(42)
outputs_first = utils.generate(
model=generation_model,
prompt=prompt_tokens,
model=model1,
prompt=prompt,
max_generated_tokens=10,
temperature=temperature,
top_k=top_k,
)

torch.manual_seed(42)
outputs_second = utils.generate(
model=generation_model,
prompt=prompt_tokens,
model=model2,
prompt=prompt,
max_generated_tokens=10,
temperature=temperature,
top_k=top_k,
)

assert outputs_first == outputs_second

def test_reproducibility_kv_cache_vs_no_kv_cache(
self, generation_model, generation_model_no_kv_cache, prompt_tokens
):
"""
Test to check if the `generate` function produces the same output when one model
has a kv cache enabled and the other doesn't
"""
temperature = 0.6
top_k = 100

torch.manual_seed(42)
assert generation_model.caches_are_enabled()
outputs_first = utils.generate(
model=generation_model,
prompt=prompt_tokens,
max_generated_tokens=20,
temperature=temperature,
top_k=top_k,
)

torch.manual_seed(42)
assert not generation_model_no_kv_cache.caches_are_enabled()
outputs_no_kv_cache = utils.generate(
model=generation_model_no_kv_cache,
prompt=prompt_tokens,
max_generated_tokens=20,
temperature=temperature,
top_k=top_k,
)

assert outputs_first == outputs_no_kv_cache

def test_batched_generate(self, generation_model_batched, prompt_tokens_batched):
"""Test batched generation works as expected."""
temperature = 0.6
Expand Down

0 comments on commit 25b5e63

Please sign in to comment.