From 25f5241bbac5a3563d77fecf4f9aaf7b8e83081b Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Tue, 14 May 2024 15:16:47 +0200 Subject: [PATCH 1/4] Custom delimiter for multiple provider tokens via environment because there is no default --- cmd/eval-dev-quality/cmd/evaluate.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index d0205dbb..17b29d7e 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -40,11 +40,10 @@ type Evaluate struct { // Languages determines which language should be used for the evaluation, or empty if all languages should be used. Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."` - // Models determines which models should be used for the evaluation, or empty if all models should be used. Models []string `long:"model" description:"Evaluate with this model. By default all models are used."` // ProviderTokens holds all API tokens for the providers. - ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token,...')." env:"PROVIDER_TOKEN"` + ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","` // QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task. QueryAttempts uint `long:"attempts" description:"Number of query attempts to perform when a model request errors in the process of solving a task." default:"3"` From d0bc31101076dd6c7dd50c37d3057c5dda8a8dbb Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Tue, 14 May 2024 15:17:42 +0200 Subject: [PATCH 2/4] Generic OpenAI API provider that can be configured via command line Part of #111 --- cmd/eval-dev-quality/cmd/evaluate.go | 36 ++++++++- cmd/eval-dev-quality/cmd/evaluate_test.go | 49 ++++++++++++ provider/openai-api/openai.go | 96 +++++++++++++++++++++++ 3 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 provider/openai-api/openai.go diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index 17b29d7e..9d0a1179 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -20,8 +20,10 @@ import ( _ "github.com/symflower/eval-dev-quality/language/java" // Register language. "github.com/symflower/eval-dev-quality/log" "github.com/symflower/eval-dev-quality/model" + "github.com/symflower/eval-dev-quality/model/llm" "github.com/symflower/eval-dev-quality/provider" - _ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider. + _ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider. + openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api" _ "github.com/symflower/eval-dev-quality/provider/openrouter" // Register provider. _ "github.com/symflower/eval-dev-quality/provider/symflower" // Register provider. "github.com/symflower/eval-dev-quality/tools" @@ -44,6 +46,8 @@ type Evaluate struct { Models []string `long:"model" description:"Evaluate with this model. By default all models are used."` // ProviderTokens holds all API tokens for the providers. ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","` + // ProviderUrls holds all custom inference endpoint urls for the providers. + ProviderUrls map[string]string `long:"urls" description:"Custom OpenAI API compatible inference endpoints (of the form '$provider:$url,...'). Use '$provider=custom-$name' to manually register a custom OpenAI API endpoint provider. Note that the models of a custom OpenAI API endpoint provider must be declared explicitly using the '--model' option. When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_URL" env-delim:","` // QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task. QueryAttempts uint `long:"attempts" description:"Number of query attempts to perform when a model request errors in the process of solving a task." default:"3"` @@ -111,6 +115,36 @@ func (command *Evaluate) Execute(args []string) (err error) { } } + // Register custom OpenAI API providers and models. + { + customProviders := map[string]*openaiapi.Provider{} + for providerID, providerURL := range command.ProviderUrls { + if !strings.HasPrefix(providerID, "custom-") { + continue + } + + p := openaiapi.NewProvider(providerID, providerURL) + provider.Register(p) + customProviders[providerID] = p + } + for _, model := range command.Models { + if !strings.HasPrefix(model, "custom-") { + continue + } + + providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator) + if !ok { + log.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator) + } + modelProvider, ok := customProviders[providerID] + if !ok { + log.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model) + } + + modelProvider.AddModel(llm.NewModel(modelProvider, model)) + } + } + // Gather languages. languagesSelected := map[string]language.Language{} { diff --git a/cmd/eval-dev-quality/cmd/evaluate_test.go b/cmd/eval-dev-quality/cmd/evaluate_test.go index 2060aab7..5ba50798 100644 --- a/cmd/eval-dev-quality/cmd/evaluate_test.go +++ b/cmd/eval-dev-quality/cmd/evaluate_test.go @@ -2,6 +2,7 @@ package cmd import ( "fmt" + "net/url" "os" "path/filepath" "regexp" @@ -528,6 +529,54 @@ func TestEvaluateExecute(t *testing.T) { }) } }) + t.Run("OpenAI API", func(t *testing.T) { + if !osutil.IsLinux() { + t.Skipf("Installation of Ollama is not supported on this OS") + } + + { + var shutdown func() (err error) + defer func() { + if shutdown != nil { + require.NoError(t, shutdown()) + } + }() + ollamaOpenAIAPIUrl, err := url.JoinPath(tools.OllamaURL, "v1") + require.NoError(t, err) + validate(t, &testCase{ + Name: "Ollama", + + Before: func(t *testing.T, logger *log.Logger, resultPath string) { + var err error + shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL) + require.NoError(t, err) + + require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, "qwen:0.5b")) + }, + + Arguments: []string{ + "--language", "golang", + "--urls", "custom-ollama:" + ollamaOpenAIAPIUrl, + "--model", "custom-ollama/qwen:0.5b", + "--repository", filepath.Join("golang", "plain"), + }, + + ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){ + "categories.svg": nil, + "evaluation.csv": nil, + "evaluation.log": func(t *testing.T, filePath, data string) { + // Since the model is non-deterministic, we can only assert that the model did at least not error. + assert.Contains(t, data, `Evaluation score for "custom-ollama/qwen:0.5b"`) + assert.Contains(t, data, "response-no-error=1") + }, + "golang-summed.csv": nil, + "models-summed.csv": nil, + "README.md": nil, + "custom-ollama_qwen:0.5b/golang/golang/plain.log": nil, + }, + }) + } + }) }) t.Run("Runs", func(t *testing.T) { diff --git a/provider/openai-api/openai.go b/provider/openai-api/openai.go new file mode 100644 index 00000000..ac4317a6 --- /dev/null +++ b/provider/openai-api/openai.go @@ -0,0 +1,96 @@ +package openaiapi + +import ( + "context" + "fmt" + "strings" + + pkgerrors "github.com/pkg/errors" + "github.com/sashabaranov/go-openai" + + "github.com/symflower/eval-dev-quality/log" + "github.com/symflower/eval-dev-quality/model" + "github.com/symflower/eval-dev-quality/provider" +) + +// Provider holds a generic "OpenAI API" provider. +type Provider struct { + baseURL string + token string + id string + models []model.Model +} + +// NewProvider returns a generic "OpenAI API" provider. +func NewProvider(id string, baseURL string) (provider *Provider) { + return &Provider{ + baseURL: baseURL, + id: id, + } +} + +var _ provider.Provider = (*Provider)(nil) + +// Available checks if the provider is ready to be used. +// This might include checking for an installation or making sure an API access token is valid. +func (p *Provider) Available(logger *log.Logger) (err error) { + return nil // We cannot know if a custom provider requires an API. +} + +// ID returns the unique ID of this provider. +func (p *Provider) ID() (id string) { + return p.id +} + +// Models returns which models are available to be queried via this provider. +func (p *Provider) Models() (models []model.Model, err error) { + return p.models, nil +} + +// AddModel manually adds a model to the provider. +func (p *Provider) AddModel(m model.Model) { + p.models = append(p.models, m) +} + +var _ provider.InjectToken = (*Provider)(nil) + +// SetToken sets a potential token to be used in case the provider needs to authenticate a remote API. +func (p *Provider) SetToken(token string) { + p.token = token +} + +var _ provider.Query = (*Provider)(nil) + +// Query queries the provider with the given model name. +func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) { + client := p.client() + modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) + + apiResponse, err := client.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: modelIdentifier, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: promptText, + }, + }, + }, + ) + if err != nil { + return "", pkgerrors.WithStack(err) + } else if len(apiResponse.Choices) == 0 { + return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse)) + } + + return apiResponse.Choices[0].Message.Content, nil +} + +// client returns a new client with the current configuration. +func (p *Provider) client() (client *openai.Client) { + config := openai.DefaultConfig(p.token) + config.BaseURL = p.baseURL + + return openai.NewClientWithConfig(config) +} From 3dff4c4adb9e63b3a4076539dad98faeb4e4bede Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Mon, 27 May 2024 11:07:10 +0200 Subject: [PATCH 3/4] refactor, Unify and reuse OpenAI API query logic between generic OpenAI API, Ollama and Openrouter Closes #111 --- provider/ollama/ollama.go | 22 ++------------------- provider/openai-api/openai.go | 22 +-------------------- provider/openai-api/query.go | 32 +++++++++++++++++++++++++++++++ provider/openrouter/openrouter.go | 22 ++------------------- 4 files changed, 37 insertions(+), 61 deletions(-) create mode 100644 provider/openai-api/query.go diff --git a/provider/ollama/ollama.go b/provider/ollama/ollama.go index 2d8118f9..103b6df9 100644 --- a/provider/ollama/ollama.go +++ b/provider/ollama/ollama.go @@ -2,7 +2,6 @@ package ollama import ( "context" - "fmt" "net/url" "strings" @@ -13,6 +12,7 @@ import ( "github.com/symflower/eval-dev-quality/model" "github.com/symflower/eval-dev-quality/model/llm" "github.com/symflower/eval-dev-quality/provider" + openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api" "github.com/symflower/eval-dev-quality/tools" ) @@ -85,25 +85,7 @@ func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText client := p.client() modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) - apiResponse, err := client.CreateChatCompletion( - ctx, - openai.ChatCompletionRequest{ - Model: modelIdentifier, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: promptText, - }, - }, - }, - ) - if err != nil { - return "", pkgerrors.WithStack(err) - } else if len(apiResponse.Choices) == 0 { - return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse)) - } - - return apiResponse.Choices[0].Message.Content, nil + return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText) } // client returns a new client with the current configuration. diff --git a/provider/openai-api/openai.go b/provider/openai-api/openai.go index ac4317a6..18afcad6 100644 --- a/provider/openai-api/openai.go +++ b/provider/openai-api/openai.go @@ -2,10 +2,8 @@ package openaiapi import ( "context" - "fmt" "strings" - pkgerrors "github.com/pkg/errors" "github.com/sashabaranov/go-openai" "github.com/symflower/eval-dev-quality/log" @@ -66,25 +64,7 @@ func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText client := p.client() modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) - apiResponse, err := client.CreateChatCompletion( - ctx, - openai.ChatCompletionRequest{ - Model: modelIdentifier, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: promptText, - }, - }, - }, - ) - if err != nil { - return "", pkgerrors.WithStack(err) - } else if len(apiResponse.Choices) == 0 { - return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse)) - } - - return apiResponse.Choices[0].Message.Content, nil + return QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText) } // client returns a new client with the current configuration. diff --git a/provider/openai-api/query.go b/provider/openai-api/query.go new file mode 100644 index 00000000..8411fea0 --- /dev/null +++ b/provider/openai-api/query.go @@ -0,0 +1,32 @@ +package openaiapi + +import ( + "context" + "fmt" + + pkgerrors "github.com/pkg/errors" + "github.com/sashabaranov/go-openai" +) + +// QueryOpenAIModel queries an OpenAI API model. +func QueryOpenAIAPIModel(ctx context.Context, client *openai.Client, modelIdentifier string, promptText string) (response string, err error) { + apiResponse, err := client.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: modelIdentifier, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: promptText, + }, + }, + }, + ) + if err != nil { + return "", pkgerrors.WithStack(err) + } else if len(apiResponse.Choices) == 0 { + return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse)) + } + + return apiResponse.Choices[0].Message.Content, nil +} diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index e0122ad0..f149115d 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -3,7 +3,6 @@ package openrouter import ( "context" "errors" - "fmt" "strings" pkgerrors "github.com/pkg/errors" @@ -13,6 +12,7 @@ import ( "github.com/symflower/eval-dev-quality/model" "github.com/symflower/eval-dev-quality/model/llm" "github.com/symflower/eval-dev-quality/provider" + openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api" ) // Provider holds an "openrouter.ai" provider using its public REST API. @@ -79,25 +79,7 @@ func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText client := p.client() modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) - apiResponse, err := client.CreateChatCompletion( - ctx, - openai.ChatCompletionRequest{ - Model: modelIdentifier, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: promptText, - }, - }, - }, - ) - if err != nil { - return "", pkgerrors.WithStack(err) - } else if len(apiResponse.Choices) == 0 { - return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse)) - } - - return apiResponse.Choices[0].Message.Content, nil + return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText) } // client returns a new client with the current configuration. From cd93bb461bdf3c364925e415c6862c04696a520e Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Mon, 27 May 2024 15:36:59 +0200 Subject: [PATCH 4/4] refactor, Extract default Ollama test model to avoud thousand magic constants --- cmd/eval-dev-quality/cmd/evaluate_test.go | 17 +++++++++-------- provider/ollama/ollama_test.go | 5 +++-- provider/testing/helper.go | 3 +++ tools/ollama_test.go | 11 ++++++----- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/cmd/eval-dev-quality/cmd/evaluate_test.go b/cmd/eval-dev-quality/cmd/evaluate_test.go index 5ba50798..dfd218fc 100644 --- a/cmd/eval-dev-quality/cmd/evaluate_test.go +++ b/cmd/eval-dev-quality/cmd/evaluate_test.go @@ -18,6 +18,7 @@ import ( "github.com/symflower/eval-dev-quality/evaluate/metrics" metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing" "github.com/symflower/eval-dev-quality/log" + providertesting "github.com/symflower/eval-dev-quality/provider/testing" "github.com/symflower/eval-dev-quality/tools" ) @@ -502,12 +503,12 @@ func TestEvaluateExecute(t *testing.T) { shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL) require.NoError(t, err) - require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, "qwen:0.5b")) + require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, providertesting.OllamaTestModel)) }, Arguments: []string{ "--language", "golang", - "--model", "ollama/qwen:0.5b", + "--model", "ollama/" + providertesting.OllamaTestModel, "--repository", filepath.Join("golang", "plain"), }, @@ -516,7 +517,7 @@ func TestEvaluateExecute(t *testing.T) { "evaluation.csv": nil, "evaluation.log": func(t *testing.T, filePath, data string) { // Since the model is non-deterministic, we can only assert that the model did at least not error. - assert.Contains(t, data, `Evaluation score for "ollama/qwen:0.5b"`) + assert.Contains(t, data, fmt.Sprintf(`Evaluation score for "ollama/%s"`, providertesting.OllamaTestModel)) assert.Contains(t, data, "response-no-error=1") assert.Contains(t, data, "preloading model") assert.Contains(t, data, "unloading model") @@ -524,7 +525,7 @@ func TestEvaluateExecute(t *testing.T) { "golang-summed.csv": nil, "models-summed.csv": nil, "README.md": nil, - "ollama_qwen:0.5b/golang/golang/plain.log": nil, + "ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil, }, }) } @@ -551,13 +552,13 @@ func TestEvaluateExecute(t *testing.T) { shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL) require.NoError(t, err) - require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, "qwen:0.5b")) + require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, providertesting.OllamaTestModel)) }, Arguments: []string{ "--language", "golang", "--urls", "custom-ollama:" + ollamaOpenAIAPIUrl, - "--model", "custom-ollama/qwen:0.5b", + "--model", "custom-ollama/" + providertesting.OllamaTestModel, "--repository", filepath.Join("golang", "plain"), }, @@ -566,13 +567,13 @@ func TestEvaluateExecute(t *testing.T) { "evaluation.csv": nil, "evaluation.log": func(t *testing.T, filePath, data string) { // Since the model is non-deterministic, we can only assert that the model did at least not error. - assert.Contains(t, data, `Evaluation score for "custom-ollama/qwen:0.5b"`) + assert.Contains(t, data, fmt.Sprintf(`Evaluation score for "custom-ollama/%s"`, providertesting.OllamaTestModel)) assert.Contains(t, data, "response-no-error=1") }, "golang-summed.csv": nil, "models-summed.csv": nil, "README.md": nil, - "custom-ollama_qwen:0.5b/golang/golang/plain.log": nil, + "custom-ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil, }, }) } diff --git a/provider/ollama/ollama_test.go b/provider/ollama/ollama_test.go index b37a3ed7..c67a2a80 100644 --- a/provider/ollama/ollama_test.go +++ b/provider/ollama/ollama_test.go @@ -9,6 +9,7 @@ import ( "github.com/zimmski/osutil" "github.com/symflower/eval-dev-quality/log" + providertesting "github.com/symflower/eval-dev-quality/provider/testing" "github.com/symflower/eval-dev-quality/tools" ) @@ -143,11 +144,11 @@ func TestProviderModels(t *testing.T) { Name: "Local Model", LocalModels: []string{ - "qwen:0.5b", + providertesting.OllamaTestModel, }, ExpectedModels: []string{ - "ollama/qwen:0.5b", + "ollama/" + providertesting.OllamaTestModel, }, }) } diff --git a/provider/testing/helper.go b/provider/testing/helper.go index 734f4523..f165c95f 100644 --- a/provider/testing/helper.go +++ b/provider/testing/helper.go @@ -16,3 +16,6 @@ func NewMockProviderNamedWithModels(t *testing.T, id string, models []model.Mode return m } + +// OllamaTestModel holds the smallest Ollama model that we use for testing. +const OllamaTestModel = "qwen:0.5b" diff --git a/tools/ollama_test.go b/tools/ollama_test.go index 1ae2a45f..46295e9a 100644 --- a/tools/ollama_test.go +++ b/tools/ollama_test.go @@ -10,6 +10,7 @@ import ( "github.com/zimmski/osutil" "github.com/symflower/eval-dev-quality/log" + providertesting "github.com/symflower/eval-dev-quality/provider/testing" "github.com/symflower/eval-dev-quality/util" ) @@ -61,10 +62,10 @@ func TestOllamaLoading(t *testing.T) { defer func() { require.NoError(t, shutdown()) }() - require.NoError(t, OllamaPull(log, OllamaPath, url, "qwen:0.5b")) + require.NoError(t, OllamaPull(log, OllamaPath, url, providertesting.OllamaTestModel)) t.Run("Load Model", func(t *testing.T) { - assert.NoError(t, OllamaLoad(url, "qwen:0.5b")) + assert.NoError(t, OllamaLoad(url, providertesting.OllamaTestModel)) output, err := util.CommandWithResult(context.Background(), log, &util.Command{ Command: []string{ @@ -76,10 +77,10 @@ func TestOllamaLoading(t *testing.T) { }, }) assert.NoError(t, err) - assert.Contains(t, output, "qwen:0.5b") + assert.Contains(t, output, providertesting.OllamaTestModel) }) t.Run("unload Model", func(t *testing.T) { - assert.NoError(t, OllamaUnload(url, "qwen:0.5b")) + assert.NoError(t, OllamaUnload(url, providertesting.OllamaTestModel)) // Give it a few seconds for the unloading completes. time.Sleep(2 * time.Second) @@ -94,6 +95,6 @@ func TestOllamaLoading(t *testing.T) { }, }) assert.NoError(t, err) - assert.NotContains(t, output, "qwen:0.5b") + assert.NotContains(t, output, providertesting.OllamaTestModel) }) }