diff --git a/examples/amazon-bedrock/promptfooconfig.claude.yaml b/examples/amazon-bedrock/promptfooconfig.claude.yaml index 86c8195564..c18c019434 100644 --- a/examples/amazon-bedrock/promptfooconfig.claude.yaml +++ b/examples/amazon-bedrock/promptfooconfig.claude.yaml @@ -2,7 +2,7 @@ prompts: - 'Convert this English to {{language}}: {{input}}' - 'Translate to {{language}}: {{input}}' providers: - - id: bedrock:completion:anthropic.claude-instant-v1 + - model: bedrock:completion:anthropic.claude-instant-v1 config: region: 'us-west-2' temperature: 0.7 diff --git a/examples/amazon-bedrock/promptfooconfig.titan-text.yaml b/examples/amazon-bedrock/promptfooconfig.titan-text.yaml index c18307c648..413c616b1b 100644 --- a/examples/amazon-bedrock/promptfooconfig.titan-text.yaml +++ b/examples/amazon-bedrock/promptfooconfig.titan-text.yaml @@ -2,7 +2,7 @@ prompts: - 'Convert this English to {{language}}: {{input}}' - 'Translate to {{language}}: {{input}}' providers: - - id: bedrock:completion:amazon.titan-text-lite-v1 + - model: bedrock:completion:amazon.titan-text-lite-v1 config: region: 'us-west-2' textGenerationConfig: diff --git a/examples/azure-openai/promptfooconfig.yaml b/examples/azure-openai/promptfooconfig.yaml index 254c12406d..d6e246785a 100644 --- a/examples/azure-openai/promptfooconfig.yaml +++ b/examples/azure-openai/promptfooconfig.yaml @@ -2,7 +2,7 @@ prompts: - 'Generate one very interesting fact about {{topic}}' providers: - - id: azureopenai:chat:gpt-35-turbo-deployment1 + - model: azureopenai:chat:gpt-35-turbo-deployment1 config: apiHost: 'your-org.openai.azure.com' @@ -20,6 +20,6 @@ tests: - type: similar value: Bananas are naturally radioactive. provider: - id: azureopenai:embeddings:ada-deployment1 + model: azureopenai:embeddings:ada-deployment1 config: apiHost: 'your-org.openai.azure.com' diff --git a/examples/cohere/simple_config.yaml b/examples/cohere/simple_config.yaml index 034839d322..796c0ef03e 100644 --- a/examples/cohere/simple_config.yaml +++ b/examples/cohere/simple_config.yaml @@ -2,12 +2,12 @@ prompts: - "Write a tweet about {{topic}}" providers: - - id: cohere:command + - model: cohere:command config: temperature: 0.5 prompt_truncation: AUTO connectors: - - id: web-search + - model: web-search showSearchQueries: true tests: diff --git a/examples/custom-provider/promptfooconfig.yaml b/examples/custom-provider/promptfooconfig.yaml index 58a5fd82d5..7742b89c84 100644 --- a/examples/custom-provider/promptfooconfig.yaml +++ b/examples/custom-provider/promptfooconfig.yaml @@ -4,11 +4,11 @@ tests: vars.csv # To compare two of the same provider, you can do the following: # # providers: -# - customProvider.js: -# id: custom-provider-hightemp -# config: -# temperature: 1.0 -# - customProvider.js: -# id: custom-provider-lowtemp -# config: -# temperature: 0 +# - model: customProvider.js +# label: custom-provider-hightemp +# config: +# temperature: 1.0 +# - model: customProvider.js +# label: custom-provider-lowtemp +# config: +# temperature: 0 diff --git a/examples/external-provider-config/gpt-3.5.yaml b/examples/external-provider-config/gpt-3.5.yaml index 52ae9403a3..0e43de7f13 100644 --- a/examples/external-provider-config/gpt-3.5.yaml +++ b/examples/external-provider-config/gpt-3.5.yaml @@ -1,4 +1,4 @@ -id: 'openai:chat:gpt-3.5-turbo-0613' +model: 'openai:chat:gpt-3.5-turbo-0613' config: functions: [ diff --git a/examples/gemma-vs-llama/promptfooconfig.yaml b/examples/gemma-vs-llama/promptfooconfig.yaml index 46bb779b32..ffe265ab71 100644 --- a/examples/gemma-vs-llama/promptfooconfig.yaml +++ b/examples/gemma-vs-llama/promptfooconfig.yaml @@ -2,7 +2,7 @@ prompts: - "{{message}}" providers: - - id: replicate:meta/llama-2-7b-chat + - model: replicate:meta/llama-2-7b-chat config: temperature: 0.01 # minimum temperature max_new_tokens: 1024 @@ -10,7 +10,7 @@ providers: prefix: "[INST] " suffix: "[/INST] " - - id: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 + - model: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 config: temperature: 0.01 max_new_tokens: 1024 diff --git a/examples/gemma-vs-mistral/promptfooconfig.yaml b/examples/gemma-vs-mistral/promptfooconfig.yaml index b345557f5a..bd43512968 100644 --- a/examples/gemma-vs-mistral/promptfooconfig.yaml +++ b/examples/gemma-vs-mistral/promptfooconfig.yaml @@ -6,7 +6,7 @@ defaultTest: transform: output.trim() providers: - - id: replicate:mistralai/mistral-7b-instruct-v0.2 + - model: replicate:mistralai/mistral-7b-instruct-v0.2 config: temperature: 0.01 max_new_tokens: 1024 @@ -14,7 +14,7 @@ providers: prefix: "[INST] " suffix: " [/INST]" - - id: replicate:mistralai/mixtral-8x7b-instruct-v0.1 + - model: replicate:mistralai/mixtral-8x7b-instruct-v0.1 config: temperature: 0.01 max_new_tokens: 1024 @@ -22,7 +22,7 @@ providers: prefix: "[INST] " suffix: " [/INST]" - - id: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 + - model: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 config: temperature: 0.01 max_new_tokens: 1024 diff --git a/examples/google-aistudio-gemini/promptfooconfig.yaml b/examples/google-aistudio-gemini/promptfooconfig.yaml index bc45131cfa..472df7a71b 100644 --- a/examples/google-aistudio-gemini/promptfooconfig.yaml +++ b/examples/google-aistudio-gemini/promptfooconfig.yaml @@ -3,7 +3,7 @@ prompts: - "Write a very concise, funny tweet about {{topic}}" providers: - - id: google:gemini-pro + - model: google:gemini-pro config: generationConfig: temperature: 0 diff --git a/examples/gpt-3.5-temperature-comparison/promptfooconfig.yaml b/examples/gpt-3.5-temperature-comparison/promptfooconfig.yaml index 191e3618db..751d203bb2 100644 --- a/examples/gpt-3.5-temperature-comparison/promptfooconfig.yaml +++ b/examples/gpt-3.5-temperature-comparison/promptfooconfig.yaml @@ -2,14 +2,14 @@ prompts: - 'Respond to the following instruction: {{message}}' providers: - - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-lowtemp - config: - temperature: 0 - - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-hightemp - config: - temperature: 1 + - model: openai:gpt-3.5-turbo-0613 + label: openai-gpt-3.5-turbo-lowtemp + config: + temperature: 0 + - model: openai:gpt-3.5-turbo-0613 + label: openai-gpt-3.5-turbo-hightemp + config: + temperature: 1 tests: - vars: diff --git a/examples/huggingface-inference-endpoint/promptfooconfig.yaml b/examples/huggingface-inference-endpoint/promptfooconfig.yaml index 5b87cfea0c..ebf8e43e96 100644 --- a/examples/huggingface-inference-endpoint/promptfooconfig.yaml +++ b/examples/huggingface-inference-endpoint/promptfooconfig.yaml @@ -4,7 +4,7 @@ prompts: - "Write a tweet about {{topic}}:" providers: - - id: huggingface:text-generation:gemma-7b-it + - model: huggingface:text-generation:gemma-7b-it config: apiEndpoint: https://v9igsezez4ei3cq4.us-east-1.aws.endpoints.huggingface.cloud # apiKey: abc123 # Or set HF_API_TOKEN environment variable diff --git a/examples/llama-gpt-comparison/promptfooconfig.yaml b/examples/llama-gpt-comparison/promptfooconfig.yaml index 21daa8688e..6db6f97759 100644 --- a/examples/llama-gpt-comparison/promptfooconfig.yaml +++ b/examples/llama-gpt-comparison/promptfooconfig.yaml @@ -3,23 +3,27 @@ prompts: prompts/completion_prompt.txt: completion_prompt providers: - - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-lowtemp - prompts: chat_prompt - config: - temperature: 0 - max_tokens: 128 - - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-hightemp - prompts: chat_prompt - config: - temperature: 1 - max_tokens: 128 - - replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48: - prompts: completion_prompt - config: - temperature: 0.01 # minimum temperature - max_length: 128 + - model: openai:gpt-3.5-turbo-0613 + label: openai-gpt-3.5-turbo-lowtemp + prompts: + - chat_prompt + config: + temperature: 0 + max_tokens: 128 + - model: openai:gpt-3.5-turbo-0613 + label: openai-gpt-3.5-turbo-hightemp + prompts: + - chat_prompt + config: + temperature: 1 + max_tokens: 128 + - model: replicate:meta/llama70b-v2-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3 + label: llama70b-v2-chat + prompts: + - completion_prompt + config: + temperature: 0.01 # minimum temperature + max_length: 128 tests: - vars: diff --git a/examples/llama-gpt-comparison/prompts/chat_prompt.json b/examples/llama-gpt-comparison/prompts/chat_prompt.json index a7f9d83598..2d570ae2f0 100644 --- a/examples/llama-gpt-comparison/prompts/chat_prompt.json +++ b/examples/llama-gpt-comparison/prompts/chat_prompt.json @@ -1,4 +1,8 @@ [ + { + "role": "system", + "content": "You are a pirate" + }, { "role": "user", "content": "{{message}}" diff --git a/examples/llama-gpt-comparison/prompts/completion_prompt.txt b/examples/llama-gpt-comparison/prompts/completion_prompt.txt index 2876fa0449..aa5c61d591 100644 --- a/examples/llama-gpt-comparison/prompts/completion_prompt.txt +++ b/examples/llama-gpt-comparison/prompts/completion_prompt.txt @@ -1,2 +1,4 @@ -User: {{message}} -Assistant: +[INST] <> +You are a pirate +<> +{{message}} \ No newline at end of file diff --git a/examples/mistral-llama-comparison/promptfooconfig.yaml b/examples/mistral-llama-comparison/promptfooconfig.yaml index 91dd721ad4..9eb7e0f5b7 100644 --- a/examples/mistral-llama-comparison/promptfooconfig.yaml +++ b/examples/mistral-llama-comparison/promptfooconfig.yaml @@ -3,22 +3,25 @@ prompts: prompts/llama_prompt.txt: llama_prompt providers: - - huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1: - prompts: mistral_prompt - config: - temperature: 0.01 - max_new_tokens: 128 - - replicate:mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e: - prompts: mistral_prompt - config: - temperature: 0.01 - max_new_tokens: 128 - prompt_template: '{prompt}' - - replicate:meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e: - prompts: llama_prompt - config: - temperature: 0.01 - max_new_tokens: 128 + - model: huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1 + prompts: + - mistral_prompt + config: + temperature: 0.01 + max_new_tokens: 128 + - model: replicate:mistralai/mixtral-8x7b-instruct-v0.1 + prompts: + - mistral_prompt + config: + temperature: 0.01 + max_new_tokens: 128 + prompt_template: '{prompt}' + - model: replicate:meta/llama-2-7b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48 + prompts: + - llama_prompt + config: + temperature: 0.01 + max_new_tokens: 128 tests: - vars: diff --git a/examples/ollama-comparison/promptfooconfig.yaml b/examples/ollama-comparison/promptfooconfig.yaml index cab5a8ed9c..21e5fbdeb8 100644 --- a/examples/ollama-comparison/promptfooconfig.yaml +++ b/examples/ollama-comparison/promptfooconfig.yaml @@ -3,16 +3,18 @@ prompts: prompts/llama_prompt.txt: llama_prompt providers: - - ollama:llama2: - prompts: llama_prompt - config: - num_predict: 1024 - - ollama:llama2-uncensored: - prompts: llama_prompt - config: - num_predict: 1024 - - openai:gpt-3.5-turbo: - prompts: openai_prompt + - id: ollama:llama2 + prompts: + - llama_prompt + config: + num_predict: 1024 + - id: ollama:llama2-uncensored + prompts: + - llama_prompt + config: + num_predict: 1024 + - id: openai:gpt-3.5-turbo + prompts: openai_prompt defaultTest: assert: diff --git a/examples/perplexity.ai-example/promptfooconfig.yaml b/examples/perplexity.ai-example/promptfooconfig.yaml index 59f55c378f..cb42db2471 100644 --- a/examples/perplexity.ai-example/promptfooconfig.yaml +++ b/examples/perplexity.ai-example/promptfooconfig.yaml @@ -3,8 +3,8 @@ prompts: providers: # Compare GPT 3.5 vs Perplexity 70B - - id: openai:chat:gpt-3.5-turbo-0613 - - id: openai:chat:pplx-70b-online + - model: openai:chat:gpt-3.5-turbo-0613 + - model: openai:chat:pplx-70b-online config: apiBaseUrl: https://api.perplexity.ai apiKeyEnvar: PERPLEXITY_API_KEY diff --git a/examples/python-provider/promptfooconfig.yaml b/examples/python-provider/promptfooconfig.yaml index 8f5144eca0..fc0dde64a1 100644 --- a/examples/python-provider/promptfooconfig.yaml +++ b/examples/python-provider/promptfooconfig.yaml @@ -3,7 +3,7 @@ prompts: - "Write a very concise, funny tweet about {{topic}}" providers: - - id: python:provider.py # or provider_async.py + - model: python:provider.py # or provider_async.py config: someOption: foobar diff --git a/examples/replicate-lifeboat/promptfooconfig.yaml b/examples/replicate-lifeboat/promptfooconfig.yaml index cbbe6d4f33..84779148e0 100644 --- a/examples/replicate-lifeboat/promptfooconfig.yaml +++ b/examples/replicate-lifeboat/promptfooconfig.yaml @@ -2,12 +2,12 @@ prompts: - 'Respond to the user concisely: {{message}}' providers: - - id: openai:chat:gpt-3.5-turbo + - model: openai:chat:gpt-3.5-turbo config: apiKey: '...' temperature: 0.01 max_tokens: 512 - - id: openai:chat:meta/llama-2-70b-chat + - model: openai:chat:meta/llama-2-70b-chat config: apiKey: '...' apiBaseUrl: https://openai-proxy.replicate.com diff --git a/site/docs/configuration/expected-outputs/model-graded.md b/site/docs/configuration/expected-outputs/model-graded.md index 3ac6a9fa8c..68b1f12c2b 100644 --- a/site/docs/configuration/expected-outputs/model-graded.md +++ b/site/docs/configuration/expected-outputs/model-graded.md @@ -115,7 +115,7 @@ tests: ## Examples (comparison) -The `select-best` assertion type is used to compare multiple outputs in the same TestCase row and select the one that best meets a specified criterion. +The `select-best` assertion type is used to compare multiple outputs in the same TestCase row and select the one that best meets a specified criterion. Here's an example of how to use `select-best` in a configuration file: @@ -178,7 +178,7 @@ Use the `provider.config` field to set custom parameters: ```yaml provider: - - id: openai:gpt-3.5-turbo + - model: openai:gpt-3.5-turbo config: temperature: 0 ``` diff --git a/site/docs/configuration/parameters.md b/site/docs/configuration/parameters.md index 924b027209..4e0acd8fee 100644 --- a/site/docs/configuration/parameters.md +++ b/site/docs/configuration/parameters.md @@ -131,12 +131,12 @@ prompts: prompts/llama_completion_prompt.txt: llama_completion_prompt providers: - - openai:gpt-3.5-turbo-0613: - prompts: gpt_chat_prompt - - openai:gpt-4-turbo-0613: - prompts: gpt_chat_prompt - - replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48: - prompts: llama_completion_prompt + - model: openai:gpt-3.5-turbo-0613 + prompts: gpt_chat_prompt + - model: openai:gpt-4-turbo-0613 + prompts: gpt_chat_prompt + - model: replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48 + prompts: llama_completion_prompt ``` In this configuration, the `gpt_chat_prompt` is used for both GPT-3.5 and GPT-4 models, while the `llama_completion_prompt` is used for the Llama v2 model. The prompts are defined in separate files within the `prompts` directory. diff --git a/site/docs/guides/azure-vs-openai.md b/site/docs/guides/azure-vs-openai.md index bf27749a72..8b6c9077b9 100644 --- a/site/docs/guides/azure-vs-openai.md +++ b/site/docs/guides/azure-vs-openai.md @@ -45,8 +45,8 @@ Edit your `promptfooconfig.yaml` to include both OpenAI and Azure OpenAI as prov ```yaml providers: - - id: openai:chat:gpt-3.5-turbo - - id: azureopenai:chat:my-gpt-35-turbo-deployment + - model: openai:chat:gpt-3.5-turbo + - model: azureopenai:chat:my-gpt-35-turbo-deployment config: apiHost: myazurehost.openai.azure.com ``` @@ -59,11 +59,11 @@ For each provider, you may configure additional parameters such as `temperature` ```yaml providers: - - id: openai:chat:gpt-3.5-turbo + - model: openai:chat:gpt-3.5-turbo config: temperature: 0 max_tokens: 128 - - id: azureopenai:chat:my-gpt-35-turbo-deployment + - model: azureopenai:chat:my-gpt-35-turbo-deployment config: apiHost: your_azure_openai_host temperature: 0 diff --git a/site/docs/guides/compare-llama2-vs-gpt.md b/site/docs/guides/compare-llama2-vs-gpt.md index 949ef5feb9..a933f53d04 100644 --- a/site/docs/guides/compare-llama2-vs-gpt.md +++ b/site/docs/guides/compare-llama2-vs-gpt.md @@ -69,12 +69,12 @@ prompts: prompts/completion_prompt.txt: completion_prompt providers: - - openai:gpt-3.5-turbo-0613: - prompts: chat_prompt - - openai:gpt-4-0613: - prompts: chat_prompt - - replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48: - prompts: completion_prompt + - model: openai:gpt-3.5-turbo-0613 + prompts: chat_prompt + - model: openai:gpt-4-0613 + prompts: chat_prompt + - model: replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48 + prompts: completion_prompt ``` :::info @@ -161,27 +161,27 @@ Each model has a `config` field where you can specify additional parameters. Let ```yaml title=promptfooconfig.yaml providers: - - openai:gpt-3.5-turbo-0613: - prompts: chat_prompt - // highlight-start - config: - temperature: 0 - max_tokens: 128 - // highlight-end - - openai:gpt-4-0613: - prompts: chat_prompt - // highlight-start - config: - temperature: 0 - max_tokens: 128 - // highlight-end - - replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48: - prompts: completion_prompt - // highlight-start - config: - temperature: 0.01 # minimum temperature - max_length: 128 - // highlight-end + - model: openai:gpt-3.5-turbo-0613 + prompts: chat_prompt + // highlight-start + config: + temperature: 0 + max_tokens: 128 + // highlight-end + - model: openai:gpt-4-0613 + prompts: chat_prompt + // highlight-start + config: + temperature: 0 + max_tokens: 128 + // highlight-end + - model: replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48 + prompts: completion_prompt + // highlight-start + config: + temperature: 0.01 # minimum temperature + max_length: 128 + // highlight-end ``` Here's what each parameter means: diff --git a/site/docs/guides/evaluate-llm-temperature.md b/site/docs/guides/evaluate-llm-temperature.md index fd27642ac5..deb12710eb 100644 --- a/site/docs/guides/evaluate-llm-temperature.md +++ b/site/docs/guides/evaluate-llm-temperature.md @@ -4,11 +4,11 @@ The `temperature`` setting in language models is like a dial that adjusts how pr In general, a low temperature leads to "safer", more expected words, while a higher temperature encourages the model to choose less obvious words. This is why higher temperature is commonly associated with more creative outputs. -Under the hood, `temperature` adjusts how the model calculates the likelihood of each word it might pick next. +Under the hood, `temperature` adjusts how the model calculates the likelihood of each word it might pick next. The `temperature` parameter affects each output token by scaling the logits (the raw output scores from the model) before they are passed through the softmax function that turns them into probabilities. Lower temperatures sharpen the distinction between high and low scores, making the high scores more dominant, while higher temperatures flatten this distinction, giving lower-scoring words a better chance of being chosen. -## Finding the optimal temperature +## Finding the optimal temperature The best way to find the optimal temperature parameter is to run a systematic *evaluation*. @@ -40,11 +40,11 @@ prompts: providers: - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-lowtemp + model: openai-gpt-3.5-turbo-lowtemp config: temperature: 0.2 - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-hightemp + model: openai-gpt-3.5-turbo-hightemp config: temperature: 0.9 @@ -71,7 +71,7 @@ To run the evaluation, use the following command: npx promptfoo@latest eval ``` -This command shows the outputs side-by-side in the command line. +This command shows the outputs side-by-side in the command line. ## Adding automated checks @@ -131,13 +131,13 @@ Set a constant seed in the provider config: ```yaml providers: - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-lowtemp + model: openai-gpt-3.5-turbo-lowtemp config: temperature: 0.2 // highlight-next-line seed: 0 - openai:gpt-3.5-turbo-0613: - id: openai-gpt-3.5-turbo-hightemp + model: openai-gpt-3.5-turbo-hightemp config: temperature: 0.9 // highlight-next-line @@ -150,4 +150,4 @@ The `eval` command also has a parameter, `repeat`, which runs each test multiple promptfoo eval --repeat 3 ``` -The above command runs the LLM three times for each test case, helping you get a more complete sample of how it performs at a given temperature. \ No newline at end of file +The above command runs the LLM three times for each test case, helping you get a more complete sample of how it performs at a given temperature. diff --git a/site/docs/guides/evaluate-openai-assistants.md b/site/docs/guides/evaluate-openai-assistants.md index a1214464f2..75e323ada0 100644 --- a/site/docs/guides/evaluate-openai-assistants.md +++ b/site/docs/guides/evaluate-openai-assistants.md @@ -74,7 +74,7 @@ If you want to override the configuration of an assistant for a specific test, y ```yaml providers: - - id: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgB + - model: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgB config: model: gpt-4-1106-preview instructions: 'Enter a replacement for system-level instructions here' @@ -96,10 +96,10 @@ Here's an example that compares the saved Assistant settings against new potenti ```yaml providers: # Original - - id: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgB + - model: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgB # Modified - - id: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgB + - model: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgB config: model: gpt-4-1106-preview instructions: 'Always talk like a pirate' diff --git a/site/docs/guides/evaluate-replicate-lifeboat.md b/site/docs/guides/evaluate-replicate-lifeboat.md index 49489bc44f..661ce69d4c 100644 --- a/site/docs/guides/evaluate-replicate-lifeboat.md +++ b/site/docs/guides/evaluate-replicate-lifeboat.md @@ -33,12 +33,12 @@ prompts: - 'Respond to the user concisely: {{message}}' providers: - - id: openai:chat:gpt-3.5-turbo + - model: openai:chat:gpt-3.5-turbo config: apiKey: 'your_openai_api_key' temperature: 0.01 max_tokens: 512 - - id: openai:chat:meta/llama-2-70b-chat + - model: openai:chat:meta/llama-2-70b-chat config: apiKey: 'your_replicate_api_key' apiBaseUrl: https://openai-proxy.replicate.com diff --git a/site/docs/guides/gemini-vs-gpt.md b/site/docs/guides/gemini-vs-gpt.md index 88a51b39af..aee83875f4 100644 --- a/site/docs/guides/gemini-vs-gpt.md +++ b/site/docs/guides/gemini-vs-gpt.md @@ -55,11 +55,11 @@ prompts: prompts/gemini_prompt.json: gemini_prompt providers: - - id: vertex:gemini-pro + - model: vertex:gemini-pro prompts: gemini_prompt - - id: openai:gpt-3.5-turbo + - model: openai:gpt-3.5-turbo prompts: gpt_prompt - - id: openai:gpt-4 + - model: openai:gpt-4 prompts: gpt_prompt ``` @@ -99,7 +99,7 @@ npx promptfoo@latest view ## Step 5: Add automatic evals (optional) -Automatic evals are a nice way to scale your work, so you don't need to check each outputs every time. +Automatic evals are a nice way to scale your work, so you don't need to check each outputs every time. To add automatic evaluations to your test cases, you'll include assertions in your test cases. Assertions are conditions that the output of the language model must meet for the test case to be considered successful. Here's how you can add them: @@ -150,4 +150,4 @@ In our tiny eval, we observed that GPT 3.5 and Gemini Pro had similar failure mo **The key here is that your results may vary based on your LLM needs, so I encourage you to enter your own test cases and choose the model that is best for you.** -See the [getting started guide](/docs/getting-started) to begin! \ No newline at end of file +See the [getting started guide](/docs/getting-started) to begin! diff --git a/site/docs/guides/gemma-vs-llama.md b/site/docs/guides/gemma-vs-llama.md index 4f4e01a5bd..fe8cb10c37 100644 --- a/site/docs/guides/gemma-vs-llama.md +++ b/site/docs/guides/gemma-vs-llama.md @@ -27,7 +27,7 @@ Let's start by creating a new directory for our eval: npx promptfoo@latest init gemma-vs-llama ``` -`cd gemma-vs-llama` and begin editing `promptfooconfig.yaml`. +`cd gemma-vs-llama` and begin editing `promptfooconfig.yaml`. This config is where you define how you will interact with the Gemma and Llama models. It includes details such as the models you're comparing, the parameters for generating responses, and the format of your prompts. @@ -40,7 +40,7 @@ prompts: - "{{message}}" ``` -Each prompt in this list will be run through both Gemma and Llama. +Each prompt in this list will be run through both Gemma and Llama. You should modify this prompt to match the use case you want to test. For example: @@ -57,7 +57,7 @@ The next section of the configuration file deals with the providers, which in th ##### Llama Configuration ```yaml -- id: replicate:meta/llama-2-7b-chat +- model: replicate:meta/llama-2-7b-chat config: temperature: 0.01 max_new_tokens: 128 @@ -74,7 +74,7 @@ The next section of the configuration file deals with the providers, which in th ##### Gemma Configuration ```yaml -- id: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 +- model: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 config: temperature: 0.01 max_new_tokens: 128 @@ -96,7 +96,7 @@ prompts: - "{{message}}" providers: - - id: replicate:meta/llama-2-7b-chat + - model: replicate:meta/llama-2-7b-chat config: temperature: 0.01 max_new_tokens: 128 @@ -104,7 +104,7 @@ providers: prefix: "[INST] " suffix: "[/INST] " - - id: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 + - model: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 config: temperature: 0.01 max_new_tokens: 128 @@ -115,7 +115,7 @@ providers: ## Step 2: Defining Test Cases -Test cases are where you specify the inputs that will be fed to both models. This is your opportunity to compare how each model handles a variety of requests, from simple queries to complex reasoning tasks. +Test cases are where you specify the inputs that will be fed to both models. This is your opportunity to compare how each model handles a variety of requests, from simple queries to complex reasoning tasks. **_Modify these test cases to fit your needs_**. Here are some examples: @@ -205,12 +205,12 @@ npx promptfoo@latest view After running the evaluation, you'll have a dataset that compares the responses from Gemma and Llama across your test cases. Look for patterns in the results: -- Which model is more accurate or relevant in its responses? +- Which model is more accurate or relevant in its responses? - In our small example set, Llama was a little more likely to hallucinate., e.g. claiming to know the weather in New York. -- Are there noticeable differences in how they handle certain types of questions? +- Are there noticeable differences in how they handle certain types of questions? - It seems like Gemma is more likely to respond verbosely and include markdown formatting. - Llama has a weird habit of roleplaying (e.g. extra output such as `*adjusts glasses*`) and by default prefers to preface responses with "Of course!" Consider the implications of these results for your specific application or use case. Although Gemma outperforms Llama on generic test sets, you must create your own test set in order to really pick a winner! -To learn more about setting up promptfoo, see [Getting Started](/docs/getting-started) or our more detailed [Configuration Guide](/docs/configuration/guide). \ No newline at end of file +To learn more about setting up promptfoo, see [Getting Started](/docs/getting-started) or our more detailed [Configuration Guide](/docs/configuration/guide). diff --git a/site/docs/guides/gemma-vs-mistral.md b/site/docs/guides/gemma-vs-mistral.md index fa385d7d4f..0b52dd2157 100644 --- a/site/docs/guides/gemma-vs-mistral.md +++ b/site/docs/guides/gemma-vs-mistral.md @@ -53,7 +53,7 @@ Next, specify the models you're comparing by setting up their configurations: #### Mistral Configuration ```yaml -- id: replicate:mistralai/mistral-7b-instruct-v0.2 +- model: replicate:mistralai/mistral-7b-instruct-v0.2 config: temperature: 0.01 max_new_tokens: 1024 @@ -65,7 +65,7 @@ Next, specify the models you're comparing by setting up their configurations: #### Mixtral Configuration ```yaml -- id: replicate:mistralai/mixtral-8x7b-instruct-v0.1 +- model: replicate:mistralai/mixtral-8x7b-instruct-v0.1 config: temperature: 0.01 max_new_tokens: 1024 @@ -77,7 +77,7 @@ Next, specify the models you're comparing by setting up their configurations: #### Gemma Configuration ```yaml -- id: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 +- model: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 config: temperature: 0.01 max_new_tokens: 1024 @@ -95,7 +95,7 @@ prompts: - "{{message}}" providers: - - id: replicate:mistralai/mistral-7b-instruct-v0.2 + - model: replicate:mistralai/mistral-7b-instruct-v0.2 config: temperature: 0.01 max_new_tokens: 1024 @@ -103,7 +103,7 @@ providers: prefix: "[INST] " suffix: " [/INST]" - - id: replicate:mistralai/mixtral-8x7b-instruct-v0.1 + - model: replicate:mistralai/mixtral-8x7b-instruct-v0.1 config: temperature: 0.01 max_new_tokens: 1024 @@ -111,7 +111,7 @@ providers: prefix: "[INST] " suffix: " [/INST]" - - id: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 + - model: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 config: temperature: 0.01 max_new_tokens: 1024 @@ -122,7 +122,7 @@ providers: ## Step 2: Build a test set -Design test cases that reflect a variety of requests that are representative of your app's use case. +Design test cases that reflect a variety of requests that are representative of your app's use case. For this example, we're focusing on riddles to test the models' ability to understand and generate creative and logical responses. @@ -239,4 +239,4 @@ Here's what we noticed from our small riddle test set: When constructing your own test set, think about edge cases and unusual criteria that are specific to your app and may not be in model training data. Ideally, it's best to set up a feedback loop where real users of your app can flag failure cases. Use this to build your test set over time. -To learn more about setting up promptfoo, see [Getting Started](/docs/getting-started) or our more detailed [Configuration Guide](/docs/configuration/guide). \ No newline at end of file +To learn more about setting up promptfoo, see [Getting Started](/docs/getting-started) or our more detailed [Configuration Guide](/docs/configuration/guide). diff --git a/site/docs/guides/llama2-uncensored-benchmark-ollama.md b/site/docs/guides/llama2-uncensored-benchmark-ollama.md index 88047ab5f0..738b4ea83c 100644 --- a/site/docs/guides/llama2-uncensored-benchmark-ollama.md +++ b/site/docs/guides/llama2-uncensored-benchmark-ollama.md @@ -76,12 +76,14 @@ prompts: prompts/llama_prompt.txt: llama_prompt providers: - - ollama:llama2: - prompts: llama_prompt - - ollama:llama2-uncensored: - prompts: llama_prompt - - openai:gpt-3.5-turbo: - prompts: openai_prompt + - model: ollama:llama2 + prompts: + - llama_prompt + - model: ollama:llama2-uncensored + prompts: + - llama_prompt + - model: openai:gpt-3.5-turbo + prompts: openai_prompt ``` ## Add test cases @@ -120,12 +122,14 @@ prompts: prompts/llama_prompt.txt: llama_prompt providers: - - ollama:llama2: - prompts: llama_prompt - - ollama:llama2-uncensored: - prompts: llama_prompt - - openai:gpt-3.5-turbo: - prompts: openai_prompt + - model: ollama:llama2 + prompts: + - llama_prompt + - model: ollama:llama2-uncensored + prompts: + - llama_prompt + - model: openai:gpt-3.5-turbo + prompts: openai_prompt ``` Let's set up a few assertions to automatically assess the output for correctness. The `defaultTest` block is a shorthand that adds the `assert` to every test: @@ -171,13 +175,14 @@ prompts: llama_prompt.txt: llama_prompt providers: - - ollama:llama2: - prompts: llama_prompt - - ollama:llama2-uncensored: - prompts: llama_prompt - - openai:gpt-3.5-turbo: - prompts: openai_prompt - + - model: ollama:llama2 + prompts: + - llama_prompt + - model: ollama:llama2-uncensored + prompts: + - llama_prompt + - model: openai:gpt-3.5-turbo + prompts: openai_prompt ``` :::info diff --git a/site/docs/guides/mistral-vs-llama.md b/site/docs/guides/mistral-vs-llama.md index 79454996a4..96e4045cc4 100644 --- a/site/docs/guides/mistral-vs-llama.md +++ b/site/docs/guides/mistral-vs-llama.md @@ -65,12 +65,15 @@ prompts: prompts/llama_prompt.txt: llama_prompt providers: - - huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1: - prompts: mistral_prompt - - replicate:mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e: - prompts: mistral prompt - - replicate:meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e: - prompts: llama_prompt + - model: huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1 + prompts: + - mistral_prompt + - model: replicate:mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e + prompts: + - mistral_prompt + - model: replicate:meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e + prompts: + - llama_prompt ``` :::tip @@ -83,28 +86,31 @@ Each model has a `config` field where you can specify additional parameters. Let ```yaml title=promptfooconfig.yaml providers: - - huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1: - prompts: mistral_prompt - // highlight-start - config: - temperature: 0.01 - max_new_tokens: 128 - // highlight-end - - replicate:mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e - prompts: mistral_prompt - // highlight-start - config: - prompt_template: '{prompt}' - temperature: 0.01 - max_new_tokens: 128 - // highlight-end - - replicate:meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e: - prompts: llama_prompt - // highlight-start - config: - temperature: 0.01 - max_new_tokens: 128 - // highlight-end + - model: huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1 + prompts: + - mistral_prompt + // highlight-start + config: + temperature: 0.01 + max_new_tokens: 128 + // highlight-end + - model: replicate:mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e + prompts: + - mistral_prompt + // highlight-start + config: + prompt_template: '{prompt}' + temperature: 0.01 + max_new_tokens: 128 + // highlight-end + - model: replicate:meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e + prompts: + - llama_prompt + // highlight-start + config: + temperature: 0.01 + max_new_tokens: 128 + // highlight-end ``` Mistral supports [HuggingFace text generation parameters](https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task) whereas Replicate's API has its own set of [supported parameters](https://replicate.com/meta/llama-2-7b-chat/api). diff --git a/site/docs/guides/mixtral-vs-gpt.md b/site/docs/guides/mixtral-vs-gpt.md index e8c0b404f3..7ef6d568ac 100644 --- a/site/docs/guides/mixtral-vs-gpt.md +++ b/site/docs/guides/mixtral-vs-gpt.md @@ -45,8 +45,8 @@ export OPENAI_API_KEY=your_openai_api_key In this example, we're using Replicate, but you can also use providers like [HuggingFace](/docs/providers/huggingface), [TogetherAI](/docs/providers/togetherai), etc: ```yaml -- huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1 -- id: openai:chat:mistralai/Mixtral-8x7B-Instruct-v0.1 +- model: huggingface:text-generation:mistralai/Mistral-7B-Instruct-v0.1 +- model: openai:chat:mistralai/Mixtral-8x7B-Instruct-v0.1 config: apiBaseUrl: https://api.together.xyz ``` @@ -60,19 +60,19 @@ Customize the behavior of each model by setting parameters such as `temperature` ```yaml title=promptfooconfig.yaml providers: - - id: openai:gpt-3.5-turbo-0613 + - model: openai:gpt-3.5-turbo-0613 // highlight-start config: temperature: 0 max_tokens: 128 // highlight-end - - id: openai:gpt-4-turbo-0613 + - model: openai:gpt-4-turbo-0613 // highlight-start config: temperature: 0 max_tokens: 128 // highlight-end - - id: replicate:mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e + - model: replicate:mistralai/mixtral-8x7b-instruct-v0.1:2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e // highlight-start config: temperature: 0.01 diff --git a/site/docs/providers/anthropic.md b/site/docs/providers/anthropic.md index 8e8fe2fa8b..de7799e2fa 100644 --- a/site/docs/providers/anthropic.md +++ b/site/docs/providers/anthropic.md @@ -103,7 +103,7 @@ Config parameters may also be passed like so: ```yaml providers: - - id: anthropic:completion:claude-1 + - model: anthropic:completion:claude-1 prompts: chat_prompt config: temperature: 0 @@ -121,7 +121,7 @@ The easiest way to do this for _all_ your test cases is to add the [`defaultTest defaultTest: options: provider: - id: provider:chat:modelname + model: provider:chat:modelname config: # Provider config options ``` @@ -134,7 +134,7 @@ assert: - type: llm-rubric value: Do not mention that you are an AI or chat assistant provider: - id: provider:chat:modelname + model: provider:chat:modelname config: # Provider config options ``` @@ -148,10 +148,10 @@ tests: # ... options: provider: - id: provider:chat:modelname + model: provider:chat:modelname config: # Provider config options assert: - type: llm-rubric value: Do not mention that you are an AI or chat assistant -``` \ No newline at end of file +``` diff --git a/site/docs/providers/aws-bedrock.md b/site/docs/providers/aws-bedrock.md index 9795a1990f..b14bacc801 100644 --- a/site/docs/providers/aws-bedrock.md +++ b/site/docs/providers/aws-bedrock.md @@ -23,7 +23,7 @@ Additional config parameters are passed like so: ```yaml providers: - - id: bedrock:completion:anthropic.claude-v1 + - model: bedrock:completion:anthropic.claude-v1 // highlight-start config: region: 'us-west-2' @@ -44,7 +44,7 @@ The easiest way to do this for _all_ your test cases is to add the [`defaultTest defaultTest: options: provider: - id: provider:chat:modelname + model: provider:chat:modelname config: # Provider config options ``` @@ -57,7 +57,7 @@ assert: - type: llm-rubric value: Do not mention that you are an AI or chat assistant provider: - id: provider:chat:modelname + model: provider:chat:modelname config: # Provider config options ``` @@ -71,7 +71,7 @@ tests: # ... options: provider: - id: provider:chat:modelname + model: provider:chat:modelname config: # Provider config options assert: diff --git a/site/docs/providers/azure.md b/site/docs/providers/azure.md index 373e86716d..813688fc2d 100644 --- a/site/docs/providers/azure.md +++ b/site/docs/providers/azure.md @@ -19,7 +19,7 @@ Also set the `apiHost` value to point to your endpoint: ```yaml providers: - - id: azureopenai:chat:deploymentNameHere + - model: azureopenai:chat:deploymentNameHere config: apiHost: 'xxxxxxxx.openai.azure.com' ``` @@ -28,7 +28,7 @@ Additional config parameters are passed like so: ```yaml providers: - - id: azureopenai:chat:deploymentNameHere + - model: azureopenai:chat:deploymentNameHere config: apiHost: 'xxxxxxxx.openai.azure.com' // highlight-start @@ -47,7 +47,7 @@ To use client credentials for authentication with Azure, you need to set the fol ```yaml providers: - - id: azureopenai:chat:deploymentNameHere + - model: azureopenai:chat:deploymentNameHere config: apiHost: 'xxxxxxxx.openai.azure.com' azureClientId: 'your-azure-client-id' @@ -77,7 +77,7 @@ The easiest way to do this for _all_ your test cases is to add the [`defaultTest defaultTest: options: provider: - id: azureopenai:chat:gpt-4-deployment-name + model: azureopenai:chat:gpt-4-deployment-name config: apiHost: 'xxxxxxx.openai.azure.com' ``` @@ -90,7 +90,7 @@ assert: - type: llm-rubric value: Do not mention that you are an AI or chat assistant provider: - id: azureopenai:chat:xxxx + model: azureopenai:chat:xxxx config: apiHost: 'xxxxxxx.openai.azure.com' ``` @@ -104,7 +104,7 @@ tests: # ... options: provider: - id: azureopenai:chat:xxxx + model: azureopenai:chat:xxxx config: apiHost: 'xxxxxxx.openai.azure.com' assert: @@ -122,7 +122,7 @@ You may also specify `deployment_id` and `dataSources`, used to integrate with t ```yaml providers: - - id: azureopenai:chat:deploymentNameHere + - model: azureopenai:chat:deploymentNameHere config: apiHost: 'xxxxxxxx.openai.azure.com' deployment_id: 'abc123' diff --git a/site/docs/providers/cohere.md b/site/docs/providers/cohere.md index 8ef04f21aa..95570a7b10 100644 --- a/site/docs/providers/cohere.md +++ b/site/docs/providers/cohere.md @@ -14,7 +14,7 @@ Here's an example configuration: ```yaml providers: - - id: cohere:command + - model: cohere:command config: temperature: 0.5 max_tokens: 256 diff --git a/site/docs/providers/custom-api.md b/site/docs/providers/custom-api.md index 37baf346f2..c935f2679b 100644 --- a/site/docs/providers/custom-api.md +++ b/site/docs/providers/custom-api.md @@ -9,7 +9,14 @@ To create a custom API provider, implement the `ApiProvider` interface in a sepa ```javascript class ApiProvider { constructor(options: { id?: string; config: Record}); - id: () => string; + + // Unique identifier + model: string; + + // Displayed in UI + label: string; + + // Fetch response from LLM callApi: (prompt: string, context: { vars: Record }) => Promise; } ``` @@ -27,16 +34,20 @@ import fetch from 'node-fetch'; class CustomApiProvider { constructor(options) { // Provider ID can be overridden by the config file (e.g. when using multiple of the same provider) - this.providerId = options.id || 'custom provider'; + this.providerId = options.model || 'custom provider'; // options.config contains any custom options passed to the provider this.config = options.config; } - id() { + get model() { return this.providerId; } + get label() { + return `Custom provider with temperature ${this.config.temperature}`; + } + async callApi(prompt, context) { // Add your custom API logic here // Use options like: `this.config.temperature`, `this.config.max_tokens`, etc. @@ -84,12 +95,12 @@ You can instantiate multiple providers of the same type with distinct IDs. In th ```yaml providers: - - customProvider.js: - id: custom-provider-hightemp - config: - temperature: 1.0 - - customProvider.js: - id: custom-provider-lowtemp - config: - temperature: 0 + - model: customProvider.js + label: custom-provider-hightemp + config: + temperature: 1.0 + - model: customProvider.js + label: custom-provider-lowtemp + config: + temperature: 0 ``` diff --git a/site/docs/providers/huggingface.md b/site/docs/providers/huggingface.md index 0b715a7252..c35b88204d 100644 --- a/site/docs/providers/huggingface.md +++ b/site/docs/providers/huggingface.md @@ -74,7 +74,7 @@ Here's an example of how this provider might appear in your promptfoo config: ```yaml providers: - - id: huggingface:text-generation:mistralai/Mistral-7B-v0.1 + - model: huggingface:text-generation:mistralai/Mistral-7B-v0.1 config: temperature: 0.1 max_length: 1024 @@ -99,7 +99,7 @@ prompts: - "Write a tweet about {{topic}}:" providers: - - id: huggingface:text-generation:gemma-7b-it + - model: huggingface:text-generation:gemma-7b-it config: apiEndpoint: https://v9igsezez4ei3cq4.us-east-1.aws.endpoints.huggingface.cloud # apiKey: abc123 # Or set HF_API_TOKEN environment variable @@ -117,7 +117,7 @@ If you're running the [Huggingface Text Generation Inference](https://github.com ```yaml providers: - - id: huggingface:text-generation:my-local-model + - model: huggingface:text-generation:my-local-model config: apiEndpoint: http://127.0.0.1:8080/generate ``` diff --git a/site/docs/providers/localai.md b/site/docs/providers/localai.md index 8e274cc9a1..21b1668f2f 100644 --- a/site/docs/providers/localai.md +++ b/site/docs/providers/localai.md @@ -27,7 +27,7 @@ You can set parameters like `temperature` and `apiBaseUrl` ([full list here](htt ```yaml title=promptfooconfig.yaml providers: - - id: localai:lunademo + - model: localai:lunademo config: temperature: 0.5 ``` diff --git a/site/docs/providers/mistral.md b/site/docs/providers/mistral.md index adc21a10bc..674e81d1f7 100644 --- a/site/docs/providers/mistral.md +++ b/site/docs/providers/mistral.md @@ -45,7 +45,7 @@ Example configuration with options: ```yaml providers: -- id: mistral:mistral-large-latest +- model: mistral:mistral-large-latest config: temperature: 0.7 max_tokens: 512 diff --git a/site/docs/providers/ollama.md b/site/docs/providers/ollama.md index 91a41894f9..9b271e1eb9 100644 --- a/site/docs/providers/ollama.md +++ b/site/docs/providers/ollama.md @@ -36,7 +36,7 @@ To pass configuration options to Ollama, use the `config` key like so: ```yaml title=promptfooconfig.yaml providers: - - id: ollama:llama2 + - model: ollama:llama2 config: num_predict: 1024 ``` diff --git a/site/docs/providers/openai.md b/site/docs/providers/openai.md index 87c0f993ae..36ed1e5ee8 100644 --- a/site/docs/providers/openai.md +++ b/site/docs/providers/openai.md @@ -32,7 +32,7 @@ The OpenAI provider supports a handful of [configuration options](https://github ```yaml title=promptfooconfig.yaml providers: - - id: openai:gpt-3.5-turbo + - model: openai:gpt-3.5-turbo config: temperature: 0 max_tokens: 1024 @@ -64,7 +64,7 @@ The `providers` list takes a `config` key that allows you to set parameters like ```yaml providers: - - id: openai:gpt-3.5-turbo-0613 + - model: openai:gpt-3.5-turbo-0613 config: temperature: 0 max_tokens: 128 @@ -451,7 +451,7 @@ In addition, you can use `functions` to define custom functions. Each function s ```yaml prompts: [prompt.txt] providers: - - id: openai:chat:gpt-3.5-turbo-0613 + - model: openai:chat:gpt-3.5-turbo-0613 // highlight-start config: functions: @@ -526,7 +526,7 @@ providers: Here's an example of how your `provider_with_function.yaml` might look: ```yaml -id: openai:chat:gpt-3.5-turbo-0613 +model: openai:chat:gpt-3.5-turbo-0613 config: functions: - name: get_current_weather @@ -598,7 +598,7 @@ prompts: - 'Write a tweet about {{topic}}' providers: // highlight-start - - id: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgZ + - model: openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgZ config: model: gpt-4-1106-preview instructions: "You always speak like a pirate" @@ -630,7 +630,7 @@ module.exports = /** @type {import('promptfoo').TestSuiteConfig} */ ({ prompts: 'Please add the following numbers together: {{a}} and {{b}}', providers: [ { - id: 'openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgZ', + model: 'openai:assistant:asst_fEhNN3MClMamLfKLkIaoIpgZ', config: /** @type {InstanceType["config"]} */ ({ model: 'gpt-4-1106-preview', diff --git a/site/docs/providers/openrouter.md b/site/docs/providers/openrouter.md index dd995b65ca..5d74bdf931 100644 --- a/site/docs/providers/openrouter.md +++ b/site/docs/providers/openrouter.md @@ -8,7 +8,7 @@ Here's an example of how to configure the provider to use the `mistralai/mistral ```yaml providers: - - id: openai:chat:mistralai/mistral-medium # or any other chat model + - model: openai:chat:mistralai/mistral-medium # or any other chat model config: apiBaseUrl: https://openrouter.ai/api apiKey: openrouter_api_key diff --git a/site/docs/providers/palm.md b/site/docs/providers/palm.md index 02d956f29e..3bb2230a44 100644 --- a/site/docs/providers/palm.md +++ b/site/docs/providers/palm.md @@ -26,7 +26,7 @@ The PaLM provider supports various [configuration options](https://github.com/pr ```yaml providers: - - id: google:gemini-pro + - model: google:gemini-pro config: temperature: 0 maxOutputTokens: 1024 diff --git a/site/docs/providers/perplexity.md b/site/docs/providers/perplexity.md index 0e4214838d..6a2e5de673 100644 --- a/site/docs/providers/perplexity.md +++ b/site/docs/providers/perplexity.md @@ -8,11 +8,11 @@ Here's an example config that compares Perplexity's 70B model with Llama-2 70B. ```yaml providers: - - id: openai:chat:pplx-70b-chat-alpha + - model: openai:chat:pplx-70b-chat-alpha config: apiHost: api.perplexity.ai apiKeyEnvar: PERPLEXITY_API_KEY - - id: openai:chat:llama-2-70b-chat + - model: openai:chat:llama-2-70b-chat config: apiHost: api.perplexity.ai apiKeyEnvar: PERPLEXITY_API_KEY diff --git a/site/docs/providers/python.md b/site/docs/providers/python.md index 68084ed9d7..b2d8367cbf 100644 --- a/site/docs/providers/python.md +++ b/site/docs/providers/python.md @@ -12,7 +12,8 @@ To configure the Python provider, you need to specify the path to your Python sc ```yaml providers: - - id: 'python:my_script.py' + - model: 'python:my_script.py' + label: 'RAG fetch' config: additionalOption: 123 ``` @@ -64,7 +65,7 @@ The types passed into the Python script function and the `ProviderResponse` retu ```python class ProviderOptions: - id: Optional[str] + model: Optional[str] config: Optional[Dict[str, Any]] class CallApiContextParams: diff --git a/site/docs/providers/replicate.md b/site/docs/providers/replicate.md index 3cbba7ffcc..5aaa95ed2c 100644 --- a/site/docs/providers/replicate.md +++ b/site/docs/providers/replicate.md @@ -18,7 +18,7 @@ Here's an example of using Llama on Replicate. In the case of Llama, the version ```yaml providers: - - id: replicate:meta/llama-2-7b-chat + - model: replicate:meta/llama-2-7b-chat config: temperature: 0.01 max_length: 1024 @@ -31,7 +31,7 @@ Here's an example of using Gemma on Replicate. Note that unlike Llama, it does ```yaml providers: - - id: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 + - model: replicate:cjwbw/gemma-7b-it:2790a695e5dcae15506138cc4718d1106d0d475e6dca4b1d43f42414647993d5 config: temperature: 0.01 max_new_tokens: 1024 @@ -79,4 +79,4 @@ Supported environment variables: - `REPLICATE_TOP_K` - Controls the top-k sampling: the number of highest probability vocabulary tokens to keep for top-k-filtering. - `REPLICATE_SEED` - Sets a seed for reproducible results. - `REPLICATE_STOP_SEQUENCES` - Specifies stopping sequences that halt the generation. -- `REPLICATE_SYSTEM_PROMPT` - Sets a system-level prompt for all requests. \ No newline at end of file +- `REPLICATE_SYSTEM_PROMPT` - Sets a system-level prompt for all requests. diff --git a/site/docs/providers/text-generation-webui.md b/site/docs/providers/text-generation-webui.md index 47d4c671bd..3d48d97764 100644 --- a/site/docs/providers/text-generation-webui.md +++ b/site/docs/providers/text-generation-webui.md @@ -10,22 +10,22 @@ python server.py --loader --model --api ``` -Usage is compatible with the [OpenAI API](/docs/providers/openai). +Usage is compatible with the [OpenAI API](/docs/providers/openai). In promptfoo we can address the API as follows. ```yaml providers: - - openai:chat:: - id: - config: - apiKey: placeholder - apiBaseUrl: http://localhost:5000 - temperature: 0.8 - max_tokens: 1024 - passthrough: # These config values are passed directly to the API - mode: instruct - instruction_template: LLama-v2 + - model: openai:chat: + label: + config: + apiKey: placeholder + apiBaseUrl: http://localhost:5000 + temperature: 0.8 + max_tokens: 1024 + passthrough: # These config values are passed directly to the API + mode: instruct + instruction_template: LLama-v2 ``` -If desired, you can instead use the `OPENAI_API_BASE_URL` and `OPENAI_API_KEY` environment variables instead of the `apiBaseUrl` and `apiKey` configs. \ No newline at end of file +If desired, you can instead use the `OPENAI_API_BASE_URL` and `OPENAI_API_KEY` environment variables instead of the `apiBaseUrl` and `apiKey` configs. diff --git a/site/docs/providers/togetherai.md b/site/docs/providers/togetherai.md index 45010163db..30f39e810b 100644 --- a/site/docs/providers/togetherai.md +++ b/site/docs/providers/togetherai.md @@ -8,7 +8,7 @@ Here's an example config that uses Mixtral provided by Together AI: ```yaml providers: - - id: openai:chat:mistralai/Mixtral-8x7B-Instruct-v0.1 + - model: openai:chat:mistralai/Mixtral-8x7B-Instruct-v0.1 config: apiBaseUrl: https://api.together.xyz apiKeyEnvar: TOGETHER_API_KEY diff --git a/site/docs/providers/vertex.md b/site/docs/providers/vertex.md index 302932355a..eaa50751d0 100644 --- a/site/docs/providers/vertex.md +++ b/site/docs/providers/vertex.md @@ -35,7 +35,7 @@ The Vertex provider also supports various [configuration options](https://github ```yaml providers: - - id: vertex:chat-bison-32k + - model: vertex:chat-bison-32k config: temperature: 0 maxOutputTokens: 1024 diff --git a/site/docs/providers/vllm.md b/site/docs/providers/vllm.md index 8b9ab935ef..2577681042 100644 --- a/site/docs/providers/vllm.md +++ b/site/docs/providers/vllm.md @@ -8,7 +8,7 @@ Here's an example config that uses Mixtral-8x7b for text completions: ```yaml providers: - - id: openai:completion:mistralai/Mixtral-8x7B-v0.1 + - model: openai:completion:mistralai/Mixtral-8x7B-v0.1 config: apiBaseUrl: http://localhost:8080 ``` diff --git a/site/docs/providers/webhook.md b/site/docs/providers/webhook.md index 86787a6e4a..ca87c316cd 100644 --- a/site/docs/providers/webhook.md +++ b/site/docs/providers/webhook.md @@ -34,7 +34,7 @@ It is possible to set webhook provider properties under the `config` key by usin ```yaml providers: - - id: webhook:http://example.com/webhook + - model: webhook:http://example.com/webhook config: foo: bar test: 123 diff --git a/src/evaluator.ts b/src/evaluator.ts index c4ff2e25a0..7aaf75c53f 100644 --- a/src/evaluator.ts +++ b/src/evaluator.ts @@ -211,7 +211,7 @@ class Evaluator { // Set up the special _conversation variable const vars = test.vars || {}; - const conversationKey = `${provider.id()}:${prompt.id}`; + const conversationKey = `${provider.label}:${prompt.id}`; const usesConversation = prompt.raw.includes('_conversation'); if ( !process.env.PROMPTFOO_DISABLE_CONVERSATION_VAR && @@ -231,7 +231,8 @@ class Evaluator { const setup = { provider: { - id: provider.id(), + model: provider.model, + label: provider.label, }, prompt: { raw: renderedPrompt, @@ -411,15 +412,18 @@ class Evaluator { for (const provider of testSuite.providers) { // Check if providerPromptMap exists and if it contains the current prompt's display if (testSuite.providerPromptMap) { - const allowedPrompts = testSuite.providerPromptMap[provider.id()]; + console.log('ayo', testSuite.providerPromptMap, provider.label) + const allowedPrompts = testSuite.providerPromptMap[provider.label]; + console.log('allowedPrompts', allowedPrompts) if (allowedPrompts && !allowedPrompts.includes(prompt.display)) { + console.log('skip'); continue; } } - prompts.push({ + const completedPrompt = { ...prompt, id: sha256(typeof prompt.raw === 'object' ? JSON.stringify(prompt.raw) : prompt.raw), - provider: provider.id(), + provider: provider.label || provider.model, display: prompt.display, metrics: { score: 0, @@ -437,10 +441,13 @@ class Evaluator { namedScores: {}, cost: 0, }, - }); + }; + prompts.push(completedPrompt); } } + console.log('prompts', prompts); + // Aggregate all vars across test cases let tests = testSuite.tests && testSuite.tests.length > 0 @@ -539,7 +546,7 @@ class Evaluator { for (const prompt of testSuite.prompts) { for (const provider of testSuite.providers) { if (testSuite.providerPromptMap) { - const allowedPrompts = testSuite.providerPromptMap[provider.id()]; + const allowedPrompts = testSuite.providerPromptMap[provider.label]; if (allowedPrompts && !allowedPrompts.includes(prompt.display)) { // This prompt should not be used with this provider. continue; @@ -624,7 +631,7 @@ class Evaluator { numComplete++; if (progressbar) { progressbar.increment({ - provider: evalStep.provider.id(), + provider: evalStep.provider.label, prompt: evalStep.prompt.raw.slice(0, 10).replace(/\n/g, ' '), vars: Object.entries(evalStep.test.vars || {}) .map(([k, v]) => `${k}=${v}`) @@ -686,7 +693,7 @@ class Evaluator { namedScores: row.namedScores, text: resultText, prompt: row.prompt.raw, - provider: row.provider.id, + provider: row.provider.label, latencyMs: row.latencyMs, tokenUsage: row.response?.tokenUsage, gradingResult: row.gradingResult, @@ -792,7 +799,7 @@ class Evaluator { providerPrefixes: Array.from( new Set( testSuite.providers.map((p) => { - const idParts = p.id().split(':'); + const idParts = p.model.split(':'); return idParts.length > 1 ? idParts[0] : 'unknown'; }), ), diff --git a/src/index.ts b/src/index.ts index 0675d89d3b..53ce4bef3f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -56,8 +56,8 @@ async function evaluate(testSuite: EvaluateTestSuite, options: EvaluateOptions = if (assertion.provider) { if (typeof assertion.provider === 'object') { const casted = assertion.provider as ProviderOptions; - invariant(casted.id, 'Provider object must have an id'); - assertion.provider = await loadApiProvider(casted.id, { options: casted }); + invariant(casted.model, 'Provider object must have a `model` property'); + assertion.provider = await loadApiProvider(casted.model, { options: casted }); } else if (typeof assertion.provider === 'string') { assertion.provider = await loadApiProvider(assertion.provider); } else { diff --git a/src/matchers.ts b/src/matchers.ts index ec51d3aeba..206a2204be 100644 --- a/src/matchers.ts +++ b/src/matchers.ts @@ -71,9 +71,10 @@ async function loadFromProviderOptions(provider: ProviderOptions) { !Array.isArray(provider), `Provider must be an object, but received an array: ${JSON.stringify(provider)}`, ); - invariant(provider.id, 'Provider supplied to assertion must have an id'); + const model = provider.model || provider.id; + invariant(typeof model === 'string', 'Provider supplied to assertion must have a `model` property'); // TODO(ian): set basepath if invoked from filesystem config - return loadApiProvider(provider.id, { options: provider as ProviderOptions }); + return loadApiProvider(model, { options: provider as ProviderOptions }); } export async function getGradingProvider( @@ -85,7 +86,7 @@ export async function getGradingProvider( if (typeof provider === 'string') { // Defined as a string finalProvider = await loadApiProvider(provider); - } else if (typeof provider === 'object' && typeof (provider as ApiProvider).id === 'function') { + } else if (typeof provider === 'object' && typeof (provider as ApiProvider).model === 'string') { // Defined as an ApiProvider interface finalProvider = provider as ApiProvider; } else if (typeof provider === 'object') { @@ -93,7 +94,7 @@ export async function getGradingProvider( if (typeValue) { // Defined as embedding, classification, or text record finalProvider = await getGradingProvider(type, typeValue, defaultProvider); - } else if ((provider as ProviderOptions).id) { + } else if ((provider as ProviderOptions).model) { // Defined as ProviderOptions finalProvider = await loadFromProviderOptions(provider as ProviderOptions); } else { @@ -138,12 +139,12 @@ export async function getAndCheckProvider( if (!isValidProviderType) { if (defaultProvider) { logger.warn( - `Provider ${matchedProvider.id()} is not a valid ${type} provider for '${checkName}', falling back to default`, + `Provider ${matchedProvider.label} is not a valid ${type} provider for '${checkName}', falling back to default`, ); return defaultProvider; } else { throw new Error( - `Provider ${matchedProvider.id()} is not a valid ${type} provider for '${checkName}'`, + `Provider ${matchedProvider.label} is not a valid ${type} provider for '${checkName}'`, ); } } @@ -532,7 +533,7 @@ export async function matchesAnswerRelevance( invariant( typeof embeddingProvider.callEmbeddingApi === 'function', - `Provider ${embeddingProvider.id} must implement callEmbeddingApi for similarity check`, + `Provider ${embeddingProvider.label} must implement callEmbeddingApi for similarity check`, ); const inputEmbeddingResp = await embeddingProvider.callEmbeddingApi(input); diff --git a/src/prompts.ts b/src/prompts.ts index 00caf821aa..2940af988f 100644 --- a/src/prompts.ts +++ b/src/prompts.ts @@ -26,6 +26,7 @@ export function readProviderPromptMap( parsedPrompts: Prompt[], ): TestSuite['providerPromptMap'] { const ret: Record = {}; + console.log('provider promptmap'); if (!config.providers) { return ret; @@ -47,22 +48,25 @@ export function readProviderPromptMap( for (const provider of config.providers) { if (typeof provider === 'object') { // It's either a ProviderOptionsMap or a ProviderOptions - if (provider.id) { - const rawProvider = provider as ProviderOptions; + const providerOptions = provider as ProviderOptions; + const key = providerOptions.label || providerOptions.model; + if (key) { invariant( - rawProvider.id, - 'You must specify an `id` on the Provider when you override options.prompts', + key, + 'You must specify a `model` or `label` property on the Provider when you override options.prompts', ); - ret[rawProvider.id] = rawProvider.prompts || allPrompts; + ret[key] = providerOptions.prompts || allPrompts; } else { + // Backwards compatibility with deprecated provider options map (2024-03-05) const rawProvider = provider as ProviderOptionsMap; const originalId = Object.keys(rawProvider)[0]; const providerObject = rawProvider[originalId]; - const id = providerObject.id || originalId; + const id = providerObject.label || originalId; ret[id] = rawProvider[originalId].prompts || allPrompts; } } } + console.log(ret) return ret; } diff --git a/src/providers.ts b/src/providers.ts index 60bec1f06d..30ae125303 100644 --- a/src/providers.ts +++ b/src/providers.ts @@ -69,7 +69,8 @@ export async function loadApiProviders( } else if (typeof providerPaths === 'function') { return [ { - id: () => 'custom-function', + model: 'custom-function', + label: 'Custom Function', callApi: providerPaths, }, ]; @@ -80,22 +81,26 @@ export async function loadApiProviders( return loadApiProvider(provider, { basePath, env }); } else if (typeof provider === 'function') { return { - id: () => `custom-function-${idx}`, + model: `custom-function-${idx}`, + label: `Custom Function ${idx}`, callApi: provider, }; - } else if (provider.id) { + } else if (provider.model || provider.id) { // List of ProviderConfig objects - return loadApiProvider((provider as ProviderOptions).id!, { + const providerOptions = provider as ProviderOptions; + const model = providerOptions.model || providerOptions.id; + invariant(typeof model === 'string', 'Provider must have a `model` property'); + return loadApiProvider(model, { options: provider, basePath, env, }); } else { - // List of { id: string, config: ProviderConfig } objects - const id = Object.keys(provider)[0]; - const providerObject = (provider as ProviderOptionsMap)[id]; - const context = { ...providerObject, id: providerObject.id || id }; - return loadApiProvider(id, { options: context, basePath, env }); + // List of { model: string, config: ProviderConfig } objects + const model = Object.keys(provider)[0]; + const providerObject = (provider as ProviderOptionsMap)[model]; + const context = { ...providerObject, model: providerObject.model || providerObject.id || model }; + return loadApiProvider(model, { options: context, basePath, env }); } }), ); @@ -111,9 +116,11 @@ export async function loadApiProvider( env?: EnvOverrides; } = {}, ): Promise { + console.trace(providerPath) const { options = {}, basePath, env } = context; - const providerOptions = { - id: options.id, + const providerOptions: ProviderOptions = { + model: options.model || options.id, // backwards compatibility 2024-03-06 + label: options.label || options.model || options.id, config: { ...options.config, basePath, @@ -124,12 +131,13 @@ export async function loadApiProvider( const filePath = providerPath.slice('file://'.length); const yamlContent = yaml.load(fs.readFileSync(filePath, 'utf8')) as ProviderOptions; invariant(yamlContent, `Provider config ${filePath} is undefined`); - invariant(yamlContent.id, `Provider config ${filePath} must have an id`); - logger.info(`Loaded provider ${yamlContent.id} from ${filePath}`); - return loadApiProvider(yamlContent.id, { ...context, options: yamlContent }); + invariant(yamlContent.model, `Provider config ${filePath} must have an id`); + logger.info(`Loaded provider ${yamlContent.model} from ${filePath}`); + return loadApiProvider(yamlContent.model, { ...context, options: yamlContent }); } else if (providerPath === 'echo') { return { - id: () => 'echo', + model: 'echo', + label: 'Echo', callApi: async (input) => ({output: input}), }; } else if (providerPath?.startsWith('exec:')) { diff --git a/src/providers/anthropic.ts b/src/providers/anthropic.ts index cd4fa574cb..de2a740fb0 100644 --- a/src/providers/anthropic.ts +++ b/src/providers/anthropic.ts @@ -1,11 +1,11 @@ import Anthropic from '@anthropic-ai/sdk'; import logger from '../logger'; -import type { ApiProvider, EnvOverrides, ProviderResponse, TokenUsage } from '../types.js'; - import { getCache, isCacheEnabled } from '../cache'; import { parseChatPrompt } from './shared'; +import type { ApiProvider, EnvOverrides, ProviderOptions, ProviderResponse, TokenUsage } from '../types.js'; + interface AnthropicMessageOptions { apiKey?: string; temperature?: number; @@ -213,6 +213,8 @@ interface AnthropicCompletionOptions { top_k?: number; } +type AnthropicGenericOptions = ProviderOptions & { config?: AnthropicCompletionOptions }; + export class AnthropicCompletionProvider implements ApiProvider { static ANTHROPIC_COMPLETION_MODELS = [ 'claude-1', @@ -227,24 +229,29 @@ export class AnthropicCompletionProvider implements ApiProvider { modelName: string; apiKey?: string; anthropic: Anthropic; + options: AnthropicGenericOptions; config: AnthropicCompletionOptions; constructor( modelName: string, - options: { config?: AnthropicCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: AnthropicGenericOptions = {}, ) { - const { config, id, env } = options; + const { config, env } = options; this.modelName = modelName; this.apiKey = config?.apiKey || env?.ANTHROPIC_API_KEY || process.env.ANTHROPIC_API_KEY; this.anthropic = new Anthropic({ apiKey: this.apiKey }); - this.config = config || {}; - this.id = id ? () => id : this.id; + this.options = options; + this.config = options.config; } - id(): string { + get model(): string { return `anthropic:${this.modelName}`; } + get label(): string { + return this.options.label || this.model; + } + toString(): string { return `[Anthropic Provider ${this.modelName}]`; } diff --git a/src/providers/azureopenai.ts b/src/providers/azureopenai.ts index 8bd0392c1b..22aaa1632d 100644 --- a/src/providers/azureopenai.ts +++ b/src/providers/azureopenai.ts @@ -8,6 +8,7 @@ import type { CallApiOptionsParams, EnvOverrides, ProviderEmbeddingResponse, + ProviderOptions, ProviderResponse, } from '../types'; @@ -47,18 +48,21 @@ interface AzureOpenAiCompletionOptions { passthrough?: object; } +type AzureOpenAiGenericOptions = ProviderOptions & { config?: AzureOpenAiCompletionOptions }; + class AzureOpenAiGenericProvider implements ApiProvider { deploymentName: string; apiHost?: string; apiBaseUrl?: string; + options: AzureOpenAiGenericOptions; config: AzureOpenAiCompletionOptions; constructor( deploymentName: string, - options: { config?: AzureOpenAiCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: AzureOpenAiGenericOptions = {}, ) { - const { config, id, env } = options; + const { config, env } = options; this.deploymentName = deploymentName; @@ -67,8 +71,16 @@ class AzureOpenAiGenericProvider implements ApiProvider { this.apiBaseUrl = config?.apiBaseUrl || env?.AZURE_OPENAI_API_BASE_URL || process.env.AZURE_OPENAI_API_BASE_URL; - this.config = config || {}; - this.id = id ? () => id : this.id; + this.options = options; + this.config = options.config; + } + + get model() { + return `azureopenai:${this.options.model}`; + } + + get label() { + return this.options.label || this.model; } _cachedApiKey?: string; diff --git a/src/providers/bedrock.ts b/src/providers/bedrock.ts index e7462e9363..17feeff949 100644 --- a/src/providers/bedrock.ts +++ b/src/providers/bedrock.ts @@ -5,7 +5,7 @@ import { getCache, isCacheEnabled } from '../cache'; import type { BedrockRuntime } from '@aws-sdk/client-bedrock-runtime'; -import type { ApiProvider, EnvOverrides, ProviderResponse } from '../types.js'; +import type { ApiProvider, EnvOverrides, ProviderOptions, ProviderResponse } from '../types.js'; interface BedrockOptions { region?: string; @@ -73,29 +73,36 @@ const AWS_BEDROCK_MODELS: Record = { 'amazon.titan-text-express-v1': BEDROCK_MODEL.TITAN_TEXT, }; +type AwsBedrockGenericOptions = ProviderOptions & { config?: BedrockOptions }; + export class AwsBedrockCompletionProvider implements ApiProvider { static AWS_BEDROCK_COMPLETION_MODELS = Object.keys(AWS_BEDROCK_MODELS); modelName: string; + options: AwsBedrockGenericOptions; config: BedrockOptions; env?: EnvOverrides; bedrock?: BedrockRuntime; constructor( modelName: string, - options: { config?: BedrockOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions = {}, ) { - const { config, id, env } = options; + const { config, env } = options; this.env = env; this.modelName = modelName; - this.config = config || {}; - this.id = id ? () => id : this.id; + this.options = options; + this.config = config; } - id(): string { + get model(): string { return `bedrock:${this.modelName}`; } + get label(): string { + return this.options.label || this.model; + } + async getBedrockInstance() { if (!this.bedrock) { try { diff --git a/src/providers/cohere.ts b/src/providers/cohere.ts index 62c75aab62..c71c93ec53 100644 --- a/src/providers/cohere.ts +++ b/src/providers/cohere.ts @@ -1,7 +1,7 @@ import { fetchWithCache } from '../cache'; import logger from '../logger'; -import type { ApiProvider, EnvOverrides, ProviderResponse, TokenUsage } from '../types'; +import type { ApiProvider, EnvOverrides, ProviderOptions, ProviderResponse, TokenUsage } from '../types'; import { REQUEST_TIMEOUT_MS } from './shared'; interface CohereChatOptions { @@ -46,28 +46,33 @@ export class CohereChatCompletionProvider implements ApiProvider { 'command-light-nightly', ]; - private apiKey: string; - private modelName: string; - private config: CohereChatOptions; + apiKey: string; + modelName: string; + options: ProviderOptions & { config?: CohereChatOptions }; + config: CohereChatOptions; constructor( modelName: string, - options: { config?: CohereChatOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: CohereChatOptions } = {}, ) { - const { config, id, env } = options; - this.apiKey = config?.apiKey || env?.COHERE_API_KEY || process.env.COHERE_API_KEY || ''; this.modelName = modelName; + this.options = options; + this.config = options.config || {} as CohereChatOptions; + this.apiKey = this.config.apiKey || process.env.COHERE_API_KEY || ''; + if (!CohereChatCompletionProvider.COHERE_CHAT_MODELS.includes(this.modelName)) { logger.warn(`Using unknown Cohere chat model: ${this.modelName}`); } - this.id = id ? () => id : this.id; - this.config = config || {}; } - id() { + get model() { return `cohere:${this.modelName}`; } + get label() { + return this.options.label || this.model; + } + async callApi(prompt: string): Promise { if (!this.apiKey) { return { error: 'Cohere API key is not set. Please provide a valid apiKey.' }; diff --git a/src/providers/huggingface.ts b/src/providers/huggingface.ts index 73550a5e6f..664c57503d 100644 --- a/src/providers/huggingface.ts +++ b/src/providers/huggingface.ts @@ -6,11 +6,38 @@ import { ApiSimilarityProvider, ProviderClassificationResponse, ProviderEmbeddingResponse, + ProviderOptions, ProviderResponse, ProviderSimilarityResponse, } from '../types'; import { REQUEST_TIMEOUT_MS } from './shared'; +abstract class HuggingfaceGenericProvider implements ApiProvider { + modelName: string; + options: ProviderOptions & { config?: TOptions }; + config: TOptions; + + constructor( + modelName: string, + options: ProviderOptions & { config?: TOptions } = {}, + ) { + this.modelName = modelName; + this.options = options; + this.config = options.config || {} as TOptions; + } + + get model() { + return `huggingface:text-generation:${this.modelName}`; + } + + get label() { + return this.options.label || this.model; + } + + abstract callApi(prompt: string): Promise; +} + + interface HuggingfaceTextGenerationOptions { apiKey?: string; apiEndpoint?: string; @@ -41,24 +68,7 @@ const HuggingFaceTextGenerationKeys = new Set id : this.id; - this.config = config || {}; - } - - id(): string { - return `huggingface:text-generation:${this.modelName}`; - } - +export class HuggingfaceTextGenerationProvider extends HuggingfaceGenericProvider { toString(): string { return `[Huggingface Text Generation Provider ${this.modelName}]`; } @@ -135,22 +145,12 @@ interface HuggingfaceTextClassificationOptions { apiEndpoint?: string; } -export class HuggingfaceTextClassificationProvider implements ApiProvider { - modelName: string; - config: HuggingfaceTextClassificationOptions; - +export class HuggingfaceTextClassificationProvider extends HuggingfaceGenericProvider { constructor( modelName: string, - options: { id?: string; config?: HuggingfaceTextClassificationOptions } = {}, + options: ProviderOptions & { config?: HuggingfaceTextClassificationOptions } = {}, ) { - const { id, config } = options; - this.modelName = modelName; - this.id = id ? () => id : this.id; - this.config = config || {}; - } - - id(): string { - return `huggingface:text-classification:${this.modelName}`; + super(modelName, options); } toString(): string { @@ -225,24 +225,7 @@ interface HuggingfaceFeatureExtractionOptions { wait_for_model?: boolean; } -export class HuggingfaceFeatureExtractionProvider implements ApiProvider { - modelName: string; - config: HuggingfaceFeatureExtractionOptions; - - constructor( - modelName: string, - options: { id?: string; config?: HuggingfaceFeatureExtractionOptions } = {}, - ) { - const { id, config } = options; - this.modelName = modelName; - this.id = id ? () => id : this.id; - this.config = config || {}; - } - - id(): string { - return `huggingface:feature-extraction:${this.modelName}`; - } - +export class HuggingfaceFeatureExtractionProvider extends HuggingfaceGenericProvider { toString(): string { return `[Huggingface Feature Extraction Provider ${this.modelName}]`; } @@ -309,24 +292,7 @@ interface HuggingfaceSentenceSimilarityOptions { wait_for_model?: boolean; } -export class HuggingfaceSentenceSimilarityProvider implements ApiSimilarityProvider { - modelName: string; - config: HuggingfaceSentenceSimilarityOptions; - - constructor( - modelName: string, - options: { id?: string; config?: HuggingfaceSentenceSimilarityOptions } = {}, - ) { - const { id, config } = options; - this.modelName = modelName; - this.id = id ? () => id : this.id; - this.config = config || {}; - } - - id(): string { - return `huggingface:sentence-similarity:${this.modelName}`; - } - +export class HuggingfaceSentenceSimilarityProvider extends HuggingfaceGenericProvider { toString(): string { return `[Huggingface Sentence Similarity Provider ${this.modelName}]`; } diff --git a/src/providers/llama.ts b/src/providers/llama.ts index d0bf7fa938..5cff5f3f48 100644 --- a/src/providers/llama.ts +++ b/src/providers/llama.ts @@ -1,7 +1,7 @@ import { fetchWithCache } from '../cache'; import { REQUEST_TIMEOUT_MS } from './shared'; -import type { ApiProvider, ProviderResponse } from '../types.js'; +import type { ApiProvider, ProviderOptions, ProviderResponse } from '../types.js'; interface LlamaCompletionOptions { n_predict?: number; @@ -25,19 +25,23 @@ interface LlamaCompletionOptions { export class LlamaProvider implements ApiProvider { modelName: string; + options?: ProviderOptions; config?: LlamaCompletionOptions; - constructor(modelName: string, options: { config?: LlamaCompletionOptions; id?: string } = {}) { - const { config, id } = options; + constructor(modelName: string, options: ProviderOptions & { config?: LlamaCompletionOptions } = {}) { this.modelName = modelName; - this.config = config; - this.id = id ? () => id : this.id; + this.options = options; + this.config = options.config; } - id(): string { + get model(): string { return `llama:${this.modelName}`; } + get label(): string { + return this.options?.label || this.model; + } + toString(): string { return `[Llama Provider ${this.modelName}]`; } diff --git a/src/providers/localai.ts b/src/providers/localai.ts index 93cb094740..7070b97d81 100644 --- a/src/providers/localai.ts +++ b/src/providers/localai.ts @@ -6,6 +6,7 @@ import type { ApiProvider, EnvOverrides, ProviderEmbeddingResponse, + ProviderOptions, ProviderResponse, } from '../types.js'; @@ -16,28 +17,31 @@ interface LocalAiCompletionOptions { class LocalAiGenericProvider implements ApiProvider { modelName: string; - apiBaseUrl: string; + options: ProviderOptions & { config?: LocalAiCompletionOptions }; config: LocalAiCompletionOptions; + apiBaseUrl: string; constructor( modelName: string, - options: { config?: LocalAiCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: LocalAiCompletionOptions } = {}, ) { - const { id, config, env } = options; this.modelName = modelName; + this.options = options; + this.config = options.config || {}; this.apiBaseUrl = - config?.apiBaseUrl || - env?.LOCALAI_BASE_URL || + this.options.env?.LOCALAI_BASE_URL || + this.config.apiBaseUrl || process.env.LOCALAI_BASE_URL || 'http://localhost:8080/v1'; - this.config = config || {}; - this.id = id ? () => id : this.id; } - id(): string { + get model(): string { return `localai:${this.modelName}`; } + get label(): string { + return this.options.label || this.model; + } toString(): string { return `[LocalAI Provider ${this.modelName}]`; } diff --git a/src/providers/mistral.ts b/src/providers/mistral.ts index a2c02867de..903a3796d7 100644 --- a/src/providers/mistral.ts +++ b/src/providers/mistral.ts @@ -4,6 +4,7 @@ import { fetchWithCache } from '../cache'; import { ApiProvider, EnvOverrides, + ProviderOptions, ProviderResponse, TokenUsage, } from '../types'; @@ -59,6 +60,7 @@ function calculateCost( export class MistralChatCompletionProvider implements ApiProvider { modelName: string; + options: ProviderOptions; config: MistralChatCompletionOptions; env?: EnvOverrides; @@ -106,22 +108,26 @@ export class MistralChatCompletionProvider implements ApiProvider { constructor( modelName: string, - options: { id?: string; config?: MistralChatCompletionOptions; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: MistralChatCompletionOptions } = {}, ) { if (!MistralChatCompletionProvider.MISTRAL_CHAT_MODELS_NAMES.includes(modelName)) { logger.warn(`Using unknown Mistral chat model: ${modelName}`); } - const { id, config, env } = options; + const { config, env } = options; this.env = env; this.modelName = modelName; - this.id = id ? () => id : this.id; - this.config = config || {}; + this.options = options; + this.config = config; } - id(): string { + get model(): string { return `mistral:${this.modelName}`; } + get label(): string { + return this.options.label || this.model; + } + toString(): string { return `[Mistral Provider ${this.modelName}]`; } diff --git a/src/providers/ollama.ts b/src/providers/ollama.ts index dbd1449153..2abd91b310 100644 --- a/src/providers/ollama.ts +++ b/src/providers/ollama.ts @@ -2,7 +2,7 @@ import logger from '../logger'; import { fetchWithCache } from '../cache'; import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared'; -import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js'; +import type { ApiProvider, ProviderEmbeddingResponse, ProviderOptions, ProviderResponse } from '../types.js'; interface OllamaCompletionOptions { // From https://github.com/jmorganca/ollama/blob/v0.1.0/api/types.go#L161 @@ -116,17 +116,21 @@ interface OllamaChatJsonL { export class OllamaCompletionProvider implements ApiProvider { modelName: string; + options: ProviderOptions; config: OllamaCompletionOptions; - constructor(modelName: string, options: { id?: string; config?: OllamaCompletionOptions } = {}) { - const { id, config } = options; + constructor(modelName: string, options: ProviderOptions & { config?: OllamaCompletionOptions } = {}) { this.modelName = modelName; - this.id = id ? () => id : this.id; - this.config = config || {}; + this.options = options; + this.config = options.config || {}; } - id(): string { - return `ollama:completion:${this.modelName}`; + get model(): string { + return `ollama:${this.modelName}`; + } + + get label(): string { + return this.options.label || this.model; } toString(): string { @@ -200,19 +204,22 @@ export class OllamaCompletionProvider implements ApiProvider { export class OllamaChatProvider implements ApiProvider { modelName: string; + options: ProviderOptions; config: OllamaCompletionOptions; - constructor(modelName: string, options: { id?: string; config?: OllamaCompletionOptions } = {}) { - const { id, config } = options; + constructor(modelName: string, options: ProviderOptions & { config?: OllamaCompletionOptions } = {}) { this.modelName = modelName; - this.id = id ? () => id : this.id; - this.config = config || {}; + this.options = options; + this.config = options.config || {}; } - id(): string { + get model(): string { return `ollama:chat:${this.modelName}`; } + get label(): string { + return this.options.label || this.model; + } toString(): string { return `[Ollama Chat Provider ${this.modelName}]`; } diff --git a/src/providers/openai.ts b/src/providers/openai.ts index 934997c1f3..3541a08bf7 100644 --- a/src/providers/openai.ts +++ b/src/providers/openai.ts @@ -11,6 +11,7 @@ import type { CallApiOptionsParams, EnvOverrides, ProviderEmbeddingResponse, + ProviderOptions, ProviderResponse, TokenUsage, } from '../types.js'; @@ -67,27 +68,34 @@ function getTokenUsage(data: any, cached: boolean): Partial { return {}; } +type OpenAiGenericOptions = ProviderOptions & { config?: OpenAiSharedOptions }; + export class OpenAiGenericProvider implements ApiProvider { modelName: string; - config: OpenAiSharedOptions; + options: OpenAiGenericOptions; + config: OpenAiGenericOptions['config']; env?: EnvOverrides; constructor( modelName: string, - options: { config?: OpenAiSharedOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: OpenAiSharedOptions } = {}, ) { - const { config, id, env } = options; + const { config, env } = options; + this.options = options; this.env = env; - this.modelName = modelName; this.config = config || {}; - this.id = id ? () => id : this.id; + this.modelName = modelName; } - id(): string { + get model() { return `openai:${this.modelName}`; } + get label() { + return this.options?.label || this.model; + } + toString(): string { return `[OpenAI Provider ${this.modelName}]`; } @@ -223,7 +231,7 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider { constructor( modelName: string, - options: { config?: OpenAiCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: OpenAiGenericOptions = {}, ) { super(modelName, options); this.config = options.config || {}; @@ -371,7 +379,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider { constructor( modelName: string, - options: { config?: OpenAiCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: OpenAiGenericOptions = {}, ) { if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODEL_NAMES.includes(modelName)) { logger.warn(`Using unknown OpenAI chat model: ${modelName}`); diff --git a/src/providers/palm.ts b/src/providers/palm.ts index 3e6490c0e2..656ee967b0 100644 --- a/src/providers/palm.ts +++ b/src/providers/palm.ts @@ -2,7 +2,7 @@ import logger from '../logger'; import { fetchWithCache } from '../cache'; import { parseChatPrompt, REQUEST_TIMEOUT_MS } from './shared'; -import type { ApiProvider, EnvOverrides, ProviderResponse } from '../types.js'; +import type { ApiProvider, EnvOverrides, ProviderOptions, ProviderResponse } from '../types.js'; const DEFAULT_API_HOST = 'generativelanguage.googleapis.com'; @@ -20,25 +20,25 @@ interface PalmCompletionOptions { class PalmGenericProvider implements ApiProvider { modelName: string; - + options: ProviderOptions & { config?: PalmCompletionOptions }; config: PalmCompletionOptions; - env?: EnvOverrides; constructor( modelName: string, - options: { config?: PalmCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: PalmCompletionOptions } = {}, ) { - const { config, id, env } = options; - this.env = env; this.modelName = modelName; - this.config = config || {}; - this.id = id ? () => id : this.id; + this.options = options; + this.config = options.config || {}; } - id(): string { + get model() { return `palm:${this.modelName}`; } + get label() { + return this.options.label || this.model; + } toString(): string { return `[Google AI Studio Provider ${this.modelName}]`; } @@ -46,8 +46,8 @@ class PalmGenericProvider implements ApiProvider { getApiHost(): string | undefined { return ( this.config.apiHost || - this.env?.GOOGLE_API_HOST || - this.env?.PALM_API_HOST || + this.options.env?.GOOGLE_API_HOST || + this.options.env?.PALM_API_HOST || process.env.GOOGLE_API_HOST || process.env.PALM_API_HOST || DEFAULT_API_HOST @@ -57,8 +57,8 @@ class PalmGenericProvider implements ApiProvider { getApiKey(): string | undefined { return ( this.config.apiKey || - this.env?.GOOGLE_API_KEY || - this.env?.PALM_API_KEY || + this.options.env?.GOOGLE_API_KEY || + this.options.env?.PALM_API_KEY || process.env.GOOGLE_API_KEY || process.env.PALM_API_KEY ); @@ -75,7 +75,7 @@ export class PalmChatProvider extends PalmGenericProvider { constructor( modelName: string, - options: { config?: PalmCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: PalmCompletionOptions } = {}, ) { if (!PalmChatProvider.CHAT_MODELS.includes(modelName)) { logger.warn(`Using unknown Google chat model: ${modelName}`); diff --git a/src/providers/pythonCompletion.ts b/src/providers/pythonCompletion.ts index b9a124e84c..099aadaf52 100644 --- a/src/providers/pythonCompletion.ts +++ b/src/providers/pythonCompletion.ts @@ -16,10 +16,14 @@ import type { export class PythonProvider implements ApiProvider { constructor(private scriptPath: string, private options?: ProviderOptions) {} - id() { + get model() { return `python:${this.scriptPath}`; } + get label() { + return this.options?.label || this.model; + } + async callApi(prompt: string, context?: CallApiContextParams): Promise { const absPath = path.resolve( path.join(this.options?.config.basePath, this.scriptPath), diff --git a/src/providers/replicate.ts b/src/providers/replicate.ts index cbb7e49b68..f4ede9685d 100644 --- a/src/providers/replicate.ts +++ b/src/providers/replicate.ts @@ -4,7 +4,7 @@ import fetch from 'node-fetch'; import logger from '../logger'; import { getCache, isCacheEnabled } from '../cache'; -import type { ApiProvider, EnvOverrides, ProviderResponse } from '../types.js'; +import type { ApiProvider, EnvOverrides, ProviderOptions, ProviderResponse } from '../types.js'; interface ReplicateCompletionOptions { apiKey?: string; @@ -27,17 +27,20 @@ interface ReplicateCompletionOptions { [key: string]: any; } +type ReplicateGenericOptions = ProviderOptions & { config?: ReplicateCompletionOptions }; + export class ReplicateProvider implements ApiProvider { modelName: string; apiKey?: string; replicate: any; + options: ReplicateGenericOptions; config: ReplicateCompletionOptions; constructor( modelName: string, - options: { config?: ReplicateCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: ReplicateGenericOptions = {}, ) { - const { config, id, env } = options; + const { config, env } = options; this.modelName = modelName; this.apiKey = config?.apiKey || @@ -45,14 +48,18 @@ export class ReplicateProvider implements ApiProvider { env?.REPLICATE_API_TOKEN || process.env.REPLICATE_API_TOKEN || process.env.REPLICATE_API_KEY; - this.config = config || {}; - this.id = id ? () => id : this.id; + this.options = options; + this.config = config; } - id(): string { + get model(): string { return `replicate:${this.modelName}`; } + get label(): string { + return this.options.label || this.model; + } + toString(): string { return `[Replicate Provider ${this.modelName}]`; } diff --git a/src/providers/scriptCompletion.ts b/src/providers/scriptCompletion.ts index 2f8374655b..4ddc2039ae 100644 --- a/src/providers/scriptCompletion.ts +++ b/src/providers/scriptCompletion.ts @@ -21,10 +21,14 @@ function stripText(text: string) { export class ScriptCompletionProvider implements ApiProvider { constructor(private scriptPath: string, private options?: ProviderOptions) {} - id() { + get model() { return `exec:${this.scriptPath}`; } + get label() { + return this.options?.label || this.model; + } + async callApi(prompt: string, context?: CallApiContextParams): Promise { const cacheKey = `exec:${this.scriptPath}:${prompt}:${JSON.stringify(this.options)}`; const cache = await getCache(); diff --git a/src/providers/vertex.ts b/src/providers/vertex.ts index b1993b93bc..a2d4b16987 100644 --- a/src/providers/vertex.ts +++ b/src/providers/vertex.ts @@ -2,7 +2,7 @@ import logger from '../logger'; import { fetchWithCache } from '../cache'; import { parseChatPrompt, REQUEST_TIMEOUT_MS } from './shared'; -import type { ApiProvider, EnvOverrides, ProviderResponse } from '../types.js'; +import type { ApiProvider, EnvOverrides, ProviderOptions, ProviderResponse } from '../types.js'; import { maybeCoerceToGeminiFormat, type GeminiApiResponse, type ResponseData } from './vertexUtil'; interface VertexCompletionOptions { @@ -24,25 +24,26 @@ interface VertexCompletionOptions { class VertexGenericProvider implements ApiProvider { modelName: string; - + options: ProviderOptions & { config?: VertexCompletionOptions }; config: VertexCompletionOptions; - env?: EnvOverrides; constructor( modelName: string, - options: { config?: VertexCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: VertexCompletionOptions } = {}, ) { - const { config, id, env } = options; - this.env = env; this.modelName = modelName; - this.config = config || {}; - this.id = id ? () => id : this.id; + this.options = options; + this.config = options.config || {} as VertexCompletionOptions; } - id(): string { + get model() { return `vertex:${this.modelName}`; } + get label() { + return this.options.label || this.model; + } + toString(): string { return `[Google Vertex Provider ${this.modelName}]`; } @@ -50,30 +51,30 @@ class VertexGenericProvider implements ApiProvider { getApiHost(): string | undefined { return ( this.config.apiHost || - this.env?.VERTEX_API_HOST || + this.options.env?.VERTEX_API_HOST || process.env.VERTEX_API_HOST || `${this.getRegion()}-aiplatform.googleapis.com` ); } getProjectId(): string | undefined { - return this.config.projectId || this.env?.VERTEX_PROJECT_ID || process.env.VERTEX_PROJECT_ID; + return this.config.projectId || this.options.env?.VERTEX_PROJECT_ID || process.env.VERTEX_PROJECT_ID; } getApiKey(): string | undefined { - return this.config.apiKey || this.env?.VERTEX_API_KEY || process.env.VERTEX_API_KEY; + return this.config.apiKey || this.options.env?.VERTEX_API_KEY || process.env.VERTEX_API_KEY; } getRegion(): string { return ( - this.config.region || this.env?.VERTEX_REGION || process.env.VERTEX_REGION || 'us-central1' + this.config.region || this.options.env?.VERTEX_REGION || process.env.VERTEX_REGION || 'us-central1' ); } getPublisher(): string | undefined { return ( this.config.publisher || - this.env?.VERTEX_PUBLISHER || + this.options.env?.VERTEX_PUBLISHER || process.env.VERTEX_PUBLISHER || 'google' ); @@ -101,7 +102,7 @@ export class VertexChatProvider extends VertexGenericProvider { constructor( modelName: string, - options: { config?: VertexCompletionOptions; id?: string; env?: EnvOverrides } = {}, + options: ProviderOptions & { config?: VertexCompletionOptions } = {}, ) { if (!VertexChatProvider.CHAT_MODELS.includes(modelName)) { logger.warn(`Using unknown Google Vertex chat model: ${modelName}`); diff --git a/src/providers/webhook.ts b/src/providers/webhook.ts index e464ac5b17..e2ea6ee14c 100644 --- a/src/providers/webhook.ts +++ b/src/providers/webhook.ts @@ -3,23 +3,27 @@ import { fetchWithCache } from '../cache'; import { REQUEST_TIMEOUT_MS } from './shared'; -import type { ApiProvider, ProviderResponse } from '../types.js'; +import type { ApiProvider, ProviderOptions, ProviderResponse } from '../types.js'; export class WebhookProvider implements ApiProvider { webhookUrl: string; + options: ProviderOptions; config?: object; - constructor(webhookUrl: string, options: { id?: string; config?: object } = {}) { - const { id, config } = options; + constructor(webhookUrl: string, options: ProviderOptions = {}) { this.webhookUrl = webhookUrl; - this.id = id ? () => id : this.id; - this.config = config; + this.options = options; + this.config = options.config; } - id(): string { + get model(): string { return `webhook:${this.webhookUrl}`; } + get label(): string { + return this.options.label || this.model; + } + toString(): string { return `[Webhook Provider ${this.webhookUrl}]`; } diff --git a/src/types.ts b/src/types.ts index 146b5f0b6a..20fe90ea42 100644 --- a/src/types.ts +++ b/src/types.ts @@ -59,9 +59,15 @@ export interface EnvOverrides { } export interface ProviderOptions { - id?: ProviderId; + /** + * @deprecated Use `model` instead. + */ + id?: ModelId; + model?: ModelId; + label?: ProviderLabel; config?: any; prompts?: string[]; // List of prompt display strings + env?: EnvOverrides; } export interface CallApiContextParams { @@ -73,13 +79,23 @@ export interface CallApiOptionsParams { } export interface ApiProvider { - id: () => string; + // Unique identifier for the provider + model: string; + + // Human-readable label for the provider, shown on output + label: ProviderLabel; + + // Text generation function callApi: ( prompt: string, context?: CallApiContextParams, options?: CallApiOptionsParams, ) => Promise; + + // Embedding function callEmbeddingApi?: (prompt: string) => Promise; + + // Classification function callClassificationApi?: (prompt: string) => Promise; } @@ -216,7 +232,7 @@ export interface PromptWithMetadata { } export interface EvaluateResult { - provider: Pick; + provider: Pick; prompt: Prompt; vars: Record; response?: ProviderResponse; @@ -461,11 +477,13 @@ export interface TestSuite { env?: EnvOverrides; } -export type ProviderId = string; +export type ModelId = string; + +export type ProviderLabel = string; export type ProviderFunction = ApiProvider['callApi']; -export type ProviderOptionsMap = Record; +export type ProviderOptionsMap = Record; // TestSuiteConfig = Test Suite, but before everything is parsed and resolved. Providers are just strings, prompts are filepaths, tests can be filepath or inline. export interface TestSuiteConfig { @@ -473,7 +491,7 @@ export interface TestSuiteConfig { description?: string; // One or more LLM APIs to use, for example: openai:gpt-3.5-turbo, openai:gpt-4, localai:chat:vicuna - providers: ProviderId | ProviderFunction | (ProviderId | ProviderOptionsMap | ProviderOptions)[]; + providers: ModelId | ProviderFunction | (ModelId | ProviderOptionsMap | ProviderOptions)[]; // One or more prompt files to load prompts: FilePath | FilePath[] | Record; diff --git a/src/web/nextui/src/app/setup/ProviderSelector.tsx b/src/web/nextui/src/app/setup/ProviderSelector.tsx index 7d34e1773c..4421640c2c 100644 --- a/src/web/nextui/src/app/setup/ProviderSelector.tsx +++ b/src/web/nextui/src/app/setup/ProviderSelector.tsx @@ -4,12 +4,12 @@ import ProviderConfigDialog from './ProviderConfigDialog'; import type { ProviderOptions } from '@/../../../types'; -const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: string })[]) +const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { model: string })[]) .concat( [ 'replicate:replicate/flan-t5-small:69716ad8c34274043bf4a135b7315c7c569ec931d8f23d6826e249e1c142a264', - ].map((id) => ({ - id, + ].map((model) => ({ + model, config: { temperature: 0.5, max_length: 1024, repetition_penality: 1.0 }, })), ) @@ -18,8 +18,8 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri 'replicate:replicate/codellama-7b-instruct:0103579e86fc75ba0d65912890fa19ef03c84a68554635319accf2e0ba93d3ae', 'replicate:replicate/codellama-13b-instruct:da5676342de1a5a335b848383af297f592b816b950a43d251a0a9edd0113604b', 'replicate:replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', - ].map((id) => ({ - id, + ].map((model) => ({ + model, config: { system_prompt: '', temperature: 0.75, @@ -37,8 +37,8 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri 'replicate:replicate/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db', 'replicate:replicate/codellama-34b-python:9048743d22a7b19cd0abb018066809ea6af4f2b4717bef9aad3c5ae21ceac00d', 'replicate:replicate/codellama-34b:0666717e5ead8557dff55ee8f11924b5c0309f5f1ca52f64bb8eec405fdb38a7', - ].map((id) => ({ - id, + ].map((model) => ({ + model, config: { temperature: 0.75, top_p: 0.9, @@ -52,8 +52,8 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri [ 'replicate:a16z-infra/llama-2-7b-chat:7b0bfc9aff140d5b75bacbed23e91fd3c34b01a1e958d32132de6e0a19796e2c', 'replicate:a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', - ].map((id) => ({ - id, + ].map((model) => ({ + model, config: { temperature: 0.95, top_p: 0.95, @@ -71,14 +71,14 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri 'anthropic:claude-1-100k', 'anthropic:claude-instant-1', 'anthropic:claude-instant-1-100k', - ].map((id) => ({ id, config: { max_tokens_to_sample: 256, temperature: 0.5 } })), + ].map((model) => ({ model, config: { max_tokens_to_sample: 256, temperature: 0.5 } })), ) .concat( [ 'bedrock:anthropic.claude-instant-v1', 'bedrock:anthropic.claude-v1', 'bedrock:anthropic.claude-v2', - ].map((id) => ({ id, config: { max_tokens_to_sample: 256, temperature: 0.5 } })), + ].map((model) => ({ model, config: { max_tokens_to_sample: 256, temperature: 0.5 } })), ) .concat( [ @@ -92,8 +92,8 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri 'openai:gpt-4-0613', 'openai:gpt-4-32k', 'openai:gpt-4-32k-0314', - ].map((id) => ({ - id, + ].map((model) => ({ + model, config: { organization: '', temperature: 0.5, @@ -119,8 +119,8 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri 'azureopenai:gpt-4-0613', 'azureopenai:gpt-4-32k', 'azureopenai:gpt-4-32k-0314', - ].map((id) => ({ - id, + ].map((model) => ({ + model, config: { temperature: 0.5, max_tokens: 1024, @@ -139,8 +139,8 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri 'vertex:chat-bison', 'vertex:chat-bison-32k', 'vertex:chat-bison-32k@001', - ].map((id) => ({ - id, + ].map((model) => ({ + model, config: { context: undefined, examples: undefined, @@ -153,7 +153,7 @@ const defaultProviders: ProviderOptions[] = ([] as (ProviderOptions & { id: stri }, })), ) - .sort((a, b) => a.id.localeCompare(b.id)); + .sort((a, b) => a.model.localeCompare(b.model)); const PREFIX_TO_PROVIDER: Record = { anthropic: 'Anthropic', @@ -183,14 +183,14 @@ const ProviderSelector: React.FC = ({ providers, onChange if (typeof provider === 'string') { return provider; } - return provider.id || 'Unknown provider'; + return provider.model || 'Unknown provider'; }; const getProviderKey = (provider: string | ProviderOptions, index: number) => { if (typeof provider === 'string') { return provider; } - return provider.id || index; + return provider.model || index; }; const handleProviderClick = (provider: string | ProviderOptions) => { @@ -206,7 +206,7 @@ const ProviderSelector: React.FC = ({ providers, onChange const handleSave = (config: ProviderOptions['config']) => { if (selectedProvider) { const updatedProviders = providers.map((provider) => - provider.id === selectedProvider.id ? { ...provider, config } : provider, + provider.model === selectedProvider.model ? { ...provider, config } : provider, ); onChange(updatedProviders); setSelectedProvider(null); @@ -220,9 +220,9 @@ const ProviderSelector: React.FC = ({ providers, onChange freeSolo options={defaultProviders} value={providers} - groupBy={(option) => getGroupName(option.id)} + groupBy={(option) => getGroupName(option.model)} onChange={(event, newValue: (string | ProviderOptions)[]) => { - onChange(newValue.map((value) => (typeof value === 'string' ? { id: value } : value))); + onChange(newValue.map((value) => (typeof value === 'string' ? { model: value } : value))); }} getOptionLabel={(option) => { if (!option) { @@ -234,10 +234,10 @@ const ProviderSelector: React.FC = ({ providers, onChange optionString = option; } if ( - (option as ProviderOptions).id && - typeof (option as ProviderOptions).id === 'string' + (option as ProviderOptions).model && + typeof (option as ProviderOptions).model === 'string' ) { - optionString = (option as ProviderOptions).id!; + optionString = (option as ProviderOptions).model!; } const splits = optionString.split(':'); if (splits.length > 1) { @@ -270,10 +270,10 @@ const ProviderSelector: React.FC = ({ providers, onChange /> )} /> - {selectedProvider && selectedProvider.id && ( + {selectedProvider && selectedProvider.model && ( setSelectedProvider(null)} onSave={handleSave}