Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non-incremental decoding + unit test #973

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions tests/torchtune/utils/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ def generation_model(self, dtype=torch.float32):
model.eval()
return model

@pytest.fixture
def generation_model_no_kv_cache(self, dtype=torch.float32):
model = llama2(
vocab_size=4_000,
embed_dim=128,
num_layers=2,
num_heads=4,
num_kv_heads=4,
max_seq_len=2048,
)
fixed_init_model(model)
model.eval()
return model

@pytest.fixture
def generation_model_batched(self, dtype=torch.float32):
model = llama2(
Expand Down Expand Up @@ -74,27 +88,49 @@ def test_sample_consistency(self):
token = sample(logits, temperature=1, top_k=1)
assert token.item() == 100

def test_reproducibility(self, generation_model, prompt_tokens):
@pytest.mark.parametrize(
"model1,model2,prompt",
[
("generation_model", "generation_model", "prompt_tokens"),
("generation_model", "generation_model_no_kv_cache", "prompt_tokens"),
(
"generation_model_batched",
"generation_model_batched",
"prompt_tokens_batched",
),
(
"generation_model_batched",
"generation_model_no_kv_cache",
"prompt_tokens_batched",
),
],
)
def test_reproducibility(self, request, model1, model2, prompt):
"""
Test to check if the `generate` function produces the same output when run with the same
inputs and a fixed seed.
inputs and a fixed seed. This should work regardless of batched input or kv cache.
"""

model1 = request.getfixturevalue(model1)
model2 = request.getfixturevalue(model2)
prompt = request.getfixturevalue(prompt)

temperature = 0.6
top_k = 100

torch.manual_seed(42)
outputs_first = utils.generate(
model=generation_model,
prompt=prompt_tokens,
model=model1,
prompt=prompt,
max_generated_tokens=10,
temperature=temperature,
top_k=top_k,
)

torch.manual_seed(42)
outputs_second = utils.generate(
model=generation_model,
prompt=prompt_tokens,
model=model2,
prompt=prompt,
max_generated_tokens=10,
temperature=temperature,
top_k=top_k,
Expand Down
6 changes: 5 additions & 1 deletion torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:
torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool)
)

def caches_are_enabled(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can make this a property

"""Check if the key value caches are setup."""
return self.layers[0].attn.kv_cache is not None

def reset_caches(self):
"""Reset the key value caches."""
if self.layers[0].attn.kv_cache is None:
if not self.caches_are_enabled():
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)
Expand Down
23 changes: 19 additions & 4 deletions torchtune/utils/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ def generate(
custom_generate_next_token = generate_next_token

# generate the first tokens conditioned on the prompt
input_pos = torch.arange(0, model.max_seq_len, device=prompt.device)
tokens = generate_next_token(
model,
input_pos=torch.arange(0, prompt_length, device=prompt.device),
input_pos=input_pos[:prompt_length],
x=prompt,
temperature=temperature,
top_k=top_k,
Expand All @@ -137,7 +138,9 @@ def generate(
if stop_token_reached.all().item():
return generated_tokens.tolist()

input_pos = torch.tensor([prompt_length], device=prompt.device)
curr_pos = prompt_length
# if key value caches are enabled, we can incrementally decode
incremental_decoding = model.caches_are_enabled()
for _ in range(max_generated_tokens - 1):
# update stop_token_mask if we reached a stop token in a previous step
# by appending the logical not of stop_token_reached to the end of the mask
Expand All @@ -147,12 +150,24 @@ def generate(
[stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1
)

# if incremental decoding is enabled, we can use the current position
# otherwise, we take the whole sequence up to the current position
if incremental_decoding:
curr_input_pos = input_pos[curr_pos].unsqueeze(0)
else:
curr_input_pos = input_pos[: curr_pos + 1]
tokens = generated_tokens.clone()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure I need this clone...

cc someone who is better at pytorch than me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safest to leave it since you need to keep track of tokens and generated tokens separately


tokens = custom_generate_next_token(
model, input_pos=input_pos, x=tokens, temperature=temperature, top_k=top_k
model,
input_pos=curr_input_pos,
x=tokens,
temperature=temperature,
top_k=top_k,
)

generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
input_pos += 1
curr_pos += 1

if stop_tokens is not None:
stop_token_reached = update_stop_tokens_tracker(
Expand Down
Loading