diff --git a/dsp/modules/gpt3.py b/dsp/modules/gpt3.py index bbcd1e3b6e..fe1b9feade 100644 --- a/dsp/modules/gpt3.py +++ b/dsp/modules/gpt3.py @@ -44,6 +44,8 @@ class GPT3(LM): api_key (Optional[str], optional): API provider Authentication token. use Defaults to None. api_provider (Literal["openai"], optional): The API provider to use. Defaults to "openai". model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text". + max_tokens (int, optional): The maximum number of tokens to use. The maximum number of tokens to + return is half of this value. Defaults to 150. **kwargs: Additional arguments to pass to the API provider. """ @@ -55,6 +57,7 @@ def __init__( api_base: Optional[str] = None, model_type: Literal["chat", "text"] = None, system_prompt: Optional[str] = None, + max_tokens: Optional[int] = 150, **kwargs, ): super().__init__(model) @@ -86,7 +89,7 @@ def __init__( self.kwargs = { "temperature": 0.0, - "max_tokens": 150, + "max_tokens": max_tokens, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index ffd4f8caa3..85784bed65 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -111,6 +111,7 @@ def do_generate( max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens" new_kwargs = { **kwargs, + # Set the required max tokens key to the new value. max_tokens_key: max_tokens, "n": 1, "temperature": 0.0,