@@ -131,16 +131,16 @@ def put(self, values):
131131 self .text_index_cache [i ] += len (printable_text )
132132 output .append (printable_text )
133133 if any (output ):
134- self .text_queue .put (output , self . timeout )
134+ self .text_queue .put (output )
135135
136136 def end (self ):
137137 self .next_tokens_are_prompt = True
138138 output = []
139139 for i , tokens in enumerate (self .token_cache ):
140140 text = self .tokenizer .decode (tokens , ** self .decode_kwargs )
141141 output .append (text [self .text_index_cache [i ] :])
142- self .text_queue .put (output , self . timeout )
143- self .text_queue .put (self .stop_signal , self . timeout )
142+ self .text_queue .put (output )
143+ self .text_queue .put (self .stop_signal )
144144
145145 def __iter__ (self ):
146146 return self
@@ -264,12 +264,13 @@ def __init__(self, model_name, **kwargs):
264264 if self .tokenizer .pad_token is None :
265265 self .tokenizer .pad_token = self .tokenizer .eos_token
266266
267- def stream (self , input , ** kwargs ):
267+ def stream (self , input , timeout = None , ** kwargs ):
268268 streamer = None
269269 generation_kwargs = None
270270 if self .task == "conversational" :
271271 streamer = TextIteratorStreamer (
272272 self .tokenizer ,
273+ timeout = timeout ,
273274 skip_prompt = True ,
274275 )
275276 if "chat_template" in kwargs :
@@ -286,7 +287,10 @@ def stream(self, input, **kwargs):
286287 input = self .tokenizer (input , return_tensors = "pt" ).to (self .model .device )
287288 generation_kwargs = dict (input , streamer = streamer , ** kwargs )
288289 else :
289- streamer = TextIteratorStreamer (self .tokenizer )
290+ streamer = TextIteratorStreamer (
291+ self .tokenizer ,
292+ timeout = timeout ,
293+ )
290294 input = self .tokenizer (input , return_tensors = "pt" , padding = True ).to (
291295 self .model .device
292296 )
@@ -355,7 +359,7 @@ def create_pipeline(task):
355359 return pipe
356360
357361
358- def transform_using (pipeline , args , inputs , stream = False ):
362+ def transform_using (pipeline , args , inputs , stream = False , timeout = None ):
359363 args = orjson .loads (args )
360364 inputs = orjson .loads (inputs )
361365
@@ -364,7 +368,7 @@ def transform_using(pipeline, args, inputs, stream=False):
364368 convert_eos_token (pipeline .tokenizer , args )
365369
366370 if stream :
367- return pipeline .stream (inputs , ** args )
371+ return pipeline .stream (inputs , timeout = timeout , ** args )
368372 return orjson .dumps (pipeline (inputs , ** args ), default = orjson_default ).decode ()
369373
370374
0 commit comments