From 8ff0bf5271b8e043c64213df502d81e8969501e5 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 22 Mar 2024 10:01:27 -0700 Subject: [PATCH] Make max_new_tokens optional, default to max_total_tokens - input_length (#353) --- clients/python/lorax/client.py | 16 +++++----- clients/python/lorax/types.py | 2 +- docs/reference/openapi.json | 4 +-- router/src/lib.rs | 22 ++++--------- router/src/validation.rs | 57 ++++++++++++++++++++-------------- 5 files changed, 50 insertions(+), 51 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 18bbf8085..46627c3d4 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -67,7 +67,7 @@ def generate( merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, ignore_eos_token: bool = False, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, @@ -101,7 +101,7 @@ def generate( API token for accessing private adapters do_sample (`bool`): Activate logits sampling - max_new_tokens (`int`): + max_new_tokens (`Optional[int]`): Maximum number of generated tokens ignore_eos_token (`bool`): Whether to ignore EOS tokens during generation @@ -201,7 +201,7 @@ def generate_stream( merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, ignore_eos_token: bool = False, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -232,7 +232,7 @@ def generate_stream( API token for accessing private adapters do_sample (`bool`): Activate logits sampling - max_new_tokens (`int`): + max_new_tokens (`Optional[int]`): Maximum number of generated tokens ignore_eos_token (`bool`): Whether to ignore EOS tokens during generation @@ -388,7 +388,7 @@ async def generate( merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, ignore_eos_token: bool = False, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, @@ -422,7 +422,7 @@ async def generate( API token for accessing private adapters do_sample (`bool`): Activate logits sampling - max_new_tokens (`int`): + max_new_tokens (`Optional[int]`): Maximum number of generated tokens ignore_eos_token (`bool`): Whether to ignore EOS tokens during generation @@ -517,7 +517,7 @@ async def generate_stream( merged_adapters: Optional[MergedAdapters] = None, api_token: Optional[str] = None, do_sample: bool = False, - max_new_tokens: int = 20, + max_new_tokens: Optional[int] = None, ignore_eos_token: bool = False, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -550,7 +550,7 @@ async def generate_stream( API token for accessing private adapters do_sample (`bool`): Activate logits sampling - max_new_tokens (`int`): + max_new_tokens (`Optional[int]`): Maximum number of generated tokens ignore_eos_token (`bool`): Whether to ignore EOS tokens during generation diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 2b9aa32fd..e00dcdf42 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -79,7 +79,7 @@ class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens - max_new_tokens: int = 20 + max_new_tokens: Optional[int] = None # Whether to ignore the EOS token during generation ignore_eos_token: bool = False # The parameter for repetition penalty. 1.0 means no penalty. diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index 51130c4e7..d1997b5d6 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -745,9 +745,9 @@ "max_new_tokens": { "type": "integer", "format": "int32", - "default": "20", + "default": "null", + "nullable": true, "minimum": 0.0, - "exclusiveMaximum": 512.0, "exclusiveMinimum": 0.0 }, "ignore_eos_token": { diff --git a/router/src/lib.rs b/router/src/lib.rs index b218b3df9..701e6ff1c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -220,9 +220,9 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(default = "false", example = true)] pub do_sample: bool, - #[serde(default = "default_max_new_tokens")] - #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] - pub max_new_tokens: u32, + #[serde(default)] + #[schema(exclusive_minimum = 0, default = "null")] + pub max_new_tokens: Option, #[serde(default)] #[schema(default = "false", example = true)] pub ignore_eos_token: bool, @@ -267,10 +267,6 @@ pub(crate) struct GenerateParameters { pub response_format: Option, } -fn default_max_new_tokens() -> u32 { - 20 -} - fn default_parameters() -> GenerateParameters { GenerateParameters { adapter_id: None, @@ -284,7 +280,7 @@ fn default_parameters() -> GenerateParameters { top_p: None, typical_p: None, do_sample: false, - max_new_tokens: default_max_new_tokens(), + max_new_tokens: None, ignore_eos_token: false, return_full_text: None, stop: Vec::new(), @@ -621,10 +617,7 @@ impl From for CompatGenerateRequest { top_p: req.top_p, typical_p: None, do_sample: !req.n.is_none(), - max_new_tokens: req - .max_tokens - .map(|x| x as u32) - .unwrap_or(default_max_new_tokens()), + max_new_tokens: req.max_tokens.map(|x| x as u32), ignore_eos_token: req.ignore_eos_token.unwrap_or(false), return_full_text: req.echo, stop: req.stop, @@ -658,10 +651,7 @@ impl From for CompatGenerateRequest { top_p: req.top_p, typical_p: None, do_sample: !req.n.is_none(), - max_new_tokens: req - .max_tokens - .map(|x| x as u32) - .unwrap_or(default_max_new_tokens()), + max_new_tokens: req.max_tokens.map(|x| x as u32), ignore_eos_token: req.ignore_eos_token.unwrap_or(false), return_full_text: None, stop: req.stop, diff --git a/router/src/validation.rs b/router/src/validation.rs index 1cf9130b5..0a7c5f7f5 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -65,7 +65,7 @@ impl Validation { &self, inputs: String, truncate: Option, - max_new_tokens: u32, + max_new_tokens: Option, ) -> Result<(String, usize), ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { @@ -81,16 +81,18 @@ impl Validation { // Unwrap is safe here let (inputs, input_length) = response_receiver.await.unwrap()?; - // Get total tokens - let total_tokens = input_length + max_new_tokens as usize; - - // Validate MaxTotalTokens - if total_tokens > self.max_total_tokens { - return Err(ValidationError::MaxTotalTokens( - self.max_total_tokens, - input_length, - max_new_tokens, - )); + if let Some(max_new_tokens) = max_new_tokens { + // Get total tokens + let total_tokens = input_length + max_new_tokens as usize; + + // Validate MaxTotalTokens + if total_tokens > self.max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + self.max_total_tokens, + input_length, + max_new_tokens, + )); + } } // Validate InputLength @@ -111,12 +113,13 @@ impl Validation { // We make sure that truncate + max_new_tokens <= self.max_total_tokens let input_length = truncate.unwrap_or(self.max_input_length); - // Validate MaxNewTokens - if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { - return Err(ValidationError::MaxNewTokens( - self.max_total_tokens - self.max_input_length, - max_new_tokens, - )); + if let Some(max_new_tokens) = max_new_tokens { + if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { + return Err(ValidationError::MaxNewTokens( + self.max_total_tokens - self.max_input_length, + max_new_tokens, + )); + } } Ok((inputs, input_length)) @@ -231,7 +234,7 @@ impl Validation { }) .unwrap_or(Ok(0))?; - if max_new_tokens == 0 { + if max_new_tokens.is_some() && max_new_tokens.unwrap() == 0 { return Err(ValidationError::NegativeMaxNewTokens); } @@ -294,13 +297,19 @@ impl Validation { schema, return_k_alternatives, }; + + let effective_max_new_tokens = + max_new_tokens.unwrap_or((self.max_total_tokens - input_length) as u32); let stopping_parameters = StoppingCriteriaParameters { - max_new_tokens, + max_new_tokens: effective_max_new_tokens, stop_sequences, ignore_eos_token, }; - metrics::histogram!("lorax_request_max_new_tokens", max_new_tokens as f64); + metrics::histogram!( + "lorax_request_max_new_tokens", + effective_max_new_tokens as f64 + ); Ok(ValidGenerateRequest { inputs, @@ -461,7 +470,7 @@ mod tests { max_total_tokens, ); - let max_new_tokens = 10; + let max_new_tokens = Some(10); match validation .validate_input("Hello".to_string(), None, max_new_tokens) .await @@ -488,7 +497,7 @@ mod tests { max_total_tokens, ); - let max_new_tokens = 10; + let max_new_tokens = Some(10); match validation .validate_input("Hello".to_string(), None, max_new_tokens) .await @@ -588,7 +597,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: Some(0.99), - max_new_tokens: 1, + max_new_tokens: Some(1), ..default_parameters() }, }, @@ -614,7 +623,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: None, - max_new_tokens: 1, + max_new_tokens: Some(1), ..default_parameters() }, },