diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index d3b657bf..95a52cd0 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -20,8 +20,10 @@ import ( _ "github.com/symflower/eval-dev-quality/language/java" // Register language. "github.com/symflower/eval-dev-quality/log" "github.com/symflower/eval-dev-quality/model" + "github.com/symflower/eval-dev-quality/model/llm" "github.com/symflower/eval-dev-quality/provider" - _ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider. + _ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider. + openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api" _ "github.com/symflower/eval-dev-quality/provider/openrouter" // Register provider. _ "github.com/symflower/eval-dev-quality/provider/symflower" // Register provider. "github.com/symflower/eval-dev-quality/tools" @@ -51,6 +53,8 @@ type Evaluate struct { // ProviderTokens holds all API tokens for the providers. ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","` + // ProviderUrls holds all custom inference endpoint urls for the providers. + ProviderUrls map[string]string `long:"urls" description:"Custom OpenAI API compatible inference endpoints (of the form '$provider:$url,...'). Use '$provider=custom-$name' to manually register a custom OpenAI API endpoint provider. Note that the models of a custom OpenAI API endpoint provider must be declared explicitly using the '--model' option. When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_URL" env-delim:","` // logger holds the logger of the command. logger *log.Logger @@ -99,6 +103,36 @@ func (command *Evaluate) Execute(args []string) (err error) { } } + // Register custom OpenAI API providers and models. + { + customProviders := map[string]*openaiapi.Provider{} + for providerID, providerURL := range command.ProviderUrls { + if !strings.HasPrefix(providerID, "custom-") { + continue + } + + p := openaiapi.NewProvider(providerID, providerURL) + provider.Register(p) + customProviders[providerID] = p + } + for _, model := range command.Models { + if !strings.HasPrefix(model, "custom-") { + continue + } + + providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator) + if !ok { + log.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator) + } + modelProvider, ok := customProviders[providerID] + if !ok { + log.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model) + } + + modelProvider.AddModel(llm.NewModel(modelProvider, model)) + } + } + // Gather languages. languagesSelected := map[string]language.Language{} { diff --git a/cmd/eval-dev-quality/cmd/evaluate_test.go b/cmd/eval-dev-quality/cmd/evaluate_test.go index 737891f9..3e54598f 100644 --- a/cmd/eval-dev-quality/cmd/evaluate_test.go +++ b/cmd/eval-dev-quality/cmd/evaluate_test.go @@ -3,6 +3,7 @@ package cmd import ( "errors" "fmt" + "net/url" "os" "path/filepath" "regexp" @@ -541,6 +542,54 @@ func TestEvaluateExecute(t *testing.T) { }) } }) + t.Run("OpenAI API", func(t *testing.T) { + if !osutil.IsLinux() { + t.Skipf("Installation of Ollama is not supported on this OS") + } + + { + var shutdown func() (err error) + defer func() { + if shutdown != nil { + require.NoError(t, shutdown()) + } + }() + ollamaOpenAIAPIUrl, err := url.JoinPath(tools.OllamaURL, "v1") + require.NoError(t, err) + validate(t, &testCase{ + Name: "Ollama", + + Before: func(t *testing.T, logger *log.Logger, resultPath string) { + var err error + shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL) + require.NoError(t, err) + + require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, "qwen:0.5b")) + }, + + Arguments: []string{ + "--language", "golang", + "--urls", "custom-ollama:" + ollamaOpenAIAPIUrl, + "--model", "custom-ollama/qwen:0.5b", + "--repository", filepath.Join("golang", "plain"), + }, + + ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){ + "categories.svg": nil, + "evaluation.csv": nil, + "evaluation.log": func(t *testing.T, filePath, data string) { + // Since the model is non-deterministic, we can only assert that the model did at least not error. + assert.Contains(t, data, `Evaluation score for "custom-ollama/qwen:0.5b"`) + assert.Contains(t, data, "response-no-error=1") + }, + "golang-summed.csv": nil, + "models-summed.csv": nil, + "README.md": nil, + "custom-ollama_qwen:0.5b/golang/golang/plain.log": nil, + }, + }) + } + }) }) // This case checks a beautiful bug where the Markdown export crashed when the current working directory contained a README.md file. While this is not the case during the tests (as the current work directory is the directory of this file), it certainly caused problems when our binary was executed from the repository root (which of course contained a README.md). Therefore, we sadly have to modify the current work directory right within the tests of this case to reproduce the problem and fix it forever. diff --git a/provider/openai-api/openai.go b/provider/openai-api/openai.go new file mode 100644 index 00000000..ac4317a6 --- /dev/null +++ b/provider/openai-api/openai.go @@ -0,0 +1,96 @@ +package openaiapi + +import ( + "context" + "fmt" + "strings" + + pkgerrors "github.com/pkg/errors" + "github.com/sashabaranov/go-openai" + + "github.com/symflower/eval-dev-quality/log" + "github.com/symflower/eval-dev-quality/model" + "github.com/symflower/eval-dev-quality/provider" +) + +// Provider holds a generic "OpenAI API" provider. +type Provider struct { + baseURL string + token string + id string + models []model.Model +} + +// NewProvider returns a generic "OpenAI API" provider. +func NewProvider(id string, baseURL string) (provider *Provider) { + return &Provider{ + baseURL: baseURL, + id: id, + } +} + +var _ provider.Provider = (*Provider)(nil) + +// Available checks if the provider is ready to be used. +// This might include checking for an installation or making sure an API access token is valid. +func (p *Provider) Available(logger *log.Logger) (err error) { + return nil // We cannot know if a custom provider requires an API. +} + +// ID returns the unique ID of this provider. +func (p *Provider) ID() (id string) { + return p.id +} + +// Models returns which models are available to be queried via this provider. +func (p *Provider) Models() (models []model.Model, err error) { + return p.models, nil +} + +// AddModel manually adds a model to the provider. +func (p *Provider) AddModel(m model.Model) { + p.models = append(p.models, m) +} + +var _ provider.InjectToken = (*Provider)(nil) + +// SetToken sets a potential token to be used in case the provider needs to authenticate a remote API. +func (p *Provider) SetToken(token string) { + p.token = token +} + +var _ provider.Query = (*Provider)(nil) + +// Query queries the provider with the given model name. +func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) { + client := p.client() + modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) + + apiResponse, err := client.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: modelIdentifier, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: promptText, + }, + }, + }, + ) + if err != nil { + return "", pkgerrors.WithStack(err) + } else if len(apiResponse.Choices) == 0 { + return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse)) + } + + return apiResponse.Choices[0].Message.Content, nil +} + +// client returns a new client with the current configuration. +func (p *Provider) client() (client *openai.Client) { + config := openai.DefaultConfig(p.token) + config.BaseURL = p.baseURL + + return openai.NewClientWithConfig(config) +}