@@ -125,7 +125,43 @@ def batch_process(self, dataframe: pd.DataFrame):
125
125
yield batch_prompts
126
126
batch_prompts = []
127
127
128
- async def call (self , request : RequestPayload ) -> Tuple [str , float ]:
128
+ async def fastdeploy_call (self , request : RequestPayload ) -> Tuple [str , float ]:
129
+ client = self .get_client ()
130
+ try :
131
+ async with self .semaphore :
132
+ start_time = time .time ()
133
+ response = await client .chat .completions .create (
134
+ model = self .model ,
135
+ messages = [{"role" : "user" , "content" : request .prompt }],
136
+ temperature = self .args .temperature ,
137
+ top_p = self .args .top_p ,
138
+ max_tokens = self .args .max_response_length ,
139
+ n = 1 ,
140
+ stream = True ,
141
+ timeout = 60 * 60 ,
142
+ metadata = {
143
+ "training" : True ,
144
+ "raw_request" : False ,
145
+ }
146
+ )
147
+ # Streaming text is stored in a list of chunks
148
+ chunks = []
149
+ # Streaming responses
150
+ async for chunk in response :
151
+ delta = chunk .choices [0 ].delta
152
+ if delta and delta .content :
153
+ chunks .append (delta .content )
154
+ text = "" .join (chunks )
155
+ end_time = time .time ()
156
+ elapsed_time = end_time - start_time
157
+ logger .debug ("Streaming response took %.2f seconds" , elapsed_time )
158
+ return text , round (elapsed_time , 2 )
159
+
160
+ except Exception as e :
161
+ logger .error ("Error while streaming: %s" , e )
162
+ raise ValueError (e )
163
+
164
+ async def vllm_call (self , request : RequestPayload ) -> Tuple [str , float ]:
129
165
client = self .get_client ()
130
166
try :
131
167
async with self .semaphore :
@@ -157,7 +193,12 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
157
193
158
194
async def group_call (self , request : RequestPayload ) -> ResponsePayload :
159
195
"""Performs n complete token generation rollouts for the given query."""
160
- tasks = [self .call (request ) for _ in range (request .num_responses )]
196
+ if self .args .use_fastdeploy == "true" :
197
+ call = self .fastdeploy_call
198
+ else :
199
+ call = self .vllm_call
200
+
201
+ tasks = [call (request ) for _ in range (request .num_responses )]
161
202
162
203
result = ResponsePayload ()
163
204
result .idx = request .idx
@@ -341,9 +382,9 @@ def parse_args():
341
382
parser .add_argument (
342
383
"--limit_rows" , type = int , default = - 1 , help = "Maximum number of rows to read from the dataset (-1 means all)"
343
384
)
385
+ parser .add_argument ("--use_fastdeploy" , type = str .lower , choices = ["true" , "false" ], default = "true" , help = "Engine selection (true=FastDeploy, false=vLLM, default: true)" )
344
386
return parser .parse_args ()
345
387
346
-
347
388
if __name__ == "__main__" :
348
389
args = parse_args ()
349
390
task = ApiTask (args )
0 commit comments