diff --git a/tests/layers/jax/sample/test_rejection_sampler.py b/tests/layers/jax/sample/test_rejection_sampler.py index 51b65b447..57d644c69 100644 --- a/tests/layers/jax/sample/test_rejection_sampler.py +++ b/tests/layers/jax/sample/test_rejection_sampler.py @@ -436,6 +436,9 @@ def run_rejection_sampler_test( batch_size=len(num_draft_tokens), padded_tokens_length=int(sum(num_draft_tokens))) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + assert parsed_output == test_case.expected, \ f"Test '{test_case.name}': Expected {test_case.expected}, got {parsed_output}" @@ -512,6 +515,9 @@ def test_parse_output_basic(self, rejection_sampler): batch_size=len(num_draft_tokens), padded_tokens_length=int(sum(num_draft_tokens))) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + expected = [[10, 20, 30, 40], [50, 60, 70]] assert parsed_output == expected, f"Expected {expected}, got {parsed_output}" @@ -535,6 +541,9 @@ def test_parse_output_with_placeholders(self, rejection_sampler): batch_size=len(num_draft_tokens), padded_tokens_length=int(sum(num_draft_tokens))) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + expected = [[10], [20, 30, 40]] assert parsed_output == expected, f"Expected {expected}, got {parsed_output}" @@ -556,6 +565,9 @@ def test_parse_output_invalid_tokens(self, rejection_sampler): batch_size=len(num_draft_tokens), padded_tokens_length=int(sum(num_draft_tokens))) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + expected = [[10, 20]] # Invalid tokens filtered out assert parsed_output == expected, f"Expected {expected}, got {parsed_output}" @@ -577,6 +589,9 @@ def test_parse_output_empty_sequences(self, rejection_sampler): batch_size=len(num_draft_tokens), padded_tokens_length=int(sum(num_draft_tokens))) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + expected = [[50], [60]] # Only bonus tokens assert parsed_output == expected, f"Expected {expected}, got {parsed_output}" @@ -632,6 +647,9 @@ def test_extreme_padding(self, rejection_sampler, test_helper): batch_size=len(num_draft_tokens), padded_tokens_length=int(sum(num_draft_tokens))) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + expected = [[1, 5]] # Should ignore all padding assert parsed_output == expected, f"Expected {expected}, got {parsed_output}" @@ -777,6 +795,9 @@ def test_single_long_sequence(self, rejection_sampler, test_helper): batch_size=len(num_draft_tokens), padded_tokens_length=int(sum(num_draft_tokens))) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + expected = [list(range(1, 28)) + [99]] # Tokens up to mismatch point assert parsed_output == expected, f"Expected {expected}, got {parsed_output}" @@ -884,6 +905,9 @@ def test_non_greedy_deterministic_with_seed(self, rejection_sampler, num_draft_tokens_cpu=np.asarray(num_draft_tokens), batch_size=1, padded_tokens_length=4) + + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] outputs.append(parsed_output) # All outputs should be identical with same seed @@ -1064,6 +1088,9 @@ def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper): batch_size=2, padded_tokens_length=0) + # Convert numpy arrays to lists for comparison + parsed_output = [x.tolist() for x in parsed_output] + # Should get bonus tokens for empty sequences expected = [[77], [88]] assert parsed_output == expected, f"Expected {expected}, got {parsed_output}" @@ -1152,6 +1179,10 @@ def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler, non_greedy_parsed = rejection_sampler.parse_output( non_greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3) + # Convert numpy arrays to lists for comparison + greedy_parsed = [x.tolist() for x in greedy_parsed] + non_greedy_parsed = [x.tolist() for x in non_greedy_parsed] + # For perfect match, greedy should have all tokens + bonus assert greedy_parsed == [[5, 15, 25, 99]] diff --git a/tests/runner/test_speculative_decoding_manager.py b/tests/runner/test_speculative_decoding_manager.py index ea772e2ed..da55c1c7e 100644 --- a/tests/runner/test_speculative_decoding_manager.py +++ b/tests/runner/test_speculative_decoding_manager.py @@ -321,7 +321,7 @@ def test_propose_eagle3_draft_token_ids(self, ) # Inputs - sampled_token_ids = [[1], [2]] + sampled_token_ids = [np.array([1]), np.array([2])] aux_hidden_states = MagicMock() attn_metadata = MagicMock() attn_metadata.seq_lens.shape = [2] diff --git a/tpu_inference/layers/jax/sample/rejection_sampler.py b/tpu_inference/layers/jax/sample/rejection_sampler.py index 84d37eb1b..6623a6873 100644 --- a/tpu_inference/layers/jax/sample/rejection_sampler.py +++ b/tpu_inference/layers/jax/sample/rejection_sampler.py @@ -128,7 +128,7 @@ def parse_output( num_draft_tokens_cpu: np.ndarray, batch_size: int, padded_tokens_length: int, - ) -> list[list[int]]: + ) -> list[np.ndarray]: """Parse the output of the rejection sampler. Args: @@ -177,7 +177,7 @@ def parse_output( else: seq_tokens = valid_main_tokens - outputs.append(seq_tokens.tolist()) + outputs.append(seq_tokens) start_idx = end_idx return outputs diff --git a/tpu_inference/runner/speculative_decoding_manager.py b/tpu_inference/runner/speculative_decoding_manager.py index 82a9d7342..9bcb025a1 100644 --- a/tpu_inference/runner/speculative_decoding_manager.py +++ b/tpu_inference/runner/speculative_decoding_manager.py @@ -78,7 +78,7 @@ def propose_draft_token_ids( def propose_eagle3_draft_token_ids( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[np.ndarray], aux_hidden_states: Optional[tuple[jnp.ndarray, ...]], attn_metadata: AttentionMetadata, spec_decode_metadata: Optional[SpecDecodeMetadata], @@ -91,7 +91,7 @@ def propose_eagle3_draft_token_ids( req_ids = self.runner.input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): - if token_ids: + if token_ids.size != 0: # Common case. next_token_id = token_ids[-1] else: diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index e4f8e791e..bc525c791 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -935,7 +935,7 @@ def _sample_from_logits( # Append sampled tokens for req_idx, req_state, _ in request_seq_lens: sampled_ids = valid_sampled_token_ids[req_idx] - if not sampled_ids: + if sampled_ids.size == 0: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx]