/
ray_tokenizer_group.py
166 lines (139 loc) · 6.24 KB
/
ray_tokenizer_group.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import asyncio
import os
from typing import List, Optional
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.engine.ray_utils import ray
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
class RayTokenizerGroupPool(BaseTokenizerGroup):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
# Class to use for workers making up the pool.
_worker_cls = TokenizerGroup
@classmethod
def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
**init_kwargs) -> "RayTokenizerGroupPool":
ray_actor_options = (tokenizer_pool_config.extra_config or {
"num_cpus": 0
})
ray_actor_options.setdefault(
"scheduling_strategy",
NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=True))
# Carry over the env vars to the actors.
# This is necessary for API keys and such.
ray_actor_options.setdefault("runtime_env", {})
_carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"])
init_kwargs["num_actors"] = tokenizer_pool_config.pool_size
init_kwargs["ray_actor_options"] = ray_actor_options
return cls(**init_kwargs)
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], num_actors: int,
ray_actor_options: dict, **tokenizer_config):
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
self._local_tokenizer_group = self._worker_cls(
tokenizer_id=tokenizer_id,
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_input_length=max_input_length,
)
ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options)
self.tokenizer_actors = [
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
max_num_seqs, max_input_length,
**tokenizer_config)
for _ in range(num_actors)
]
self._idle_actors: Optional[asyncio.Queue] = None
@property
def pool_size(self) -> int:
return len(self.tokenizer_actors)
def ping(self):
return ray.get(
[actor.ping.remote() for actor in self.tokenizer_actors])
def _ensure_queue_initialized(self):
if self._idle_actors is None:
self._idle_actors = asyncio.Queue()
for actor in self.tokenizer_actors:
self._idle_actors.put_nowait(actor)
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
The actor is then put back in the queue for future use.
This is blocking.
"""
self._ensure_queue_initialized()
if self._idle_actors.empty():
raise RuntimeError("No idle actors available.")
actor = self._idle_actors.get_nowait()
try:
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
return ret
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
If there are no idle actors, we wait until one becomes
available.
The actor is then put back in the queue for future use.
This is non-blocking.
"""
self._ensure_queue_initialized()
actor = await self._idle_actors.get()
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
return ret
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self._local_tokenizer_group.get_max_input_len(lora_request)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
"""Copy over all current process environment variables to the runtime_env.
The variables in runtime_env will take precedence over the current process
environment variables.
runtime_env will be modified in place."""
env_vars = os.environ.copy()
runtime_env.setdefault("env_vars", {})
env_vars.update(runtime_env["env_vars"])
runtime_env["env_vars"] = env_vars