Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flax import nnx
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec
from torchax.ops.mappings import j2t_dtype
from torchax.ops.mappings import j2t, j2t_dtype
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
Expand All @@ -28,7 +28,7 @@
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, KVConnectorOutput, LogprobsLists,
ModelRunnerOutput)
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.request import Request
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.kv_connector_model_runner_mixin import \
Expand Down Expand Up @@ -122,9 +122,10 @@ def get_output(self) -> ModelRunnerOutput:
next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
1)
valid_sampled_token_ids = selected_token_ids.tolist()

valid_sampled_token_ids = [token_id for token_id in selected_token_ids]
for i in self._discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
valid_sampled_token_ids[i] = np.array([])
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
return self._model_runner_output

Expand Down Expand Up @@ -190,7 +191,8 @@ def _substitute_placeholder_token(
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)


def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
def _reorder_logits_indices(logprobs_lists: LogprobsLists,
logits_indices_selector: List[int]):
return LogprobsLists(
logprob_token_ids=[
logprobs_lists.logprob_token_ids[i]
Expand Down Expand Up @@ -595,11 +597,11 @@ def _modify_prev_results(self):
next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
1)
valid_sampled_token_ids = selected_token_ids.tolist()
valid_sampled_token_ids = [token_id for token_id in selected_token_ids]

# Mask out the sampled tokens that should not be sampled.
for i in pre_discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
valid_sampled_token_ids[i] = np.array([])
# Append sampled tokens
for pre_req_idx, req_state, _ in pre_request_seq_lens:
sampled_ids = valid_sampled_token_ids[pre_req_idx]
Expand Down Expand Up @@ -804,6 +806,8 @@ def _sample_from_logits(
if tpu_sampling_metadata.logprobs:
logprobs = self._compute_and_gather_logprobs(
logits, next_tokens, self.model_config.max_logprobs)
logprobs_lists = jax.tree.map(lambda x: j2t(x.astype(jnp.float32)),
logprobs).tolists()
else:
logprobs = None

Expand Down Expand Up @@ -856,7 +860,6 @@ def _sample_from_logits(

if logprobs is not None:
# Map logprobs back to the pre-dp shuffling order
logprobs_lists = logprobs.tolists()
if logits_indices_selector is not None:
logprobs_lists = _reorder_logits_indices(
logprobs_lists, logits_indices_selector)
Expand Down Expand Up @@ -898,7 +901,9 @@ def _sample_from_logits(
if logits_indices_selector is not None:
next_tokens = next_tokens[logits_indices_selector]
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
valid_sampled_token_ids = selected_token_ids.tolist()
valid_sampled_token_ids = [
token_id for token_id in selected_token_ids
]
else:
valid_sampled_token_ids = self.rejection_sampler.parse_output(
next_tokens, self.input_batch.vocab_size,
Expand All @@ -907,7 +912,7 @@ def _sample_from_logits(

# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
valid_sampled_token_ids[i] = np.array([])
# Append sampled tokens
for req_idx, req_state, _ in request_seq_lens:
sampled_ids = valid_sampled_token_ids[req_idx]
Expand All @@ -929,7 +934,6 @@ def _sample_from_logits(

if logprobs is not None:
# Map logprobs back to the pre-dp shuffling order
logprobs_lists = logprobs.tolists()
if logits_indices_selector is not None:
logprobs_lists = _reorder_logits_indices(
logprobs_lists, logits_indices_selector)
Expand Down Expand Up @@ -976,7 +980,8 @@ def select_local_fn(local_array, local_indices):

@staticmethod
@functools.partial(jax.jit, static_argnames=("max_logprobs", ))
def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
def _compute_and_gather_logprobs(logits, next_tokens,
max_logprobs) -> LogprobsTensors:
logprobs = compute_logprobs(logits)
return gather_logprobs(logprobs, next_tokens, max_logprobs)

Expand Down