Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/layers/jax/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ def run_rejection_sampler_test(
batch_size=len(num_draft_tokens),
padded_tokens_length=int(sum(num_draft_tokens)))

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

assert parsed_output == test_case.expected, \
f"Test '{test_case.name}': Expected {test_case.expected}, got {parsed_output}"

Expand Down Expand Up @@ -512,6 +515,9 @@ def test_parse_output_basic(self, rejection_sampler):
batch_size=len(num_draft_tokens),
padded_tokens_length=int(sum(num_draft_tokens)))

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

expected = [[10, 20, 30, 40], [50, 60, 70]]
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"

Expand All @@ -535,6 +541,9 @@ def test_parse_output_with_placeholders(self, rejection_sampler):
batch_size=len(num_draft_tokens),
padded_tokens_length=int(sum(num_draft_tokens)))

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

expected = [[10], [20, 30, 40]]
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"

Expand All @@ -556,6 +565,9 @@ def test_parse_output_invalid_tokens(self, rejection_sampler):
batch_size=len(num_draft_tokens),
padded_tokens_length=int(sum(num_draft_tokens)))

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

expected = [[10, 20]] # Invalid tokens filtered out
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"

Expand All @@ -577,6 +589,9 @@ def test_parse_output_empty_sequences(self, rejection_sampler):
batch_size=len(num_draft_tokens),
padded_tokens_length=int(sum(num_draft_tokens)))

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

expected = [[50], [60]] # Only bonus tokens
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"

Expand Down Expand Up @@ -632,6 +647,9 @@ def test_extreme_padding(self, rejection_sampler, test_helper):
batch_size=len(num_draft_tokens),
padded_tokens_length=int(sum(num_draft_tokens)))

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

expected = [[1, 5]] # Should ignore all padding
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"

Expand Down Expand Up @@ -777,6 +795,9 @@ def test_single_long_sequence(self, rejection_sampler, test_helper):
batch_size=len(num_draft_tokens),
padded_tokens_length=int(sum(num_draft_tokens)))

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

expected = [list(range(1, 28)) + [99]] # Tokens up to mismatch point
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"

Expand Down Expand Up @@ -884,6 +905,9 @@ def test_non_greedy_deterministic_with_seed(self, rejection_sampler,
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
batch_size=1,
padded_tokens_length=4)

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]
outputs.append(parsed_output)

# All outputs should be identical with same seed
Expand Down Expand Up @@ -1064,6 +1088,9 @@ def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper):
batch_size=2,
padded_tokens_length=0)

# Convert numpy arrays to lists for comparison
parsed_output = [x.tolist() for x in parsed_output]

# Should get bonus tokens for empty sequences
expected = [[77], [88]]
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
Expand Down Expand Up @@ -1152,6 +1179,10 @@ def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler,
non_greedy_parsed = rejection_sampler.parse_output(
non_greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3)

# Convert numpy arrays to lists for comparison
greedy_parsed = [x.tolist() for x in greedy_parsed]
non_greedy_parsed = [x.tolist() for x in non_greedy_parsed]

# For perfect match, greedy should have all tokens + bonus
assert greedy_parsed == [[5, 15, 25, 99]]

Expand Down
2 changes: 1 addition & 1 deletion tests/runner/test_speculative_decoding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_propose_eagle3_draft_token_ids(self,
)

# Inputs
sampled_token_ids = [[1], [2]]
sampled_token_ids = [np.array([1]), np.array([2])]
aux_hidden_states = MagicMock()
attn_metadata = MagicMock()
attn_metadata.seq_lens.shape = [2]
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/layers/jax/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def parse_output(
num_draft_tokens_cpu: np.ndarray,
batch_size: int,
padded_tokens_length: int,
) -> list[list[int]]:
) -> list[np.ndarray]:
"""Parse the output of the rejection sampler.
Args:
Expand Down Expand Up @@ -177,7 +177,7 @@ def parse_output(
else:
seq_tokens = valid_main_tokens

outputs.append(seq_tokens.tolist())
outputs.append(seq_tokens)
start_idx = end_idx

return outputs
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/runner/speculative_decoding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def propose_draft_token_ids(

def propose_eagle3_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
sampled_token_ids: list[np.ndarray],
aux_hidden_states: Optional[tuple[jnp.ndarray, ...]],
attn_metadata: AttentionMetadata,
spec_decode_metadata: Optional[SpecDecodeMetadata],
Expand All @@ -91,7 +91,7 @@ def propose_eagle3_draft_token_ids(
req_ids = self.runner.input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
if token_ids.size != 0:
# Common case.
next_token_id = token_ids[-1]
else:
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def _sample_from_logits(
# Append sampled tokens
for req_idx, req_state, _ in request_seq_lens:
sampled_ids = valid_sampled_token_ids[req_idx]
if not sampled_ids:
if sampled_ids.size == 0:
continue

start_idx = self.input_batch.num_tokens_no_spec[req_idx]
Expand Down