Skip to content

Commit

Permalink
Generic OpenAI API provider that can be configured via command line
Browse files Browse the repository at this point in the history
Part of #111
  • Loading branch information
bauersimon authored and Munsio committed May 23, 2024
1 parent df7641e commit e0232cc
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 1 deletion.
36 changes: 35 additions & 1 deletion cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
_ "github.com/symflower/eval-dev-quality/language/java" // Register language.
"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm"
"github.com/symflower/eval-dev-quality/provider"
_ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider.
_ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider.
openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api"
_ "github.com/symflower/eval-dev-quality/provider/openrouter" // Register provider.
_ "github.com/symflower/eval-dev-quality/provider/symflower" // Register provider.
"github.com/symflower/eval-dev-quality/tools"
Expand All @@ -45,6 +47,8 @@ type Evaluate struct {
Models []string `long:"model" description:"Evaluate with this model. By default all models are used."`
// ProviderTokens holds all API tokens for the providers.
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
// ProviderUrls holds all custom inference endpoint urls for the providers.
ProviderUrls map[string]string `long:"urls" description:"Custom OpenAI API compatible inference endpoints (of the form '$provider:$url,...'). Use '$provider=custom-$name' to manually register a custom OpenAI API endpoint provider. Note that the models of a custom OpenAI API endpoint provider must be declared explicitly using the '--model' option. When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_URL" env-delim:","`
// QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
QueryAttempts uint `long:"attempts" description:"Number of query attempts to perform when a model request errors in the process of solving a task." default:"3"`

Expand Down Expand Up @@ -112,6 +116,36 @@ func (command *Evaluate) Execute(args []string) (err error) {
}
}

// Register custom OpenAI API providers and models.
{
customProviders := map[string]*openaiapi.Provider{}
for providerID, providerURL := range command.ProviderUrls {
if !strings.HasPrefix(providerID, "custom-") {
continue
}

p := openaiapi.NewProvider(providerID, providerURL)
provider.Register(p)
customProviders[providerID] = p
}
for _, model := range command.Models {
if !strings.HasPrefix(model, "custom-") {
continue
}

providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
if !ok {
log.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
}
modelProvider, ok := customProviders[providerID]
if !ok {
log.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
}

modelProvider.AddModel(llm.NewModel(modelProvider, model))
}
}

// Gather languages.
languagesSelected := map[string]language.Language{}
{
Expand Down
49 changes: 49 additions & 0 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -526,6 +527,54 @@ func TestEvaluateExecute(t *testing.T) {
})
}
})
t.Run("OpenAI API", func(t *testing.T) {
if !osutil.IsLinux() {
t.Skipf("Installation of Ollama is not supported on this OS")
}

{
var shutdown func() (err error)
defer func() {
if shutdown != nil {
require.NoError(t, shutdown())
}
}()
ollamaOpenAIAPIUrl, err := url.JoinPath(tools.OllamaURL, "v1")
require.NoError(t, err)
validate(t, &testCase{
Name: "Ollama",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
var err error
shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL)
require.NoError(t, err)

require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, "qwen:0.5b"))
},

Arguments: []string{
"--language", "golang",
"--urls", "custom-ollama:" + ollamaOpenAIAPIUrl,
"--model", "custom-ollama/qwen:0.5b",
"--repository", filepath.Join("golang", "plain"),
},

ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
"categories.svg": nil,
"evaluation.csv": nil,
"evaluation.log": func(t *testing.T, filePath, data string) {
// Since the model is non-deterministic, we can only assert that the model did at least not error.
assert.Contains(t, data, `Evaluation score for "custom-ollama/qwen:0.5b"`)
assert.Contains(t, data, "response-no-error=1")
},
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"custom-ollama_qwen:0.5b/golang/golang/plain.log": nil,
},
})
}
})
})

t.Run("Runs", func(t *testing.T) {
Expand Down
96 changes: 96 additions & 0 deletions provider/openai-api/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package openaiapi

import (
"context"
"fmt"
"strings"

pkgerrors "github.com/pkg/errors"
"github.com/sashabaranov/go-openai"

"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
)

// Provider holds a generic "OpenAI API" provider.
type Provider struct {
baseURL string
token string
id string
models []model.Model
}

// NewProvider returns a generic "OpenAI API" provider.
func NewProvider(id string, baseURL string) (provider *Provider) {
return &Provider{
baseURL: baseURL,
id: id,
}
}

var _ provider.Provider = (*Provider)(nil)

// Available checks if the provider is ready to be used.
// This might include checking for an installation or making sure an API access token is valid.
func (p *Provider) Available(logger *log.Logger) (err error) {
return nil // We cannot know if a custom provider requires an API.
}

// ID returns the unique ID of this provider.
func (p *Provider) ID() (id string) {
return p.id
}

// Models returns which models are available to be queried via this provider.
func (p *Provider) Models() (models []model.Model, err error) {
return p.models, nil
}

// AddModel manually adds a model to the provider.
func (p *Provider) AddModel(m model.Model) {
p.models = append(p.models, m)
}

var _ provider.InjectToken = (*Provider)(nil)

// SetToken sets a potential token to be used in case the provider needs to authenticate a remote API.
func (p *Provider) SetToken(token string) {
p.token = token
}

var _ provider.Query = (*Provider)(nil)

// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) {
client := p.client()
modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator)

apiResponse, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: modelIdentifier,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: promptText,
},
},
},
)
if err != nil {
return "", pkgerrors.WithStack(err)
} else if len(apiResponse.Choices) == 0 {
return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse))
}

return apiResponse.Choices[0].Message.Content, nil
}

// client returns a new client with the current configuration.
func (p *Provider) client() (client *openai.Client) {
config := openai.DefaultConfig(p.token)
config.BaseURL = p.baseURL

return openai.NewClientWithConfig(config)
}

0 comments on commit e0232cc

Please sign in to comment.